Add more folding rules for vector shuffle.

Adds rule to fold OpVectorShuffle with constant inputs.

Adds rules to fold OpCompositeExtrac being fed by an OpVectorShuffle.
This commit is contained in:
Steven Perron 2018-02-26 14:47:11 -05:00
parent 90e1637ce4
commit 588f4fcc95
3 changed files with 195 additions and 3 deletions

View File

@ -49,6 +49,59 @@ ConstantFoldingRule FoldExtractWithConstants() {
}; };
} }
ConstantFoldingRule FoldVectorShuffleWithConstants() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(inst->opcode() == SpvOpVectorShuffle);
const analysis::Constant* c1 = constants[0];
const analysis::Constant* c2 = constants[1];
if (c1 == nullptr || c2 == nullptr) {
return nullptr;
}
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* element_type = c1->type()->AsVector()->element_type();
std::vector<const analysis::Constant*> c1_components;
if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
c1_components = vec_const->GetComponents();
} else {
assert(c1->AsNullConstant());
const analysis::Constant* element =
const_mgr->GetConstant(element_type, {});
c1_components.resize(c1->type()->AsVector()->element_count(), element);
}
std::vector<const analysis::Constant*> c2_components;
if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
c2_components = vec_const->GetComponents();
} else {
assert(c2->AsNullConstant());
const analysis::Constant* element =
const_mgr->GetConstant(element_type, {});
c2_components.resize(c2->type()->AsVector()->element_count(), element);
}
std::vector<uint32_t> ids;
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
uint32_t index = inst->GetSingleWordInOperand(i);
if (index < c1_components.size()) {
ir::Instruction* member_inst =
const_mgr->GetDefiningInstruction(c1_components[index]);
ids.push_back(member_inst->result_id());
} else {
ir::Instruction* member_inst = const_mgr->GetDefiningInstruction(
c2_components[index - c1_components.size()]);
ids.push_back(member_inst->result_id());
}
}
analysis::TypeManager* type_mgr = context->get_type_mgr();
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
};
} // namespace
ConstantFoldingRule FoldCompositeWithConstants() { ConstantFoldingRule FoldCompositeWithConstants() {
// Folds an OpCompositeConstruct where all of the inputs are constants to a // Folds an OpCompositeConstruct where all of the inputs are constants to a
// constant. A new constant is created if necessary. // constant. A new constant is created if necessary.
@ -306,6 +359,8 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
} }
} // namespace opt } // namespace opt
} // namespace spvtools } // namespace spvtools

View File

@ -1444,6 +1444,54 @@ FoldingRule InsertFeedingExtract() {
}; };
} }
// When a VectorShuffle is feeding an Extract, we can extract from one of the
// operands of the VectorShuffle. We just need to adjust the index in the
// extract instruction.
FoldingRule VectorShuffleFeedingExtract() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
ir::Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpVectorShuffle) {
return false;
}
// Find the size of the first vector operand of the VectorShuffle
ir::Instruction* first_input =
def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
analysis::Type* first_input_type =
type_mgr->GetType(first_input->type_id());
assert(first_input_type->AsVector() &&
"Input to vector shuffle should be vectors.");
uint32_t first_input_size = first_input_type->AsVector()->element_count();
// Get index of the element the vector shuffle is placing in the position
// being extracted.
uint32_t new_index =
cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
// Get the id of the of the vector the elemtent comes from, and update the
// index if needed.
uint32_t new_vector = 0;
if (new_index < first_input_size) {
new_vector = cinst->GetSingleWordInOperand(0);
} else {
new_vector = cinst->GetSingleWordInOperand(1);
new_index -= first_input_size;
}
// Update the extract instruction.
inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
inst->SetInOperand(1, {new_index});
return true;
};
}
FoldingRule RedundantPhi() { FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi // An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself. // itself, can be replaced by the value itself.
@ -1725,6 +1773,7 @@ spvtools::opt::FoldingRules::FoldingRules() {
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpExtInst].push_back(RedundantFMix()); rules_[SpvOpExtInst].push_back(RedundantFMix());

View File

@ -177,6 +177,7 @@ OpName %main "main"
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 %v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4 %v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103 %102 = OpConstantComposite %v2int %103 %103
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0 %struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
@ -186,6 +187,7 @@ OpName %main "main"
%float_n1 = OpConstant %float -1 %float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps. %104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
%float_0 = OpConstant %float 0 %float_0 = OpConstant %float 0
%float_half = OpConstant %float 0.5
%float_1 = OpConstant %float 1 %float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2 %float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3 %float_3 = OpConstant %float 3
@ -391,6 +393,70 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest,
)); ));
// clang-format on // clang-format on
using IntVectorInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<std::vector<uint32_t>>>;
TEST_P(IntVectorInstructionFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
bool succeeded = opt::FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_TRUE(succeeded);
if (inst != nullptr) {
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
EXPECT_THAT(opcodes, Contains(inst->opcode()));
opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
const opt::analysis::Constant* result =
const_mrg->GetConstantFromInst(inst);
EXPECT_NE(result, nullptr);
if (result != nullptr) {
const std::vector<const opt::analysis::Constant*>& componenets =
result->AsVectorConstant()->GetComponents();
EXPECT_EQ(componenets.size(), tc.expected_result.size());
for (size_t i = 0; i < componenets.size(); i++) {
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetU32());
}
}
}
}
// clang-format off
INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest,
::testing::Values(
// Test case 0: fold 0*n
InstructionFoldingCase<std::vector<uint32_t>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %int %n\n" +
"%2 = OpVectorShuffle %v2int %v2int_2_2 %v2int_2_3 0 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {2,3}),
InstructionFoldingCase<std::vector<uint32_t>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load = OpLoad %int %n\n" +
"%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {0,3})
));
// clang-format on
using BooleanInstructionFoldingTest = using BooleanInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>; ::testing::TestWithParam<InstructionFoldingCase<bool>>;
@ -2116,7 +2182,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
2, 0), 2, 0),
// Test case 9: constant struct has OpUndef // Test case 9: Extracting a member of element inserted via Insert
InstructionFoldingCase<uint32_t>( InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" + Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" + "%main_lab = OpLabel\n" +
@ -2127,7 +2193,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
4, 103), 4, 103),
// Test case 10: constant struct has OpUndef // Test case 10: Extracting a element that is partially changed by Insert. (Don't fold)
InstructionFoldingCase<uint32_t>( InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" + Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" + "%main_lab = OpLabel\n" +
@ -2137,7 +2203,29 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
"%4 = OpCompositeExtract %v2int %3 0\n" + "%4 = OpCompositeExtract %v2int %3 0\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
4, 0) 4, 0),
// Test case 11: Extracting from result of vector shuffle (first input)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v2int Function\n" +
"%2 = OpLoad %v2int %n\n" +
"%3 = OpVectorShuffle %v2int %102 %2 3 0\n" +
"%4 = OpCompositeExtract %int %3 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, INT_7_ID),
// Test case 12: Extracting from result of vector shuffle (second input)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v2int Function\n" +
"%2 = OpLoad %v2int %n\n" +
"%3 = OpVectorShuffle %v2int %2 %102 2 0\n" +
"%4 = OpCompositeExtract %int %3 0\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, INT_7_ID)
)); ));
INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest, INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,