mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 19:20:05 +00:00
Handles more cases of redundant selects
* Handles OpConstantNull and vector types * vector selects (except against a null) are converted to vector shuffles * Added tests
This commit is contained in:
parent
a7cec7843c
commit
52bceb3569
@ -1549,25 +1549,65 @@ FoldingRule RedundantSelect() {
|
|||||||
assert(inst->NumInOperands() == 3);
|
assert(inst->NumInOperands() == 3);
|
||||||
assert(constants.size() == 3);
|
assert(constants.size() == 3);
|
||||||
|
|
||||||
const analysis::BoolConstant* bc =
|
|
||||||
constants[0] ? constants[0]->AsBoolConstant() : nullptr;
|
|
||||||
uint32_t true_id = inst->GetSingleWordInOperand(1);
|
uint32_t true_id = inst->GetSingleWordInOperand(1);
|
||||||
uint32_t false_id = inst->GetSingleWordInOperand(2);
|
uint32_t false_id = inst->GetSingleWordInOperand(2);
|
||||||
|
|
||||||
if (bc) {
|
if (true_id == false_id) {
|
||||||
// Select condition is constant, result is known
|
|
||||||
inst->SetOpcode(SpvOpCopyObject);
|
|
||||||
inst->SetInOperands(
|
|
||||||
{{SPV_OPERAND_TYPE_ID, {bc->value() ? true_id : false_id}}});
|
|
||||||
return true;
|
|
||||||
} else if (true_id == false_id) {
|
|
||||||
// Both results are the same, condition doesn't matter
|
// Both results are the same, condition doesn't matter
|
||||||
inst->SetOpcode(SpvOpCopyObject);
|
inst->SetOpcode(SpvOpCopyObject);
|
||||||
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else if (constants[0]) {
|
||||||
return false;
|
const analysis::Type* type = constants[0]->type();
|
||||||
|
if (type->AsBool()) {
|
||||||
|
// Scalar constant value, select the corresponding value.
|
||||||
|
inst->SetOpcode(SpvOpCopyObject);
|
||||||
|
if (constants[0]->AsNullConstant() ||
|
||||||
|
!constants[0]->AsBoolConstant()->value()) {
|
||||||
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
|
||||||
|
} else {
|
||||||
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
assert(type->AsVector());
|
||||||
|
if (constants[0]->AsNullConstant()) {
|
||||||
|
// All values come from false id.
|
||||||
|
inst->SetOpcode(SpvOpCopyObject);
|
||||||
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// Convert to a vector shuffle.
|
||||||
|
std::vector<ir::Operand> ops;
|
||||||
|
ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
|
||||||
|
ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
|
||||||
|
const analysis::VectorConstant* vector_const =
|
||||||
|
constants[0]->AsVectorConstant();
|
||||||
|
uint32_t size =
|
||||||
|
static_cast<uint32_t>(vector_const->GetComponents().size());
|
||||||
|
for (uint32_t i = 0; i != size; ++i) {
|
||||||
|
const analysis::Constant* component =
|
||||||
|
vector_const->GetComponents()[i];
|
||||||
|
if (component->AsNullConstant() ||
|
||||||
|
!component->AsBoolConstant()->value()) {
|
||||||
|
// Selecting from the false vector which is the second input
|
||||||
|
// vector to the shuffle. Offset the index by |size|.
|
||||||
|
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
|
||||||
|
} else {
|
||||||
|
// Selecting from true vector which is the first input vector to
|
||||||
|
// the shuffle.
|
||||||
|
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inst->SetOpcode(SpvOpVectorShuffle);
|
||||||
|
inst->SetInOperands(std::move(ops));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,6 +130,7 @@ OpName %main "main"
|
|||||||
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
|
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
|
||||||
%true = OpConstantTrue %bool
|
%true = OpConstantTrue %bool
|
||||||
%false = OpConstantFalse %bool
|
%false = OpConstantFalse %bool
|
||||||
|
%bool_null = OpConstantNull %bool
|
||||||
%short = OpTypeInt 16 1
|
%short = OpTypeInt 16 1
|
||||||
%int = OpTypeInt 32 1
|
%int = OpTypeInt 32 1
|
||||||
%long = OpTypeInt 64 1
|
%long = OpTypeInt 64 1
|
||||||
@ -139,6 +140,7 @@ OpName %main "main"
|
|||||||
%v4float = OpTypeVector %float 4
|
%v4float = OpTypeVector %float 4
|
||||||
%v4double = OpTypeVector %double 4
|
%v4double = OpTypeVector %double 4
|
||||||
%v2float = OpTypeVector %float 2
|
%v2float = OpTypeVector %float 2
|
||||||
|
%v2bool = OpTypeVector %bool 2
|
||||||
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
|
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
|
||||||
%_ptr_int = OpTypePointer Function %int
|
%_ptr_int = OpTypePointer Function %int
|
||||||
%_ptr_uint = OpTypePointer Function %uint
|
%_ptr_uint = OpTypePointer Function %uint
|
||||||
@ -176,6 +178,9 @@ OpName %main "main"
|
|||||||
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
|
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
|
||||||
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
|
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
|
||||||
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
|
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
|
||||||
|
%v2bool_null = OpConstantNull %v2bool
|
||||||
|
%v2bool_true_false = OpConstantComposite %v2bool %true %false
|
||||||
|
%v2bool_false_true = OpConstantComposite %v2bool %false %true
|
||||||
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
|
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
|
||||||
%v2int_null = OpConstantNull %v2int
|
%v2int_null = OpConstantNull %v2int
|
||||||
%102 = OpConstantComposite %v2int %103 %103
|
%102 = OpConstantComposite %v2int %103 %103
|
||||||
@ -2336,40 +2341,6 @@ INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest,
|
|||||||
2, 0)
|
2, 0)
|
||||||
));
|
));
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest,
|
|
||||||
::testing::Values(
|
|
||||||
// Test case 0: Fold select with the same values for both sides
|
|
||||||
InstructionFoldingCase<uint32_t>(
|
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
|
||||||
"%main_lab = OpLabel\n" +
|
|
||||||
"%n = OpVariable %_ptr_bool Function\n" +
|
|
||||||
"%load = OpLoad %bool %n\n" +
|
|
||||||
"%2 = OpSelect %int %load %100 %100\n" +
|
|
||||||
"OpReturn\n" +
|
|
||||||
"OpFunctionEnd",
|
|
||||||
2, INT_0_ID),
|
|
||||||
// Test case 1: Fold select true to left side
|
|
||||||
InstructionFoldingCase<uint32_t>(
|
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
|
||||||
"%main_lab = OpLabel\n" +
|
|
||||||
"%n = OpVariable %_ptr_int Function\n" +
|
|
||||||
"%load = OpLoad %bool %n\n" +
|
|
||||||
"%2 = OpSelect %int %true %100 %n\n" +
|
|
||||||
"OpReturn\n" +
|
|
||||||
"OpFunctionEnd",
|
|
||||||
2, INT_0_ID),
|
|
||||||
// Test case 2: Fold select false to right side
|
|
||||||
InstructionFoldingCase<uint32_t>(
|
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
|
||||||
"%main_lab = OpLabel\n" +
|
|
||||||
"%n = OpVariable %_ptr_int Function\n" +
|
|
||||||
"%load = OpLoad %bool %n\n" +
|
|
||||||
"%2 = OpSelect %int %false %n %100\n" +
|
|
||||||
"OpReturn\n" +
|
|
||||||
"OpFunctionEnd",
|
|
||||||
2, INT_0_ID)
|
|
||||||
));
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
|
INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
|
||||||
::testing::Values(
|
::testing::Values(
|
||||||
// Test case 0: Don't fold n + 1.0
|
// Test case 0: Don't fold n + 1.0
|
||||||
@ -4302,5 +4273,113 @@ INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest,
|
|||||||
"OpFunctionEnd\n",
|
"OpFunctionEnd\n",
|
||||||
4, true)
|
4, true)
|
||||||
));
|
));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(SelectFoldingTest, MatchingInstructionFoldingTest,
|
||||||
|
::testing::Values(
|
||||||
|
// Test case 0: Fold select with the same values for both sides
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
|
||||||
|
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_bool Function\n" +
|
||||||
|
"%load = OpLoad %bool %n\n" +
|
||||||
|
"%2 = OpSelect %int %load %100 %100\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, true),
|
||||||
|
// Test case 1: Fold select true to left side
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
|
||||||
|
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_int Function\n" +
|
||||||
|
"%load = OpLoad %bool %n\n" +
|
||||||
|
"%2 = OpSelect %int %true %100 %n\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, true),
|
||||||
|
// Test case 2: Fold select false to right side
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
|
||||||
|
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_int Function\n" +
|
||||||
|
"%load = OpLoad %bool %n\n" +
|
||||||
|
"%2 = OpSelect %int %false %n %100\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, true),
|
||||||
|
// Test case 3: Fold select null to right side
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
|
||||||
|
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_int Function\n" +
|
||||||
|
"%load = OpLoad %int %n\n" +
|
||||||
|
"%2 = OpSelect %int %bool_null %load %100\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, true),
|
||||||
|
// Test case 4: vector null
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
|
||||||
|
"; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
|
||||||
|
"; CHECK: [[v2int2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int2]] [[int2]]\n" +
|
||||||
|
"; CHECK: %2 = OpCopyObject [[v2int]] [[v2int2_2]]\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%load = OpLoad %v2int %n\n" +
|
||||||
|
"%2 = OpSelect %v2int %v2bool_null %load %v2int_2_2\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, true),
|
||||||
|
// Test case 5: vector select
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
|
||||||
|
"; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 0 3\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%m = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%n = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%2 = OpLoad %v2int %n\n" +
|
||||||
|
"%3 = OpLoad %v2int %n\n" +
|
||||||
|
"%4 = OpSelect %v2int %v2bool_true_false %2 %3\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
4, true),
|
||||||
|
// Test case 6: vector select
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
|
||||||
|
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
|
||||||
|
"; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 2 1\n" +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%m = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%n = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%2 = OpLoad %v2int %n\n" +
|
||||||
|
"%3 = OpLoad %v2int %n\n" +
|
||||||
|
"%4 = OpSelect %v2int %v2bool_false_true %2 %3\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
4, true)
|
||||||
|
));
|
||||||
#endif
|
#endif
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user