mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-24 20:40:13 +00:00
[OPT] Zero-extend unsigned 16-bit integers when bitcasting (#5714)
The folding rule `BitCastScalarOrVector` was incorrectly handling bitcasting to unsigned integers smaller than 32-bits. It was simply copying the entire 32-bit word containing the integer. This conflicts with the requirement in section 2.2.1 of the SPIR-V spec which states that unsigned numeric types with a bit width less than 32-bits must have the high-order bits set to 0. This change include a refactor of the bit extension code to be able to test it better, and to use it in multiple files. Fixes https://github.com/microsoft/DirectXShaderCompiler/issues/6319.
This commit is contained in:
parent
80a1aed219
commit
581279dedd
@ -21,59 +21,6 @@ namespace opt {
|
|||||||
namespace {
|
namespace {
|
||||||
constexpr uint32_t kExtractCompositeIdInIdx = 0;
|
constexpr uint32_t kExtractCompositeIdInIdx = 0;
|
||||||
|
|
||||||
// Returns the value obtained by extracting the |number_of_bits| least
|
|
||||||
// significant bits from |value|, and sign-extending it to 64-bits.
|
|
||||||
uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
|
|
||||||
if (number_of_bits == 64) return value;
|
|
||||||
|
|
||||||
uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
|
|
||||||
uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
|
|
||||||
if (value & mask_for_sign_bit) {
|
|
||||||
// Set upper bits to 1
|
|
||||||
value |= ~mask_for_significant_bits;
|
|
||||||
} else {
|
|
||||||
// Clear the upper bits
|
|
||||||
value &= mask_for_significant_bits;
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the value obtained by extracting the |number_of_bits| least
|
|
||||||
// significant bits from |value|, and zero-extending it to 64-bits.
|
|
||||||
uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
|
|
||||||
if (number_of_bits == 64) return value;
|
|
||||||
|
|
||||||
uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
|
|
||||||
uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
|
|
||||||
value &= mask_for_bits_to_keep;
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a constant whose value is `value` and type is `type`. This constant
|
|
||||||
// will be generated by `const_mgr`. The type must be a scalar integer type.
|
|
||||||
const analysis::Constant* GenerateIntegerConstant(
|
|
||||||
const analysis::Integer* integer_type, uint64_t result,
|
|
||||||
analysis::ConstantManager* const_mgr) {
|
|
||||||
assert(integer_type != nullptr);
|
|
||||||
|
|
||||||
std::vector<uint32_t> words;
|
|
||||||
if (integer_type->width() == 64) {
|
|
||||||
// In the 64-bit case, two words are needed to represent the value.
|
|
||||||
words = {static_cast<uint32_t>(result),
|
|
||||||
static_cast<uint32_t>(result >> 32)};
|
|
||||||
} else {
|
|
||||||
// In all other cases, only a single word is needed.
|
|
||||||
assert(integer_type->width() <= 32);
|
|
||||||
if (integer_type->IsSigned()) {
|
|
||||||
result = SignExtendValue(result, integer_type->width());
|
|
||||||
} else {
|
|
||||||
result = ZeroExtendValue(result, integer_type->width());
|
|
||||||
}
|
|
||||||
words = {static_cast<uint32_t>(result)};
|
|
||||||
}
|
|
||||||
return const_mgr->GetConstant(integer_type, words);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a constants with the value NaN of the given type. Only works for
|
// Returns a constants with the value NaN of the given type. Only works for
|
||||||
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
|
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
|
||||||
const analysis::Constant* GetNan(const analysis::Type* type,
|
const analysis::Constant* GetNan(const analysis::Type* type,
|
||||||
@ -1730,7 +1677,7 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
|
|||||||
uint64_t result = op(ia, ib);
|
uint64_t result = op(ia, ib);
|
||||||
|
|
||||||
const analysis::Constant* result_constant =
|
const analysis::Constant* result_constant =
|
||||||
GenerateIntegerConstant(integer_type, result, const_mgr);
|
const_mgr->GenerateIntegerConstant(integer_type, result);
|
||||||
return result_constant;
|
return result_constant;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -1745,7 +1692,7 @@ const analysis::Constant* FoldScalarSConvert(
|
|||||||
const analysis::Integer* integer_type = result_type->AsInteger();
|
const analysis::Integer* integer_type = result_type->AsInteger();
|
||||||
assert(integer_type && "The result type of an SConvert");
|
assert(integer_type && "The result type of an SConvert");
|
||||||
int64_t value = a->GetSignExtendedValue();
|
int64_t value = a->GetSignExtendedValue();
|
||||||
return GenerateIntegerConstant(integer_type, value, const_mgr);
|
return const_mgr->GenerateIntegerConstant(integer_type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// A scalar folding rule that folds OpUConvert.
|
// A scalar folding rule that folds OpUConvert.
|
||||||
@ -1762,8 +1709,8 @@ const analysis::Constant* FoldScalarUConvert(
|
|||||||
// If the operand was an unsigned value with less than 32-bit, it would have
|
// If the operand was an unsigned value with less than 32-bit, it would have
|
||||||
// been sign extended earlier, and we need to clear those bits.
|
// been sign extended earlier, and we need to clear those bits.
|
||||||
auto* operand_type = a->type()->AsInteger();
|
auto* operand_type = a->type()->AsInteger();
|
||||||
value = ZeroExtendValue(value, operand_type->width());
|
value = utils::ClearHighBits(value, 64 - operand_type->width());
|
||||||
return GenerateIntegerConstant(integer_type, value, const_mgr);
|
return const_mgr->GenerateIntegerConstant(integer_type, value);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -525,6 +525,28 @@ uint32_t ConstantManager::GetNullConstId(const Type* type) {
|
|||||||
return GetDefiningInstruction(c)->result_id();
|
return GetDefiningInstruction(c)->result_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Constant* ConstantManager::GenerateIntegerConstant(
|
||||||
|
const analysis::Integer* integer_type, uint64_t result) {
|
||||||
|
assert(integer_type != nullptr);
|
||||||
|
|
||||||
|
std::vector<uint32_t> words;
|
||||||
|
if (integer_type->width() == 64) {
|
||||||
|
// In the 64-bit case, two words are needed to represent the value.
|
||||||
|
words = {static_cast<uint32_t>(result),
|
||||||
|
static_cast<uint32_t>(result >> 32)};
|
||||||
|
} else {
|
||||||
|
// In all other cases, only a single word is needed.
|
||||||
|
assert(integer_type->width() <= 32);
|
||||||
|
if (integer_type->IsSigned()) {
|
||||||
|
result = utils::SignExtendValue(result, integer_type->width());
|
||||||
|
} else {
|
||||||
|
result = utils::ZeroExtendValue(result, integer_type->width());
|
||||||
|
}
|
||||||
|
words = {static_cast<uint32_t>(result)};
|
||||||
|
}
|
||||||
|
return GetConstant(integer_type, words);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
|
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
|
||||||
analysis::ConstantManager* const_mgr) const {
|
analysis::ConstantManager* const_mgr) const {
|
||||||
std::vector<const analysis::Constant*> components;
|
std::vector<const analysis::Constant*> components;
|
||||||
|
@ -671,6 +671,11 @@ class ConstantManager {
|
|||||||
// Returns the id of a OpConstantNull with type of |type|.
|
// Returns the id of a OpConstantNull with type of |type|.
|
||||||
uint32_t GetNullConstId(const Type* type);
|
uint32_t GetNullConstId(const Type* type);
|
||||||
|
|
||||||
|
// Returns a constant whose value is `value` and type is `type`. This constant
|
||||||
|
// will be generated by `const_mgr`. The type must be a scalar integer type.
|
||||||
|
const Constant* GenerateIntegerConstant(const analysis::Integer* integer_type,
|
||||||
|
uint64_t result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Creates a Constant instance with the given type and a vector of constant
|
// Creates a Constant instance with the given type and a vector of constant
|
||||||
// defining words. Returns a unique pointer to the created Constant instance
|
// defining words. Returns a unique pointer to the created Constant instance
|
||||||
|
@ -247,18 +247,7 @@ utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
|
|||||||
|
|
||||||
// Truncate first_word if the |type| has width less than uint32.
|
// Truncate first_word if the |type| has width less than uint32.
|
||||||
if (bit_width < bits_per_word) {
|
if (bit_width < bits_per_word) {
|
||||||
const uint32_t num_high_bits_to_mask = bits_per_word - bit_width;
|
first_word = utils::SignExtendValue(first_word, bit_width);
|
||||||
const bool is_negative_after_truncation =
|
|
||||||
result_type_signed &&
|
|
||||||
utils::IsBitAtPositionSet(first_word, bit_width - 1);
|
|
||||||
|
|
||||||
if (is_negative_after_truncation) {
|
|
||||||
// Truncate and sign-extend |first_word|. No padding words will be
|
|
||||||
// added and |pad_value| can be left as-is.
|
|
||||||
first_word = utils::SetHighBits(first_word, num_high_bits_to_mask);
|
|
||||||
} else {
|
|
||||||
first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::SmallVector<uint32_t, 2> words = {first_word};
|
utils::SmallVector<uint32_t, 2> words = {first_word};
|
||||||
|
@ -180,8 +180,14 @@ std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
|
|||||||
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
|
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
|
||||||
analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
|
analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
|
||||||
const analysis::Type* type) {
|
const analysis::Type* type) {
|
||||||
if (type->AsInteger() || type->AsFloat())
|
const spvtools::opt::analysis::Integer* int_type = type->AsInteger();
|
||||||
return const_mgr->GetConstant(type, words);
|
|
||||||
|
if (int_type && int_type->width() <= 32) {
|
||||||
|
assert(words.size() == 1);
|
||||||
|
return const_mgr->GenerateIntegerConstant(int_type, words[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words);
|
||||||
if (const auto* vec_type = type->AsVector())
|
if (const auto* vec_type = type->AsVector())
|
||||||
return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
|
return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -181,6 +181,31 @@ T ClearHighBits(T word, size_t num_bits_to_set) {
|
|||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the value obtained by extracting the |number_of_bits| least
|
||||||
|
// significant bits from |value|, and sign-extending it to 64-bits.
|
||||||
|
template <typename T>
|
||||||
|
T SignExtendValue(T value, uint32_t number_of_bits) {
|
||||||
|
const uint32_t bit_width = sizeof(value) * 8;
|
||||||
|
if (number_of_bits == bit_width) return value;
|
||||||
|
|
||||||
|
bool is_negative = utils::IsBitAtPositionSet(value, number_of_bits - 1);
|
||||||
|
if (is_negative) {
|
||||||
|
value = utils::SetHighBits(value, bit_width - number_of_bits);
|
||||||
|
} else {
|
||||||
|
value = utils::ClearHighBits(value, bit_width - number_of_bits);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the value obtained by extracting the |number_of_bits| least
|
||||||
|
// significant bits from |value|, and zero-extending it to 64-bits.
|
||||||
|
template <typename T>
|
||||||
|
T ZeroExtendValue(T value, uint32_t number_of_bits) {
|
||||||
|
const uint32_t bit_width = sizeof(value) * 8;
|
||||||
|
if (number_of_bits == bit_width) return value;
|
||||||
|
return utils::ClearHighBits(value, bit_width - number_of_bits);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
|
||||||
|
@ -924,7 +924,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
|
|||||||
"%2 = OpBitcast %ushort %short_0xBC00\n" +
|
"%2 = OpBitcast %ushort %short_0xBC00\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
2, 0xFFFFBC00),
|
2, 0xBC00),
|
||||||
// Test case 53: Bit-cast half 1 to ushort
|
// Test case 53: Bit-cast half 1 to ushort
|
||||||
InstructionFoldingCase<uint32_t>(
|
InstructionFoldingCase<uint32_t>(
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
@ -940,7 +940,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
|
|||||||
"%2 = OpBitcast %short %ushort_0xBC00\n" +
|
"%2 = OpBitcast %short %ushort_0xBC00\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
2, 0xBC00),
|
2, 0xFFFFBC00),
|
||||||
// Test case 55: Bit-cast short 0xBC00 to short
|
// Test case 55: Bit-cast short 0xBC00 to short
|
||||||
InstructionFoldingCase<uint32_t>(
|
InstructionFoldingCase<uint32_t>(
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
@ -996,7 +996,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
|
|||||||
"%2 = OpBitcast %ubyte %byte_n1\n" +
|
"%2 = OpBitcast %ubyte %byte_n1\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
2, 0xFFFFFFFF),
|
2, 0xFF),
|
||||||
// Test case 62: Negate 2.
|
// Test case 62: Negate 2.
|
||||||
InstructionFoldingCase<uint32_t>(
|
InstructionFoldingCase<uint32_t>(
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
@ -188,6 +188,46 @@ TEST(BitUtilsTest, IsBitSetAtPositionAll) {
|
|||||||
EXPECT_TRUE(IsBitAtPositionSet(max_u64, i));
|
EXPECT_TRUE(IsBitAtPositionSet(max_u64, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ExtendedValueTestCase {
|
||||||
|
uint32_t input;
|
||||||
|
uint32_t bit_width;
|
||||||
|
uint32_t expected_result;
|
||||||
|
};
|
||||||
|
|
||||||
|
using SignExtendedValueTest = ::testing::TestWithParam<ExtendedValueTestCase>;
|
||||||
|
|
||||||
|
TEST_P(SignExtendedValueTest, SignExtendValue) {
|
||||||
|
const auto& tc = GetParam();
|
||||||
|
auto result = SignExtendValue(tc.input, tc.bit_width);
|
||||||
|
EXPECT_EQ(result, tc.expected_result);
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SignExtendValue, SignExtendedValueTest,
|
||||||
|
::testing::Values(ExtendedValueTestCase{1, 1, 0xFFFFFFFF},
|
||||||
|
ExtendedValueTestCase{1, 2, 0x1},
|
||||||
|
ExtendedValueTestCase{2, 1, 0x0},
|
||||||
|
ExtendedValueTestCase{0x8, 4, 0xFFFFFFF8},
|
||||||
|
ExtendedValueTestCase{0x8765, 16, 0xFFFF8765},
|
||||||
|
ExtendedValueTestCase{0x7765, 16, 0x7765},
|
||||||
|
ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF}));
|
||||||
|
|
||||||
|
using ZeroExtendedValueTest = ::testing::TestWithParam<ExtendedValueTestCase>;
|
||||||
|
|
||||||
|
TEST_P(ZeroExtendedValueTest, ZeroExtendValue) {
|
||||||
|
const auto& tc = GetParam();
|
||||||
|
auto result = ZeroExtendValue(tc.input, tc.bit_width);
|
||||||
|
EXPECT_EQ(result, tc.expected_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ZeroExtendValue, ZeroExtendedValueTest,
|
||||||
|
::testing::Values(ExtendedValueTestCase{1, 1, 0x1},
|
||||||
|
ExtendedValueTestCase{1, 2, 0x1},
|
||||||
|
ExtendedValueTestCase{2, 1, 0x0},
|
||||||
|
ExtendedValueTestCase{0x8, 4, 0x8},
|
||||||
|
ExtendedValueTestCase{0xFF8765, 16, 0x8765},
|
||||||
|
ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF}));
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
Loading…
Reference in New Issue
Block a user