diff --git a/source/opt/constants.h b/source/opt/constants.h index c039ae087..588ca3e76 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -163,6 +163,21 @@ class ScalarConstant : public Constant { return is_zero; } + uint32_t GetU32BitValue() const { + // Relies on unsigned values smaller than 32-bit being zero extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 1); + return words()[0]; + } + + uint64_t GetU64BitValue() const { + // Relies on unsigned values smaller than 64-bit being zero extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 2); + return static_cast(words()[1]) << 32 | + static_cast(words()[0]); + } + protected: ScalarConstant(const Type* ty, const std::vector& w) : Constant(ty), words_(w) {} @@ -189,13 +204,6 @@ class IntConstant : public ScalarConstant { return words()[0]; } - uint32_t GetU32BitValue() const { - // Relies on unsigned values smaller than 32-bit being zero extended. See - // section 2.2.1 of the SPIR-V spec. - assert(words().size() == 1); - return words()[0]; - } - int64_t GetS64BitValue() const { // Relies on unsigned values smaller than 64-bit being sign extended. See // section 2.2.1 of the SPIR-V spec. @@ -204,14 +212,6 @@ class IntConstant : public ScalarConstant { static_cast(words()[0]); } - uint64_t GetU64BitValue() const { - // Relies on unsigned values smaller than 64-bit being zero extended. See - // section 2.2.1 of the SPIR-V spec. - assert(words().size() == 2); - return static_cast(words()[1]) << 32 | - static_cast(words()[0]); - } - // Make a copy of this IntConstant instance. std::unique_ptr CopyIntConstant() const { return MakeUnique(type_->AsInteger(), words_); diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 3f10bd009..3d803addc 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -136,25 +136,28 @@ std::vector GetWordsFromScalarIntConstant( const analysis::IntConstant* c) { assert(c != nullptr); uint32_t width = c->type()->AsInteger()->width(); - assert(width == 32 || width == 64); + assert(width == 8 || width == 16 || width == 32 || width == 64); if (width == 64) { uint64_t uval = static_cast(c->GetU64()); return ExtractInts(uval); } - return {c->GetU32()}; + // Section 2.2.1 of the SPIR-V spec guarantees that all integer types + // smaller than 32-bits are automatically zero or sign extended to 32-bits. + return {c->GetU32BitValue()}; } std::vector GetWordsFromScalarFloatConstant( const analysis::FloatConstant* c) { assert(c != nullptr); uint32_t width = c->type()->AsFloat()->width(); - assert(width == 32 || width == 64); + assert(width == 16 || width == 32 || width == 64); if (width == 64) { utils::FloatProxy result(c->GetDouble()); return result.GetWords(); } - utils::FloatProxy result(c->GetFloat()); - return result.GetWords(); + // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types + // smaller than 32-bits are automatically zero extended to 32-bits. + return {c->GetU32BitValue()}; } std::vector GetWordsFromNumericScalarOrVectorConstant( diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index a034e959a..b32480326 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -92,8 +92,13 @@ TEST_P(IntegerInstructionFoldingTest, Case) { inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); EXPECT_EQ(inst->opcode(), SpvOpConstant); analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const analysis::IntConstant* result = - const_mrg->GetConstantFromInst(inst)->AsIntConstant(); + const analysis::Constant* constant = const_mrg->GetConstantFromInst(inst); + // We expect to see either integer types or 16-bit float types here. + EXPECT_TRUE((constant->AsIntConstant() != nullptr) || + ((constant->AsFloatConstant() != nullptr) && + (constant->type()->AsFloat()->width() == 16))); + const analysis::ScalarConstant* result = + const_mrg->GetConstantFromInst(inst)->AsScalarConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { EXPECT_EQ(result->GetU32BitValue(), tc.expected_result); @@ -115,6 +120,7 @@ const std::string& Header() { static const std::string header = R"(OpCapability Shader OpCapability Float16 OpCapability Float64 +OpCapability Int8 OpCapability Int16 OpCapability Int64 %1 = OpExtInstImport "GLSL.std.450" @@ -134,6 +140,9 @@ OpName %main "main" %false = OpConstantFalse %bool %bool_null = OpConstantNull %bool %short = OpTypeInt 16 1 +%ushort = OpTypeInt 16 0 +%byte = OpTypeInt 8 1 +%ubyte = OpTypeInt 8 0 %int = OpTypeInt 32 1 %long = OpTypeInt 64 1 %uint = OpTypeInt 32 0 @@ -169,6 +178,8 @@ OpName %main "main" %short_0 = OpConstant %short 0 %short_2 = OpConstant %short 2 %short_3 = OpConstant %short 3 +%ubyte_1 = OpConstant %ubyte 1 +%byte_n1 = OpConstant %byte -1 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. %int_0 = OpConstant %int 0 @@ -302,6 +313,8 @@ OpName %main "main" %int_0xC05FD666 = OpConstant %int 0xC05FD666 %int_0x66666666 = OpConstant %int 0x66666666 %v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666 = OpConstantComposite %v4int %int_0x00000000 %int_0x3FF00000 %int_0x66666666 %int_0xC05FD666 +%ushort_0xBC00 = OpConstant %ushort 0xBC00 +%short_0xBC00 = OpConstant %short 0xBC00 )"; return header; @@ -776,7 +789,95 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpBitcast %uint %float_1\n" + "OpReturn\n" + "OpFunctionEnd", - 2, static_cast(0x3f800000)) + 2, static_cast(0x3f800000)), + // Test case 49: Bit-cast ushort 0xBC00 to ushort + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ushort %ushort_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xBC00), + // Test case 50: Bit-cast short 0xBC00 to ushort + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ushort %short_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFBC00), + // Test case 51: Bit-cast half 1 to ushort + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ushort %half_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3C00), + // Test case 52: Bit-cast ushort 0xBC00 to short + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %short %ushort_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xBC00), + // Test case 53: Bit-cast short 0xBC00 to short + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %short %short_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFBC00), + // Test case 54: Bit-cast half 1 to short + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %short %half_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3C00), + // Test case 55: Bit-cast ushort 0xBC00 to half + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %half %ushort_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xBC00), + // Test case 56: Bit-cast short 0xBC00 to half + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %half %short_0xBC00\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFBC00), + // Test case 57: Bit-cast half 1 to half + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %half %half_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3C00), + // Test case 58: Bit-cast ubyte 1 to byte + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %byte %ubyte_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 1), + // Test case 59: Bit-cast byte -1 to ubyte + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ubyte %byte_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFFFFF) )); // clang-format on