Support Narrow Types in BitCast Folding Rule (#4941)

* Support Narrow Types in BitCast Folding Rule

This change adds support for narrow types in the BitCastScalarOrVector
folding rule. According to Section 2.2.1 of the SPIR-V spec, types that
are narrower than 32 bits are automatically either sign extended, or
zero extended depending on the type. With that guaranteed, we should
be able to use the first 32-bit word of any narrow type for the folding
logic without performing any special conversions.

In order to reduce code duplication, this change moves the
GetU32BitValue and GetU64BitValue functions from IntConstant to
ScalarConstant. Without this move, we would have needed an identical
version of GetU32BitValue on FloatConstant.

* Add Tests for 16-bit BitCast Folding

This change adds several new test cases to the
IntegerInstructionFoldingTest which trigger the 16-bit BitCast logic.
The logic for half types was also added to the integer case since we
can't easily validate half float types in C++ code. It's easier to
validate them as unsigned integers instead. Pllus this also allows us
to verify the SPIR-V constant sign extension logic too.

* Add 8-Bit Folding Test Cases

This change adds a couple more test cases to the integer instruction
folding test suite in order to ensure that the BitCast logic also
works correctly with the Int8 shader capability.
This commit is contained in:
gmitrano-unity 2022-10-06 10:35:18 -04:00 committed by GitHub
parent a6e6454ef2
commit 1cecf91701
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 23 deletions

View File

@ -163,6 +163,21 @@ class ScalarConstant : public Constant {
return is_zero; return is_zero;
} }
uint32_t GetU32BitValue() const {
// Relies on unsigned values smaller than 32-bit being zero extended. See
// section 2.2.1 of the SPIR-V spec.
assert(words().size() == 1);
return words()[0];
}
uint64_t GetU64BitValue() const {
// Relies on unsigned values smaller than 64-bit being zero extended. See
// section 2.2.1 of the SPIR-V spec.
assert(words().size() == 2);
return static_cast<uint64_t>(words()[1]) << 32 |
static_cast<uint64_t>(words()[0]);
}
protected: protected:
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w) ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
: Constant(ty), words_(w) {} : Constant(ty), words_(w) {}
@ -189,13 +204,6 @@ class IntConstant : public ScalarConstant {
return words()[0]; return words()[0];
} }
uint32_t GetU32BitValue() const {
// Relies on unsigned values smaller than 32-bit being zero extended. See
// section 2.2.1 of the SPIR-V spec.
assert(words().size() == 1);
return words()[0];
}
int64_t GetS64BitValue() const { int64_t GetS64BitValue() const {
// Relies on unsigned values smaller than 64-bit being sign extended. See // Relies on unsigned values smaller than 64-bit being sign extended. See
// section 2.2.1 of the SPIR-V spec. // section 2.2.1 of the SPIR-V spec.
@ -204,14 +212,6 @@ class IntConstant : public ScalarConstant {
static_cast<uint64_t>(words()[0]); static_cast<uint64_t>(words()[0]);
} }
uint64_t GetU64BitValue() const {
// Relies on unsigned values smaller than 64-bit being zero extended. See
// section 2.2.1 of the SPIR-V spec.
assert(words().size() == 2);
return static_cast<uint64_t>(words()[1]) << 32 |
static_cast<uint64_t>(words()[0]);
}
// Make a copy of this IntConstant instance. // Make a copy of this IntConstant instance.
std::unique_ptr<IntConstant> CopyIntConstant() const { std::unique_ptr<IntConstant> CopyIntConstant() const {
return MakeUnique<IntConstant>(type_->AsInteger(), words_); return MakeUnique<IntConstant>(type_->AsInteger(), words_);

View File

@ -136,25 +136,28 @@ std::vector<uint32_t> GetWordsFromScalarIntConstant(
const analysis::IntConstant* c) { const analysis::IntConstant* c) {
assert(c != nullptr); assert(c != nullptr);
uint32_t width = c->type()->AsInteger()->width(); uint32_t width = c->type()->AsInteger()->width();
assert(width == 32 || width == 64); assert(width == 8 || width == 16 || width == 32 || width == 64);
if (width == 64) { if (width == 64) {
uint64_t uval = static_cast<uint64_t>(c->GetU64()); uint64_t uval = static_cast<uint64_t>(c->GetU64());
return ExtractInts(uval); return ExtractInts(uval);
} }
return {c->GetU32()}; // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
// smaller than 32-bits are automatically zero or sign extended to 32-bits.
return {c->GetU32BitValue()};
} }
std::vector<uint32_t> GetWordsFromScalarFloatConstant( std::vector<uint32_t> GetWordsFromScalarFloatConstant(
const analysis::FloatConstant* c) { const analysis::FloatConstant* c) {
assert(c != nullptr); assert(c != nullptr);
uint32_t width = c->type()->AsFloat()->width(); uint32_t width = c->type()->AsFloat()->width();
assert(width == 32 || width == 64); assert(width == 16 || width == 32 || width == 64);
if (width == 64) { if (width == 64) {
utils::FloatProxy<double> result(c->GetDouble()); utils::FloatProxy<double> result(c->GetDouble());
return result.GetWords(); return result.GetWords();
} }
utils::FloatProxy<float> result(c->GetFloat()); // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
return result.GetWords(); // smaller than 32-bits are automatically zero extended to 32-bits.
return {c->GetU32BitValue()};
} }
std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant( std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(

View File

@ -92,8 +92,13 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
EXPECT_EQ(inst->opcode(), SpvOpConstant); EXPECT_EQ(inst->opcode(), SpvOpConstant);
analysis::ConstantManager* const_mrg = context->get_constant_mgr(); analysis::ConstantManager* const_mrg = context->get_constant_mgr();
const analysis::IntConstant* result = const analysis::Constant* constant = const_mrg->GetConstantFromInst(inst);
const_mrg->GetConstantFromInst(inst)->AsIntConstant(); // We expect to see either integer types or 16-bit float types here.
EXPECT_TRUE((constant->AsIntConstant() != nullptr) ||
((constant->AsFloatConstant() != nullptr) &&
(constant->type()->AsFloat()->width() == 16)));
const analysis::ScalarConstant* result =
const_mrg->GetConstantFromInst(inst)->AsScalarConstant();
EXPECT_NE(result, nullptr); EXPECT_NE(result, nullptr);
if (result != nullptr) { if (result != nullptr) {
EXPECT_EQ(result->GetU32BitValue(), tc.expected_result); EXPECT_EQ(result->GetU32BitValue(), tc.expected_result);
@ -115,6 +120,7 @@ const std::string& Header() {
static const std::string header = R"(OpCapability Shader static const std::string header = R"(OpCapability Shader
OpCapability Float16 OpCapability Float16
OpCapability Float64 OpCapability Float64
OpCapability Int8
OpCapability Int16 OpCapability Int16
OpCapability Int64 OpCapability Int64
%1 = OpExtInstImport "GLSL.std.450" %1 = OpExtInstImport "GLSL.std.450"
@ -134,6 +140,9 @@ OpName %main "main"
%false = OpConstantFalse %bool %false = OpConstantFalse %bool
%bool_null = OpConstantNull %bool %bool_null = OpConstantNull %bool
%short = OpTypeInt 16 1 %short = OpTypeInt 16 1
%ushort = OpTypeInt 16 0
%byte = OpTypeInt 8 1
%ubyte = OpTypeInt 8 0
%int = OpTypeInt 32 1 %int = OpTypeInt 32 1
%long = OpTypeInt 64 1 %long = OpTypeInt 64 1
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
@ -169,6 +178,8 @@ OpName %main "main"
%short_0 = OpConstant %short 0 %short_0 = OpConstant %short 0
%short_2 = OpConstant %short 2 %short_2 = OpConstant %short 2
%short_3 = OpConstant %short 3 %short_3 = OpConstant %short 3
%ubyte_1 = OpConstant %ubyte 1
%byte_n1 = OpConstant %byte -1
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
%int_0 = OpConstant %int 0 %int_0 = OpConstant %int 0
@ -302,6 +313,8 @@ OpName %main "main"
%int_0xC05FD666 = OpConstant %int 0xC05FD666 %int_0xC05FD666 = OpConstant %int 0xC05FD666
%int_0x66666666 = OpConstant %int 0x66666666 %int_0x66666666 = OpConstant %int 0x66666666
%v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666 = OpConstantComposite %v4int %int_0x00000000 %int_0x3FF00000 %int_0x66666666 %int_0xC05FD666 %v4int_0x3FF00000_0x00000000_0xC05FD666_0x66666666 = OpConstantComposite %v4int %int_0x00000000 %int_0x3FF00000 %int_0x66666666 %int_0xC05FD666
%ushort_0xBC00 = OpConstant %ushort 0xBC00
%short_0xBC00 = OpConstant %short 0xBC00
)"; )";
return header; return header;
@ -776,7 +789,95 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
"%2 = OpBitcast %uint %float_1\n" + "%2 = OpBitcast %uint %float_1\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
2, static_cast<uint32_t>(0x3f800000)) 2, static_cast<uint32_t>(0x3f800000)),
// Test case 49: Bit-cast ushort 0xBC00 to ushort
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %ushort %ushort_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xBC00),
// Test case 50: Bit-cast short 0xBC00 to ushort
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %ushort %short_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xFFFFBC00),
// Test case 51: Bit-cast half 1 to ushort
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %ushort %half_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0x3C00),
// Test case 52: Bit-cast ushort 0xBC00 to short
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %short %ushort_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xBC00),
// Test case 53: Bit-cast short 0xBC00 to short
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %short %short_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xFFFFBC00),
// Test case 54: Bit-cast half 1 to short
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %short %half_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0x3C00),
// Test case 55: Bit-cast ushort 0xBC00 to half
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %half %ushort_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xBC00),
// Test case 56: Bit-cast short 0xBC00 to half
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %half %short_0xBC00\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xFFFFBC00),
// Test case 57: Bit-cast half 1 to half
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %half %half_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0x3C00),
// Test case 58: Bit-cast ubyte 1 to byte
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %byte %ubyte_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1),
// Test case 59: Bit-cast byte -1 to ubyte
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpBitcast %ubyte %byte_n1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0xFFFFFFFF)
)); ));
// clang-format on // clang-format on