spirv-opt: Add const folding for CompositeInsert (#4943)

* spirv-opt: Add const folding pass for CompositeInsert

* spirv-opt: Fix anas stack-use-after-scope
This commit is contained in:
Spencer Fricke 2022-11-09 00:50:42 +09:00 committed by GitHub
parent a5e766b2b4
commit 54d4e77fa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 144 additions and 0 deletions

View File

@ -120,6 +120,83 @@ ConstantFoldingRule FoldExtractWithConstants() {
};
}
// Folds an OpcompositeInsert where input is a composite constant.
ConstantFoldingRule FoldInsertWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Constant* object = constants[0];
const analysis::Constant* composite = constants[1];
if (object == nullptr || composite == nullptr) {
return nullptr;
}
// If there is more than 1 index, then each additional constant used by the
// index will need to be recreated to use the inserted object.
std::vector<const analysis::Constant*> chain;
std::vector<const analysis::Constant*> components;
const analysis::Type* type = nullptr;
// Work down hierarchy and add all the indexes, not including the final
// index.
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
if (i != inst->NumInOperands() - 1) {
chain.push_back(composite);
}
const uint32_t index = inst->GetSingleWordInOperand(i);
components = composite->AsCompositeConstant()->GetComponents();
type = composite->AsCompositeConstant()->type();
composite = components[index];
}
// Final index in hierarchy is inserted with new object.
const uint32_t final_index =
inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
std::vector<uint32_t> ids;
for (size_t i = 0; i < components.size(); i++) {
const analysis::Constant* constant =
(i == final_index) ? object : components[i];
Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
ids.push_back(member_inst->result_id());
}
const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
// Work backwards up the chain and replace each index with new constant.
for (size_t i = chain.size(); i > 0; i--) {
// Need to insert any previous instruction into the module first.
// Can't just insert in types_values_begin() because it will move above
// where the types are declared
for (Module::inst_iterator inst_iter = context->types_values_begin();
inst_iter != context->types_values_end(); ++inst_iter) {
Instruction* x = &*inst_iter;
if (inst->result_id() == x->result_id()) {
const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
break;
}
}
composite = chain[i - 1];
components = composite->AsCompositeConstant()->GetComponents();
type = composite->AsCompositeConstant()->type();
ids.clear();
for (size_t k = 0; k < components.size(); k++) {
const uint32_t index =
inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
const analysis::Constant* constant =
(k == index) ? new_constant : components[k];
const uint32_t constant_id =
const_mgr->FindDeclaredConstant(constant, 0);
ids.push_back(constant_id);
}
new_constant = const_mgr->GetConstant(type, ids);
}
// If multiple constants were created, only need to return the top index.
return new_constant;
};
}
ConstantFoldingRule FoldVectorShuffleWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
@ -1410,6 +1487,7 @@ void ConstantFoldingRules::AddFoldingRules() {
rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());

View File

@ -308,6 +308,72 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) {
builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
}
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) {
const std::string test =
R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 1 1 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%8 = OpConstantNull %uint
%9 = OpSpecConstantComposite %v3uint %uint_2 %uint_2 %uint_2
; CHECK: %15 = OpConstantComposite %v3uint %uint_3 %uint_2 %uint_2
; CHECK: %uint_3_0 = OpConstant %uint 3
; CHECK: %17 = OpConstantComposite %v3uint %8 %uint_2 %uint_2
; CHECK: %18 = OpConstantNull %uint
%10 = OpSpecConstantOp %v3uint CompositeInsert %uint_3 %9 0
%11 = OpSpecConstantOp %uint CompositeExtract %10 0
%12 = OpSpecConstantOp %v3uint CompositeInsert %8 %9 0
%13 = OpSpecConstantOp %uint CompositeExtract %12 0
%1 = OpFunction %void None %3
%14 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
}
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
const std::string test =
R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 1 1 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%mat3v3float = OpTypeMatrix %v3float 3
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%9 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
%10 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
%11 = OpSpecConstantComposite %v3float %float_1 %float_2 %float_1
%12 = OpSpecConstantComposite %mat3v3float %9 %10 %11
; CHECK: %float_2_0 = OpConstant %float 2
; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_2
; CHECK: %19 = OpConstantComposite %mat3v3float %9 %18 %11
; CHECK: %float_2_1 = OpConstant %float 2
%13 = OpSpecConstantOp %float CompositeExtract %12 2 1
%14 = OpSpecConstantOp %mat3v3float CompositeInsert %13 %12 1 2
%15 = OpSpecConstantOp %float CompositeExtract %14 1 2
%1 = OpFunction %void None %3
%16 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
}
// All types and some common constants that are potentially required in
// FoldSpecConstantOpAndCompositeTest.
std::vector<std::string> CommonTypesAndConstants() {