From 588f4fcc95b3ef926cd05b3378c8a90d3ba73f66 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Mon, 26 Feb 2018 14:47:11 -0500 Subject: [PATCH] Add more folding rules for vector shuffle. Adds rule to fold OpVectorShuffle with constant inputs. Adds rules to fold OpCompositeExtrac being fed by an OpVectorShuffle. --- source/opt/const_folding_rules.cpp | 55 +++++++++++++++++ source/opt/folding_rules.cpp | 49 ++++++++++++++++ test/opt/fold_test.cpp | 94 +++++++++++++++++++++++++++++- 3 files changed, 195 insertions(+), 3 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 6e82874f6..ad0f8f79a 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -49,6 +49,59 @@ ConstantFoldingRule FoldExtractWithConstants() { }; } +ConstantFoldingRule FoldVectorShuffleWithConstants() { + return [](ir::Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + assert(inst->opcode() == SpvOpVectorShuffle); + const analysis::Constant* c1 = constants[0]; + const analysis::Constant* c2 = constants[1]; + if (c1 == nullptr || c2 == nullptr) { + return nullptr; + } + + ir::IRContext* context = inst->context(); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* element_type = c1->type()->AsVector()->element_type(); + + std::vector c1_components; + if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { + c1_components = vec_const->GetComponents(); + } else { + assert(c1->AsNullConstant()); + const analysis::Constant* element = + const_mgr->GetConstant(element_type, {}); + c1_components.resize(c1->type()->AsVector()->element_count(), element); + } + std::vector c2_components; + if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { + c2_components = vec_const->GetComponents(); + } else { + assert(c2->AsNullConstant()); + const analysis::Constant* element = + const_mgr->GetConstant(element_type, {}); + c2_components.resize(c2->type()->AsVector()->element_count(), element); + } + + std::vector ids; + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + uint32_t index = inst->GetSingleWordInOperand(i); + if (index < c1_components.size()) { + ir::Instruction* member_inst = + const_mgr->GetDefiningInstruction(c1_components[index]); + ids.push_back(member_inst->result_id()); + } else { + ir::Instruction* member_inst = const_mgr->GetDefiningInstruction( + c2_components[index - c1_components.size()]); + ids.push_back(member_inst->result_id()); + } + } + + analysis::TypeManager* type_mgr = context->get_type_mgr(); + return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); + }; +} // namespace + ConstantFoldingRule FoldCompositeWithConstants() { // Folds an OpCompositeConstruct where all of the inputs are constants to a // constant. A new constant is created if necessary. @@ -306,6 +359,8 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() { rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); + + rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); } } // namespace opt } // namespace spvtools diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 4f2c128ca..30f5641f5 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1444,6 +1444,54 @@ FoldingRule InsertFeedingExtract() { }; } +// When a VectorShuffle is feeding an Extract, we can extract from one of the +// operands of the VectorShuffle. We just need to adjust the index in the +// extract instruction. +FoldingRule VectorShuffleFeedingExtract() { + return [](ir::Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + ir::Instruction* cinst = def_use_mgr->GetDef(cid); + + if (cinst->opcode() != SpvOpVectorShuffle) { + return false; + } + + // Find the size of the first vector operand of the VectorShuffle + ir::Instruction* first_input = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + analysis::Type* first_input_type = + type_mgr->GetType(first_input->type_id()); + assert(first_input_type->AsVector() && + "Input to vector shuffle should be vectors."); + uint32_t first_input_size = first_input_type->AsVector()->element_count(); + + // Get index of the element the vector shuffle is placing in the position + // being extracted. + uint32_t new_index = + cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); + + // Get the id of the of the vector the elemtent comes from, and update the + // index if needed. + uint32_t new_vector = 0; + if (new_index < first_input_size) { + new_vector = cinst->GetSingleWordInOperand(0); + } else { + new_vector = cinst->GetSingleWordInOperand(1); + new_index -= first_input_size; + } + + // Update the extract instruction. + inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); + inst->SetInOperand(1, {new_index}); + return true; + }; +} + FoldingRule RedundantPhi() { // An OpPhi instruction where all values are the same or the result of the phi // itself, can be replaced by the value itself. @@ -1725,6 +1773,7 @@ spvtools::opt::FoldingRules::FoldingRules() { rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); + rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract()); rules_[SpvOpExtInst].push_back(RedundantFMix()); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 0fa1d1507..04112df36 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -177,6 +177,7 @@ OpName %main "main" %v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 %v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4 %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int +%v2int_null = OpConstantNull %v2int %102 = OpConstantComposite %v2int %103 %103 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 %struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0 @@ -186,6 +187,7 @@ OpName %main "main" %float_n1 = OpConstant %float -1 %104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps. %float_0 = OpConstant %float 0 +%float_half = OpConstant %float 0.5 %float_1 = OpConstant %float 1 %float_2 = OpConstant %float 2 %float_3 = OpConstant %float 3 @@ -391,6 +393,70 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest, )); // clang-format on +using IntVectorInstructionFoldingTest = + ::testing::TestWithParam>>; + +TEST_P(IntVectorInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = opt::FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + std::vector opcodes = {SpvOpConstantComposite}; + EXPECT_THAT(opcodes, Contains(inst->opcode())); + opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const opt::analysis::Constant* result = + const_mrg->GetConstantFromInst(inst); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + const std::vector& componenets = + result->AsVectorConstant()->GetComponents(); + EXPECT_EQ(componenets.size(), tc.expected_result.size()); + for (size_t i = 0; i < componenets.size(); i++) { + EXPECT_EQ(tc.expected_result[i], componenets[i]->GetU32()); + } + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest, +::testing::Values( + // Test case 0: fold 0*n + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_2_2 %v2int_2_3 0 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {2,3}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,3}) +)); +// clang-format on + using BooleanInstructionFoldingTest = ::testing::TestWithParam>; @@ -2116,7 +2182,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 9: constant struct has OpUndef + // Test case 9: Extracting a member of element inserted via Insert InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -2127,7 +2193,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 4, 103), - // Test case 10: constant struct has OpUndef + // Test case 10: Extracting a element that is partially changed by Insert. (Don't fold) InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -2137,7 +2203,29 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe "%4 = OpCompositeExtract %v2int %3 0\n" + "OpReturn\n" + "OpFunctionEnd", - 4, 0) + 4, 0), + // Test case 11: Extracting from result of vector shuffle (first input) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpVectorShuffle %v2int %102 %2 3 0\n" + + "%4 = OpCompositeExtract %int %3 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_7_ID), + // Test case 12: Extracting from result of vector shuffle (second input) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %102 2 0\n" + + "%4 = OpCompositeExtract %int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_7_ID) )); INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,