mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 19:20:05 +00:00
Add more folding rules for vector shuffle.
Adds rule to fold OpVectorShuffle with constant inputs. Adds rules to fold OpCompositeExtrac being fed by an OpVectorShuffle.
This commit is contained in:
parent
90e1637ce4
commit
588f4fcc95
@ -49,6 +49,59 @@ ConstantFoldingRule FoldExtractWithConstants() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConstantFoldingRule FoldVectorShuffleWithConstants() {
|
||||||
|
return [](ir::Instruction* inst,
|
||||||
|
const std::vector<const analysis::Constant*>& constants)
|
||||||
|
-> const analysis::Constant* {
|
||||||
|
assert(inst->opcode() == SpvOpVectorShuffle);
|
||||||
|
const analysis::Constant* c1 = constants[0];
|
||||||
|
const analysis::Constant* c2 = constants[1];
|
||||||
|
if (c1 == nullptr || c2 == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::IRContext* context = inst->context();
|
||||||
|
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||||
|
const analysis::Type* element_type = c1->type()->AsVector()->element_type();
|
||||||
|
|
||||||
|
std::vector<const analysis::Constant*> c1_components;
|
||||||
|
if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
|
||||||
|
c1_components = vec_const->GetComponents();
|
||||||
|
} else {
|
||||||
|
assert(c1->AsNullConstant());
|
||||||
|
const analysis::Constant* element =
|
||||||
|
const_mgr->GetConstant(element_type, {});
|
||||||
|
c1_components.resize(c1->type()->AsVector()->element_count(), element);
|
||||||
|
}
|
||||||
|
std::vector<const analysis::Constant*> c2_components;
|
||||||
|
if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
|
||||||
|
c2_components = vec_const->GetComponents();
|
||||||
|
} else {
|
||||||
|
assert(c2->AsNullConstant());
|
||||||
|
const analysis::Constant* element =
|
||||||
|
const_mgr->GetConstant(element_type, {});
|
||||||
|
c2_components.resize(c2->type()->AsVector()->element_count(), element);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint32_t> ids;
|
||||||
|
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
|
||||||
|
uint32_t index = inst->GetSingleWordInOperand(i);
|
||||||
|
if (index < c1_components.size()) {
|
||||||
|
ir::Instruction* member_inst =
|
||||||
|
const_mgr->GetDefiningInstruction(c1_components[index]);
|
||||||
|
ids.push_back(member_inst->result_id());
|
||||||
|
} else {
|
||||||
|
ir::Instruction* member_inst = const_mgr->GetDefiningInstruction(
|
||||||
|
c2_components[index - c1_components.size()]);
|
||||||
|
ids.push_back(member_inst->result_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||||
|
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
ConstantFoldingRule FoldCompositeWithConstants() {
|
ConstantFoldingRule FoldCompositeWithConstants() {
|
||||||
// Folds an OpCompositeConstruct where all of the inputs are constants to a
|
// Folds an OpCompositeConstruct where all of the inputs are constants to a
|
||||||
// constant. A new constant is created if necessary.
|
// constant. A new constant is created if necessary.
|
||||||
@ -306,6 +359,8 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
|
|||||||
rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
|
rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
|
||||||
rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
|
rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
|
||||||
rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
|
rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
|
||||||
|
|
||||||
|
rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
@ -1444,6 +1444,54 @@ FoldingRule InsertFeedingExtract() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When a VectorShuffle is feeding an Extract, we can extract from one of the
|
||||||
|
// operands of the VectorShuffle. We just need to adjust the index in the
|
||||||
|
// extract instruction.
|
||||||
|
FoldingRule VectorShuffleFeedingExtract() {
|
||||||
|
return [](ir::Instruction* inst,
|
||||||
|
const std::vector<const analysis::Constant*>&) {
|
||||||
|
assert(inst->opcode() == SpvOpCompositeExtract &&
|
||||||
|
"Wrong opcode. Should be OpCompositeExtract.");
|
||||||
|
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
|
||||||
|
analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
|
||||||
|
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
|
||||||
|
ir::Instruction* cinst = def_use_mgr->GetDef(cid);
|
||||||
|
|
||||||
|
if (cinst->opcode() != SpvOpVectorShuffle) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the size of the first vector operand of the VectorShuffle
|
||||||
|
ir::Instruction* first_input =
|
||||||
|
def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||||
|
analysis::Type* first_input_type =
|
||||||
|
type_mgr->GetType(first_input->type_id());
|
||||||
|
assert(first_input_type->AsVector() &&
|
||||||
|
"Input to vector shuffle should be vectors.");
|
||||||
|
uint32_t first_input_size = first_input_type->AsVector()->element_count();
|
||||||
|
|
||||||
|
// Get index of the element the vector shuffle is placing in the position
|
||||||
|
// being extracted.
|
||||||
|
uint32_t new_index =
|
||||||
|
cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
|
||||||
|
|
||||||
|
// Get the id of the of the vector the elemtent comes from, and update the
|
||||||
|
// index if needed.
|
||||||
|
uint32_t new_vector = 0;
|
||||||
|
if (new_index < first_input_size) {
|
||||||
|
new_vector = cinst->GetSingleWordInOperand(0);
|
||||||
|
} else {
|
||||||
|
new_vector = cinst->GetSingleWordInOperand(1);
|
||||||
|
new_index -= first_input_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the extract instruction.
|
||||||
|
inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
|
||||||
|
inst->SetInOperand(1, {new_index});
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
FoldingRule RedundantPhi() {
|
FoldingRule RedundantPhi() {
|
||||||
// An OpPhi instruction where all values are the same or the result of the phi
|
// An OpPhi instruction where all values are the same or the result of the phi
|
||||||
// itself, can be replaced by the value itself.
|
// itself, can be replaced by the value itself.
|
||||||
@ -1725,6 +1773,7 @@ spvtools::opt::FoldingRules::FoldingRules() {
|
|||||||
|
|
||||||
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
|
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
|
||||||
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
|
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
|
||||||
|
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
|
||||||
|
|
||||||
rules_[SpvOpExtInst].push_back(RedundantFMix());
|
rules_[SpvOpExtInst].push_back(RedundantFMix());
|
||||||
|
|
||||||
|
@ -177,6 +177,7 @@ OpName %main "main"
|
|||||||
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
|
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
|
||||||
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
|
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
|
||||||
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
|
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
|
||||||
|
%v2int_null = OpConstantNull %v2int
|
||||||
%102 = OpConstantComposite %v2int %103 %103
|
%102 = OpConstantComposite %v2int %103 %103
|
||||||
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
|
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
|
||||||
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
|
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
|
||||||
@ -186,6 +187,7 @@ OpName %main "main"
|
|||||||
%float_n1 = OpConstant %float -1
|
%float_n1 = OpConstant %float -1
|
||||||
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
|
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
|
||||||
%float_0 = OpConstant %float 0
|
%float_0 = OpConstant %float 0
|
||||||
|
%float_half = OpConstant %float 0.5
|
||||||
%float_1 = OpConstant %float 1
|
%float_1 = OpConstant %float 1
|
||||||
%float_2 = OpConstant %float 2
|
%float_2 = OpConstant %float 2
|
||||||
%float_3 = OpConstant %float 3
|
%float_3 = OpConstant %float 3
|
||||||
@ -391,6 +393,70 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest,
|
|||||||
));
|
));
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
using IntVectorInstructionFoldingTest =
|
||||||
|
::testing::TestWithParam<InstructionFoldingCase<std::vector<uint32_t>>>;
|
||||||
|
|
||||||
|
TEST_P(IntVectorInstructionFoldingTest, Case) {
|
||||||
|
const auto& tc = GetParam();
|
||||||
|
|
||||||
|
// Build module.
|
||||||
|
std::unique_ptr<ir::IRContext> context =
|
||||||
|
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
|
||||||
|
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||||
|
ASSERT_NE(nullptr, context);
|
||||||
|
|
||||||
|
// Fold the instruction to test.
|
||||||
|
opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||||
|
ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
|
||||||
|
bool succeeded = opt::FoldInstruction(inst);
|
||||||
|
|
||||||
|
// Make sure the instruction folded as expected.
|
||||||
|
EXPECT_TRUE(succeeded);
|
||||||
|
if (inst != nullptr) {
|
||||||
|
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
|
||||||
|
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||||
|
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
|
||||||
|
EXPECT_THAT(opcodes, Contains(inst->opcode()));
|
||||||
|
opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr();
|
||||||
|
const opt::analysis::Constant* result =
|
||||||
|
const_mrg->GetConstantFromInst(inst);
|
||||||
|
EXPECT_NE(result, nullptr);
|
||||||
|
if (result != nullptr) {
|
||||||
|
const std::vector<const opt::analysis::Constant*>& componenets =
|
||||||
|
result->AsVectorConstant()->GetComponents();
|
||||||
|
EXPECT_EQ(componenets.size(), tc.expected_result.size());
|
||||||
|
for (size_t i = 0; i < componenets.size(); i++) {
|
||||||
|
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetU32());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest,
|
||||||
|
::testing::Values(
|
||||||
|
// Test case 0: fold 0*n
|
||||||
|
InstructionFoldingCase<std::vector<uint32_t>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_int Function\n" +
|
||||||
|
"%load = OpLoad %int %n\n" +
|
||||||
|
"%2 = OpVectorShuffle %v2int %v2int_2_2 %v2int_2_3 0 3\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {2,3}),
|
||||||
|
InstructionFoldingCase<std::vector<uint32_t>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_int Function\n" +
|
||||||
|
"%load = OpLoad %int %n\n" +
|
||||||
|
"%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {0,3})
|
||||||
|
));
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
using BooleanInstructionFoldingTest =
|
using BooleanInstructionFoldingTest =
|
||||||
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
||||||
|
|
||||||
@ -2116,7 +2182,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
|
|||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
2, 0),
|
2, 0),
|
||||||
// Test case 9: constant struct has OpUndef
|
// Test case 9: Extracting a member of element inserted via Insert
|
||||||
InstructionFoldingCase<uint32_t>(
|
InstructionFoldingCase<uint32_t>(
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
"%main_lab = OpLabel\n" +
|
"%main_lab = OpLabel\n" +
|
||||||
@ -2127,7 +2193,7 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
|
|||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
4, 103),
|
4, 103),
|
||||||
// Test case 10: constant struct has OpUndef
|
// Test case 10: Extracting a element that is partially changed by Insert. (Don't fold)
|
||||||
InstructionFoldingCase<uint32_t>(
|
InstructionFoldingCase<uint32_t>(
|
||||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
"%main_lab = OpLabel\n" +
|
"%main_lab = OpLabel\n" +
|
||||||
@ -2137,7 +2203,29 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
|
|||||||
"%4 = OpCompositeExtract %v2int %3 0\n" +
|
"%4 = OpCompositeExtract %v2int %3 0\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
4, 0)
|
4, 0),
|
||||||
|
// Test case 11: Extracting from result of vector shuffle (first input)
|
||||||
|
InstructionFoldingCase<uint32_t>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%2 = OpLoad %v2int %n\n" +
|
||||||
|
"%3 = OpVectorShuffle %v2int %102 %2 3 0\n" +
|
||||||
|
"%4 = OpCompositeExtract %int %3 1\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
4, INT_7_ID),
|
||||||
|
// Test case 12: Extracting from result of vector shuffle (second input)
|
||||||
|
InstructionFoldingCase<uint32_t>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%n = OpVariable %_ptr_v2int Function\n" +
|
||||||
|
"%2 = OpLoad %v2int %n\n" +
|
||||||
|
"%3 = OpVectorShuffle %v2int %2 %102 2 0\n" +
|
||||||
|
"%4 = OpCompositeExtract %int %3 0\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
4, INT_7_ID)
|
||||||
));
|
));
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
|
INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
|
||||||
|
Loading…
Reference in New Issue
Block a user