From 52bceb3569fd36edc35275603ecb83a68ae81aa1 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 2 Mar 2018 09:19:50 -0500 Subject: [PATCH] Handles more cases of redundant selects * Handles OpConstantNull and vector types * vector selects (except against a null) are converted to vector shuffles * Added tests --- source/opt/folding_rules.cpp | 62 ++++++++++++--- test/opt/fold_test.cpp | 147 +++++++++++++++++++++++++++-------- 2 files changed, 164 insertions(+), 45 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 7e4dddba3..812193154 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1549,25 +1549,65 @@ FoldingRule RedundantSelect() { assert(inst->NumInOperands() == 3); assert(constants.size() == 3); - const analysis::BoolConstant* bc = - constants[0] ? constants[0]->AsBoolConstant() : nullptr; uint32_t true_id = inst->GetSingleWordInOperand(1); uint32_t false_id = inst->GetSingleWordInOperand(2); - if (bc) { - // Select condition is constant, result is known - inst->SetOpcode(SpvOpCopyObject); - inst->SetInOperands( - {{SPV_OPERAND_TYPE_ID, {bc->value() ? true_id : false_id}}}); - return true; - } else if (true_id == false_id) { + if (true_id == false_id) { // Both results are the same, condition doesn't matter inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); return true; - } else { - return false; + } else if (constants[0]) { + const analysis::Type* type = constants[0]->type(); + if (type->AsBool()) { + // Scalar constant value, select the corresponding value. + inst->SetOpcode(SpvOpCopyObject); + if (constants[0]->AsNullConstant() || + !constants[0]->AsBoolConstant()->value()) { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); + } + return true; + } else { + assert(type->AsVector()); + if (constants[0]->AsNullConstant()) { + // All values come from false id. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); + return true; + } else { + // Convert to a vector shuffle. + std::vector ops; + ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); + ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); + const analysis::VectorConstant* vector_const = + constants[0]->AsVectorConstant(); + uint32_t size = + static_cast(vector_const->GetComponents().size()); + for (uint32_t i = 0; i != size; ++i) { + const analysis::Constant* component = + vector_const->GetComponents()[i]; + if (component->AsNullConstant() || + !component->AsBoolConstant()->value()) { + // Selecting from the false vector which is the second input + // vector to the shuffle. Offset the index by |size|. + ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}}); + } else { + // Selecting from true vector which is the first input vector to + // the shuffle. + ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); + } + } + + inst->SetOpcode(SpvOpVectorShuffle); + inst->SetInOperands(std::move(ops)); + return true; + } + } } + + return false; }; } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 345cfeacb..7d813d4c5 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -130,6 +130,7 @@ OpName %main "main" %101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps. %true = OpConstantTrue %bool %false = OpConstantFalse %bool +%bool_null = OpConstantNull %bool %short = OpTypeInt 16 1 %int = OpTypeInt 32 1 %long = OpTypeInt 64 1 @@ -139,6 +140,7 @@ OpName %main "main" %v4float = OpTypeVector %float 4 %v4double = OpTypeVector %double 4 %v2float = OpTypeVector %float 2 +%v2bool = OpTypeVector %bool 2 %struct_v2int_int_int = OpTypeStruct %v2int %int %int %_ptr_int = OpTypePointer Function %int %_ptr_uint = OpTypePointer Function %uint @@ -176,6 +178,9 @@ OpName %main "main" %v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3 %v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 %v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4 +%v2bool_null = OpConstantNull %v2bool +%v2bool_true_false = OpConstantComposite %v2bool %true %false +%v2bool_false_true = OpConstantComposite %v2bool %false %true %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int %v2int_null = OpConstantNull %v2int %102 = OpConstantComposite %v2int %103 %103 @@ -2336,40 +2341,6 @@ INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest, 2, 0) )); -INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest, -::testing::Values( - // Test case 0: Fold select with the same values for both sides - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%n = OpVariable %_ptr_bool Function\n" + - "%load = OpLoad %bool %n\n" + - "%2 = OpSelect %int %load %100 %100\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, INT_0_ID), - // Test case 1: Fold select true to left side - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%n = OpVariable %_ptr_int Function\n" + - "%load = OpLoad %bool %n\n" + - "%2 = OpSelect %int %true %100 %n\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, INT_0_ID), - // Test case 2: Fold select false to right side - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%n = OpVariable %_ptr_int Function\n" + - "%load = OpLoad %bool %n\n" + - "%2 = OpSelect %int %false %n %100\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, INT_0_ID) -)); - INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest, ::testing::Values( // Test case 0: Don't fold n + 1.0 @@ -4302,5 +4273,113 @@ INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest, "OpFunctionEnd\n", 4, true) )); + +INSTANTIATE_TEST_CASE_P(SelectFoldingTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold select with the same values for both sides + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %load %100 %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 1: Fold select true to left side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %true %100 %n\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: Fold select false to right side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %false %n %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: Fold select null to right side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSelect %int %bool_null %load %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 4: vector null + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" + + "; CHECK: [[v2int2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int2]] [[int2]]\n" + + "; CHECK: %2 = OpCopyObject [[v2int]] [[v2int2_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%load = OpLoad %v2int %n\n" + + "%2 = OpSelect %v2int %v2bool_null %load %v2int_2_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: vector select + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 0 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v2int Function\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpLoad %v2int %n\n" + + "%4 = OpSelect %v2int %v2bool_true_false %2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 6: vector select + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v2int Function\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpLoad %v2int %n\n" + + "%4 = OpSelect %v2int %v2bool_false_true %2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true) +)); #endif } // anonymous namespace