mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 19:20:05 +00:00
spirv-opt: Add const folding for CompositeInsert (#4943)
* spirv-opt: Add const folding pass for CompositeInsert * spirv-opt: Fix anas stack-use-after-scope
This commit is contained in:
parent
a5e766b2b4
commit
54d4e77fa5
@ -120,6 +120,83 @@ ConstantFoldingRule FoldExtractWithConstants() {
|
||||
};
|
||||
}
|
||||
|
||||
// Folds an OpcompositeInsert where input is a composite constant.
|
||||
ConstantFoldingRule FoldInsertWithConstants() {
|
||||
return [](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
-> const analysis::Constant* {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Constant* object = constants[0];
|
||||
const analysis::Constant* composite = constants[1];
|
||||
if (object == nullptr || composite == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If there is more than 1 index, then each additional constant used by the
|
||||
// index will need to be recreated to use the inserted object.
|
||||
std::vector<const analysis::Constant*> chain;
|
||||
std::vector<const analysis::Constant*> components;
|
||||
const analysis::Type* type = nullptr;
|
||||
|
||||
// Work down hierarchy and add all the indexes, not including the final
|
||||
// index.
|
||||
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
|
||||
if (i != inst->NumInOperands() - 1) {
|
||||
chain.push_back(composite);
|
||||
}
|
||||
const uint32_t index = inst->GetSingleWordInOperand(i);
|
||||
components = composite->AsCompositeConstant()->GetComponents();
|
||||
type = composite->AsCompositeConstant()->type();
|
||||
composite = components[index];
|
||||
}
|
||||
|
||||
// Final index in hierarchy is inserted with new object.
|
||||
const uint32_t final_index =
|
||||
inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
|
||||
std::vector<uint32_t> ids;
|
||||
for (size_t i = 0; i < components.size(); i++) {
|
||||
const analysis::Constant* constant =
|
||||
(i == final_index) ? object : components[i];
|
||||
Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
|
||||
ids.push_back(member_inst->result_id());
|
||||
}
|
||||
const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
|
||||
|
||||
// Work backwards up the chain and replace each index with new constant.
|
||||
for (size_t i = chain.size(); i > 0; i--) {
|
||||
// Need to insert any previous instruction into the module first.
|
||||
// Can't just insert in types_values_begin() because it will move above
|
||||
// where the types are declared
|
||||
for (Module::inst_iterator inst_iter = context->types_values_begin();
|
||||
inst_iter != context->types_values_end(); ++inst_iter) {
|
||||
Instruction* x = &*inst_iter;
|
||||
if (inst->result_id() == x->result_id()) {
|
||||
const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
composite = chain[i - 1];
|
||||
components = composite->AsCompositeConstant()->GetComponents();
|
||||
type = composite->AsCompositeConstant()->type();
|
||||
ids.clear();
|
||||
for (size_t k = 0; k < components.size(); k++) {
|
||||
const uint32_t index =
|
||||
inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
|
||||
const analysis::Constant* constant =
|
||||
(k == index) ? new_constant : components[k];
|
||||
const uint32_t constant_id =
|
||||
const_mgr->FindDeclaredConstant(constant, 0);
|
||||
ids.push_back(constant_id);
|
||||
}
|
||||
new_constant = const_mgr->GetConstant(type, ids);
|
||||
}
|
||||
|
||||
// If multiple constants were created, only need to return the top index.
|
||||
return new_constant;
|
||||
};
|
||||
}
|
||||
|
||||
ConstantFoldingRule FoldVectorShuffleWithConstants() {
|
||||
return [](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
@ -1410,6 +1487,7 @@ void ConstantFoldingRules::AddFoldingRules() {
|
||||
rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
|
||||
|
||||
rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
|
||||
rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
|
||||
|
||||
rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
|
||||
rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
|
||||
|
@ -308,6 +308,72 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) {
|
||||
builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
|
||||
}
|
||||
|
||||
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) {
|
||||
const std::string test =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %1 "main"
|
||||
OpExecutionMode %1 LocalSize 1 1 1
|
||||
%void = OpTypeVoid
|
||||
%3 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%v3uint = OpTypeVector %uint 3
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%uint_3 = OpConstant %uint 3
|
||||
%8 = OpConstantNull %uint
|
||||
%9 = OpSpecConstantComposite %v3uint %uint_2 %uint_2 %uint_2
|
||||
; CHECK: %15 = OpConstantComposite %v3uint %uint_3 %uint_2 %uint_2
|
||||
; CHECK: %uint_3_0 = OpConstant %uint 3
|
||||
; CHECK: %17 = OpConstantComposite %v3uint %8 %uint_2 %uint_2
|
||||
; CHECK: %18 = OpConstantNull %uint
|
||||
%10 = OpSpecConstantOp %v3uint CompositeInsert %uint_3 %9 0
|
||||
%11 = OpSpecConstantOp %uint CompositeExtract %10 0
|
||||
%12 = OpSpecConstantOp %v3uint CompositeInsert %8 %9 0
|
||||
%13 = OpSpecConstantOp %uint CompositeExtract %12 0
|
||||
%1 = OpFunction %void None %3
|
||||
%14 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
|
||||
}
|
||||
|
||||
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
|
||||
const std::string test =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %1 "main"
|
||||
OpExecutionMode %1 LocalSize 1 1 1
|
||||
%void = OpTypeVoid
|
||||
%3 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%v3float = OpTypeVector %float 3
|
||||
%mat3v3float = OpTypeMatrix %v3float 3
|
||||
%float_1 = OpConstant %float 1
|
||||
%float_2 = OpConstant %float 2
|
||||
%9 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
|
||||
%10 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
|
||||
%11 = OpSpecConstantComposite %v3float %float_1 %float_2 %float_1
|
||||
%12 = OpSpecConstantComposite %mat3v3float %9 %10 %11
|
||||
; CHECK: %float_2_0 = OpConstant %float 2
|
||||
; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_2
|
||||
; CHECK: %19 = OpConstantComposite %mat3v3float %9 %18 %11
|
||||
; CHECK: %float_2_1 = OpConstant %float 2
|
||||
%13 = OpSpecConstantOp %float CompositeExtract %12 2 1
|
||||
%14 = OpSpecConstantOp %mat3v3float CompositeInsert %13 %12 1 2
|
||||
%15 = OpSpecConstantOp %float CompositeExtract %14 1 2
|
||||
%1 = OpFunction %void None %3
|
||||
%16 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
|
||||
}
|
||||
|
||||
// All types and some common constants that are potentially required in
|
||||
// FoldSpecConstantOpAndCompositeTest.
|
||||
std::vector<std::string> CommonTypesAndConstants() {
|
||||
|
Loading…
Reference in New Issue
Block a user