Handles more cases of redundant selects

* Handles OpConstantNull and vector types
 * vector selects (except against a null) are converted to vector
 shuffles
* Added tests
This commit is contained in:
Alan Baker 2018-03-02 09:19:50 -05:00
parent a7cec7843c
commit 52bceb3569
2 changed files with 164 additions and 45 deletions

View File

@ -1549,25 +1549,65 @@ FoldingRule RedundantSelect() {
assert(inst->NumInOperands() == 3);
assert(constants.size() == 3);
const analysis::BoolConstant* bc =
constants[0] ? constants[0]->AsBoolConstant() : nullptr;
uint32_t true_id = inst->GetSingleWordInOperand(1);
uint32_t false_id = inst->GetSingleWordInOperand(2);
if (bc) {
// Select condition is constant, result is known
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {bc->value() ? true_id : false_id}}});
return true;
} else if (true_id == false_id) {
if (true_id == false_id) {
// Both results are the same, condition doesn't matter
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
return true;
} else if (constants[0]) {
const analysis::Type* type = constants[0]->type();
if (type->AsBool()) {
// Scalar constant value, select the corresponding value.
inst->SetOpcode(SpvOpCopyObject);
if (constants[0]->AsNullConstant() ||
!constants[0]->AsBoolConstant()->value()) {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
} else {
return false;
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
}
return true;
} else {
assert(type->AsVector());
if (constants[0]->AsNullConstant()) {
// All values come from false id.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
return true;
} else {
// Convert to a vector shuffle.
std::vector<ir::Operand> ops;
ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
const analysis::VectorConstant* vector_const =
constants[0]->AsVectorConstant();
uint32_t size =
static_cast<uint32_t>(vector_const->GetComponents().size());
for (uint32_t i = 0; i != size; ++i) {
const analysis::Constant* component =
vector_const->GetComponents()[i];
if (component->AsNullConstant() ||
!component->AsBoolConstant()->value()) {
// Selecting from the false vector which is the second input
// vector to the shuffle. Offset the index by |size|.
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
} else {
// Selecting from true vector which is the first input vector to
// the shuffle.
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
}
}
inst->SetOpcode(SpvOpVectorShuffle);
inst->SetInOperands(std::move(ops));
return true;
}
}
}
return false;
};
}

View File

@ -130,6 +130,7 @@ OpName %main "main"
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%bool_null = OpConstantNull %bool
%short = OpTypeInt 16 1
%int = OpTypeInt 32 1
%long = OpTypeInt 64 1
@ -139,6 +140,7 @@ OpName %main "main"
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
%v2float = OpTypeVector %float 2
%v2bool = OpTypeVector %bool 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
@ -176,6 +178,9 @@ OpName %main "main"
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
%v2bool_null = OpConstantNull %v2bool
%v2bool_true_false = OpConstantComposite %v2bool %true %false
%v2bool_false_true = OpConstantComposite %v2bool %false %true
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103
@ -2336,40 +2341,6 @@ INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest,
2, 0)
));
INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Fold select with the same values for both sides
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %load %100 %100\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, INT_0_ID),
// Test case 1: Fold select true to left side
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %true %100 %n\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, INT_0_ID),
// Test case 2: Fold select false to right side
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %false %n %100\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, INT_0_ID)
));
INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold n + 1.0
@ -4302,5 +4273,113 @@ INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest,
"OpFunctionEnd\n",
4, true)
));
INSTANTIATE_TEST_CASE_P(SelectFoldingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: Fold select with the same values for both sides
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %load %100 %100\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true),
// Test case 1: Fold select true to left side
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %true %100 %n\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true),
// Test case 2: Fold select false to right side
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %bool %n\n" +
"%2 = OpSelect %int %false %n %100\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true),
// Test case 3: Fold select null to right side
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
"; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %int %n\n" +
"%2 = OpSelect %int %bool_null %load %100\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true),
// Test case 4: vector null
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK: [[v2int2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int2]] [[int2]]\n" +
"; CHECK: %2 = OpCopyObject [[v2int]] [[v2int2_2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v2int Function\n" +
"%load = OpLoad %v2int %n\n" +
"%2 = OpSelect %v2int %v2bool_null %load %v2int_2_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true),
// Test case 5: vector select
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 0 3\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%m = OpVariable %_ptr_v2int Function\n" +
"%n = OpVariable %_ptr_v2int Function\n" +
"%2 = OpLoad %v2int %n\n" +
"%3 = OpLoad %v2int %n\n" +
"%4 = OpSelect %v2int %v2bool_true_false %2 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 6: vector select
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 2 1\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%m = OpVariable %_ptr_v2int Function\n" +
"%n = OpVariable %_ptr_v2int Function\n" +
"%2 = OpLoad %v2int %n\n" +
"%3 = OpLoad %v2int %n\n" +
"%4 = OpSelect %v2int %v2bool_false_true %2 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true)
));
#endif
} // anonymous namespace