From 581279dedd59d8353322fc2d61be07ccdcad0f13 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 19 Jun 2024 13:17:05 -0400 Subject: [PATCH] [OPT] Zero-extend unsigned 16-bit integers when bitcasting (#5714) The folding rule `BitCastScalarOrVector` was incorrectly handling bitcasting to unsigned integers smaller than 32-bits. It was simply copying the entire 32-bit word containing the integer. This conflicts with the requirement in section 2.2.1 of the SPIR-V spec which states that unsigned numeric types with a bit width less than 32-bits must have the high-order bits set to 0. This change include a refactor of the bit extension code to be able to test it better, and to use it in multiple files. Fixes https://github.com/microsoft/DirectXShaderCompiler/issues/6319. --- source/opt/const_folding_rules.cpp | 61 ++----------------- source/opt/constants.cpp | 22 +++++++ source/opt/constants.h | 5 ++ ...ld_spec_constant_op_and_composite_pass.cpp | 13 +--- source/opt/folding_rules.cpp | 10 ++- source/util/bitutils.h | 25 ++++++++ test/opt/fold_test.cpp | 6 +- test/util/bitutils_test.cpp | 40 ++++++++++++ 8 files changed, 108 insertions(+), 74 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 17900af24..a5d4cbe75 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -21,59 +21,6 @@ namespace opt { namespace { constexpr uint32_t kExtractCompositeIdInIdx = 0; -// Returns the value obtained by extracting the |number_of_bits| least -// significant bits from |value|, and sign-extending it to 64-bits. -uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { - if (number_of_bits == 64) return value; - - uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1); - uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull; - if (value & mask_for_sign_bit) { - // Set upper bits to 1 - value |= ~mask_for_significant_bits; - } else { - // Clear the upper bits - value &= mask_for_significant_bits; - } - return value; -} - -// Returns the value obtained by extracting the |number_of_bits| least -// significant bits from |value|, and zero-extending it to 64-bits. -uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) { - if (number_of_bits == 64) return value; - - uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits); - uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1; - value &= mask_for_bits_to_keep; - return value; -} - -// Returns a constant whose value is `value` and type is `type`. This constant -// will be generated by `const_mgr`. The type must be a scalar integer type. -const analysis::Constant* GenerateIntegerConstant( - const analysis::Integer* integer_type, uint64_t result, - analysis::ConstantManager* const_mgr) { - assert(integer_type != nullptr); - - std::vector words; - if (integer_type->width() == 64) { - // In the 64-bit case, two words are needed to represent the value. - words = {static_cast(result), - static_cast(result >> 32)}; - } else { - // In all other cases, only a single word is needed. - assert(integer_type->width() <= 32); - if (integer_type->IsSigned()) { - result = SignExtendValue(result, integer_type->width()); - } else { - result = ZeroExtendValue(result, integer_type->width()); - } - words = {static_cast(result)}; - } - return const_mgr->GetConstant(integer_type, words); -} - // Returns a constants with the value NaN of the given type. Only works for // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. const analysis::Constant* GetNan(const analysis::Type* type, @@ -1730,7 +1677,7 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, uint64_t result = op(ia, ib); const analysis::Constant* result_constant = - GenerateIntegerConstant(integer_type, result, const_mgr); + const_mgr->GenerateIntegerConstant(integer_type, result); return result_constant; }; } @@ -1745,7 +1692,7 @@ const analysis::Constant* FoldScalarSConvert( const analysis::Integer* integer_type = result_type->AsInteger(); assert(integer_type && "The result type of an SConvert"); int64_t value = a->GetSignExtendedValue(); - return GenerateIntegerConstant(integer_type, value, const_mgr); + return const_mgr->GenerateIntegerConstant(integer_type, value); } // A scalar folding rule that folds OpUConvert. @@ -1762,8 +1709,8 @@ const analysis::Constant* FoldScalarUConvert( // If the operand was an unsigned value with less than 32-bit, it would have // been sign extended earlier, and we need to clear those bits. auto* operand_type = a->type()->AsInteger(); - value = ZeroExtendValue(value, operand_type->width()); - return GenerateIntegerConstant(integer_type, value, const_mgr); + value = utils::ClearHighBits(value, 64 - operand_type->width()); + return const_mgr->GenerateIntegerConstant(integer_type, value); } } // namespace diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index 6eebbb572..7dc02deaa 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -525,6 +525,28 @@ uint32_t ConstantManager::GetNullConstId(const Type* type) { return GetDefiningInstruction(c)->result_id(); } +const Constant* ConstantManager::GenerateIntegerConstant( + const analysis::Integer* integer_type, uint64_t result) { + assert(integer_type != nullptr); + + std::vector words; + if (integer_type->width() == 64) { + // In the 64-bit case, two words are needed to represent the value. + words = {static_cast(result), + static_cast(result >> 32)}; + } else { + // In all other cases, only a single word is needed. + assert(integer_type->width() <= 32); + if (integer_type->IsSigned()) { + result = utils::SignExtendValue(result, integer_type->width()); + } else { + result = utils::ZeroExtendValue(result, integer_type->width()); + } + words = {static_cast(result)}; + } + return GetConstant(integer_type, words); +} + std::vector Constant::GetVectorComponents( analysis::ConstantManager* const_mgr) const { std::vector components; diff --git a/source/opt/constants.h b/source/opt/constants.h index ae8dc6259..534afa6f5 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -671,6 +671,11 @@ class ConstantManager { // Returns the id of a OpConstantNull with type of |type|. uint32_t GetNullConstId(const Type* type); + // Returns a constant whose value is `value` and type is `type`. This constant + // will be generated by `const_mgr`. The type must be a scalar integer type. + const Constant* GenerateIntegerConstant(const analysis::Integer* integer_type, + uint64_t result); + private: // Creates a Constant instance with the given type and a vector of constant // defining words. Returns a unique pointer to the created Constant instance diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index c568027d2..ddfe59f75 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -247,18 +247,7 @@ utils::SmallVector EncodeIntegerAsWords(const analysis::Type& type, // Truncate first_word if the |type| has width less than uint32. if (bit_width < bits_per_word) { - const uint32_t num_high_bits_to_mask = bits_per_word - bit_width; - const bool is_negative_after_truncation = - result_type_signed && - utils::IsBitAtPositionSet(first_word, bit_width - 1); - - if (is_negative_after_truncation) { - // Truncate and sign-extend |first_word|. No padding words will be - // added and |pad_value| can be left as-is. - first_word = utils::SetHighBits(first_word, num_high_bits_to_mask); - } else { - first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask); - } + first_word = utils::SignExtendValue(first_word, bit_width); } utils::SmallVector words = {first_word}; diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 24979671f..6def9c47f 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -180,8 +180,14 @@ std::vector GetWordsFromNumericScalarOrVectorConstant( const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant( analysis::ConstantManager* const_mgr, const std::vector& words, const analysis::Type* type) { - if (type->AsInteger() || type->AsFloat()) - return const_mgr->GetConstant(type, words); + const spvtools::opt::analysis::Integer* int_type = type->AsInteger(); + + if (int_type && int_type->width() <= 32) { + assert(words.size() == 1); + return const_mgr->GenerateIntegerConstant(int_type, words[0]); + } + + if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words); if (const auto* vec_type = type->AsVector()) return const_mgr->GetNumericVectorConstantWithWords(vec_type, words); return nullptr; diff --git a/source/util/bitutils.h b/source/util/bitutils.h index 9ced2f962..a121dc356 100644 --- a/source/util/bitutils.h +++ b/source/util/bitutils.h @@ -181,6 +181,31 @@ T ClearHighBits(T word, size_t num_bits_to_set) { false); } +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and sign-extending it to 64-bits. +template +T SignExtendValue(T value, uint32_t number_of_bits) { + const uint32_t bit_width = sizeof(value) * 8; + if (number_of_bits == bit_width) return value; + + bool is_negative = utils::IsBitAtPositionSet(value, number_of_bits - 1); + if (is_negative) { + value = utils::SetHighBits(value, bit_width - number_of_bits); + } else { + value = utils::ClearHighBits(value, bit_width - number_of_bits); + } + return value; +} + +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and zero-extending it to 64-bits. +template +T ZeroExtendValue(T value, uint32_t number_of_bits) { + const uint32_t bit_width = sizeof(value) * 8; + if (number_of_bits == bit_width) return value; + return utils::ClearHighBits(value, bit_width - number_of_bits); +} + } // namespace utils } // namespace spvtools diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 35828ab22..cb14b94fc 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -924,7 +924,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpBitcast %ushort %short_0xBC00\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0xFFFFBC00), + 2, 0xBC00), // Test case 53: Bit-cast half 1 to ushort InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + @@ -940,7 +940,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpBitcast %short %ushort_0xBC00\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0xBC00), + 2, 0xFFFFBC00), // Test case 55: Bit-cast short 0xBC00 to short InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + @@ -996,7 +996,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpBitcast %ubyte %byte_n1\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0xFFFFFFFF), + 2, 0xFF), // Test case 62: Negate 2. InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + diff --git a/test/util/bitutils_test.cpp b/test/util/bitutils_test.cpp index 3be7ed269..aea789766 100644 --- a/test/util/bitutils_test.cpp +++ b/test/util/bitutils_test.cpp @@ -188,6 +188,46 @@ TEST(BitUtilsTest, IsBitSetAtPositionAll) { EXPECT_TRUE(IsBitAtPositionSet(max_u64, i)); } } + +struct ExtendedValueTestCase { + uint32_t input; + uint32_t bit_width; + uint32_t expected_result; +}; + +using SignExtendedValueTest = ::testing::TestWithParam; + +TEST_P(SignExtendedValueTest, SignExtendValue) { + const auto& tc = GetParam(); + auto result = SignExtendValue(tc.input, tc.bit_width); + EXPECT_EQ(result, tc.expected_result); +} +INSTANTIATE_TEST_SUITE_P( + SignExtendValue, SignExtendedValueTest, + ::testing::Values(ExtendedValueTestCase{1, 1, 0xFFFFFFFF}, + ExtendedValueTestCase{1, 2, 0x1}, + ExtendedValueTestCase{2, 1, 0x0}, + ExtendedValueTestCase{0x8, 4, 0xFFFFFFF8}, + ExtendedValueTestCase{0x8765, 16, 0xFFFF8765}, + ExtendedValueTestCase{0x7765, 16, 0x7765}, + ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF})); + +using ZeroExtendedValueTest = ::testing::TestWithParam; + +TEST_P(ZeroExtendedValueTest, ZeroExtendValue) { + const auto& tc = GetParam(); + auto result = ZeroExtendValue(tc.input, tc.bit_width); + EXPECT_EQ(result, tc.expected_result); +} + +INSTANTIATE_TEST_SUITE_P( + ZeroExtendValue, ZeroExtendedValueTest, + ::testing::Values(ExtendedValueTestCase{1, 1, 0x1}, + ExtendedValueTestCase{1, 2, 0x1}, + ExtendedValueTestCase{2, 1, 0x0}, + ExtendedValueTestCase{0x8, 4, 0x8}, + ExtendedValueTestCase{0xFF8765, 16, 0x8765}, + ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF})); } // namespace } // namespace utils } // namespace spvtools