Optimize loads/stores on nested structs

Also fix LocalAccessChainConvert test: nested structs now convert

Add InsertExtractElim test for nested struct
This commit is contained in:
GregF 2017-11-17 16:47:11 -07:00 committed by David Neto
parent b14291581f
commit e28edd458b
5 changed files with 234 additions and 30 deletions

View File

@ -31,24 +31,28 @@ const uint32_t kInsertCompositeIdInIdx = 1;
} // anonymous namespace
bool InsertExtractElimPass::ExtInsMatch(const ir::Instruction* extInst,
const ir::Instruction* insInst) const {
if (extInst->NumInOperands() != insInst->NumInOperands() - 1) return false;
uint32_t numIdx = extInst->NumInOperands() - 1;
const ir::Instruction* insInst,
const uint32_t extOffset) const {
if (extInst->NumInOperands() - extOffset != insInst->NumInOperands() - 1)
return false;
uint32_t numIdx = extInst->NumInOperands() - 1 - extOffset;
for (uint32_t i = 0; i < numIdx; ++i)
if (extInst->GetSingleWordInOperand(i + 1) !=
if (extInst->GetSingleWordInOperand(i + 1 + extOffset) !=
insInst->GetSingleWordInOperand(i + 2))
return false;
return true;
}
bool InsertExtractElimPass::ExtInsConflict(
const ir::Instruction* extInst, const ir::Instruction* insInst) const {
if (extInst->NumInOperands() == insInst->NumInOperands() - 1) return false;
uint32_t extNumIdx = extInst->NumInOperands() - 1;
const ir::Instruction* extInst, const ir::Instruction* insInst,
const uint32_t extOffset) const {
if (extInst->NumInOperands() - extOffset == insInst->NumInOperands() - 1)
return false;
uint32_t extNumIdx = extInst->NumInOperands() - 1 - extOffset;
uint32_t insNumIdx = insInst->NumInOperands() - 2;
uint32_t numIdx = std::min(extNumIdx, insNumIdx);
for (uint32_t i = 0; i < numIdx; ++i)
if (extInst->GetSingleWordInOperand(i + 1) !=
if (extInst->GetSingleWordInOperand(i + 1 + extOffset) !=
insInst->GetSingleWordInOperand(i + 2))
return false;
return true;
@ -68,13 +72,32 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) {
uint32_t cid = ii->GetSingleWordInOperand(kExtractCompositeIdInIdx);
ir::Instruction* cinst = get_def_use_mgr()->GetDef(cid);
uint32_t replId = 0;
// Offset of extract indices being compared to insert indices.
// Offset increases as indices are matched.
uint32_t extOffset = 0;
while (cinst->opcode() == SpvOpCompositeInsert) {
if (ExtInsConflict(&*ii, cinst)) break;
if (ExtInsMatch(&*ii, cinst)) {
if (ExtInsMatch(&*ii, cinst, extOffset)) {
// Match! Use inserted value as replacement
replId = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
break;
}
cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
else if (ExtInsConflict(&*ii, cinst, extOffset)) {
// If extract has fewer indices than the insert, stop searching.
// Otherwise increment offset of extract indices considered and
// continue searching through the inserted value
if (ii->NumInOperands() - extOffset <
cinst->NumInOperands() - 1) {
break;
}
else {
extOffset += cinst->NumInOperands() - 2;
cid = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
}
}
else {
// Consider next composite in insert chain
cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
}
cinst = get_def_use_mgr()->GetDef(cid);
}
// If search ended with CompositeConstruct or ConstantComposite
@ -85,8 +108,8 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) {
// vector composition, and additional CompositeInsert.
if ((cinst->opcode() == SpvOpCompositeConstruct ||
cinst->opcode() == SpvOpConstantComposite) &&
(*ii).NumInOperands() == 2) {
uint32_t compIdx = (*ii).GetSingleWordInOperand(1);
(*ii).NumInOperands() - extOffset == 2) {
uint32_t compIdx = (*ii).GetSingleWordInOperand(extOffset + 1);
if (IsVectorType(cinst->type_id())) {
if (compIdx < cinst->NumInOperands()) {
uint32_t i = 0;

View File

@ -40,25 +40,28 @@ class InsertExtractElimPass : public Pass {
Status Process(ir::IRContext*) override;
private:
// Return true if indices of extract |extInst| and insert |insInst| match
// Return true if indices of extract |extInst| starting at |extOffset|
// match indices of insert |insInst|.
bool ExtInsMatch(const ir::Instruction* extInst,
const ir::Instruction* insInst) const;
const ir::Instruction* insInst,
const uint32_t extOffset) const;
// Return true if indices of extract |extInst| and insert |insInst| conflict,
// specifically, if the insert changes bits specified by the extract, but
// changes either more bits or less bits than the extract specifies,
// meaning the exact value being inserted cannot be used to replace
// the extract.
// Return true if indices of extract |extInst| starting at |extOffset| and
// indices of insert |insInst| conflict, specifically, if the insert
// changes bits specified by the extract, but changes either more bits
// or less bits than the extract specifies, meaning the exact value being
// inserted cannot be used to replace the extract.
bool ExtInsConflict(const ir::Instruction* extInst,
const ir::Instruction* insInst) const;
const ir::Instruction* insInst,
const uint32_t extOffset) const;
// Return true if |typeId| is a vector type
bool IsVectorType(uint32_t typeId);
// Look for OpExtract on sequence of OpInserts in |func|. If there is an
// insert with identical indices, replace the extract with the value
// that is inserted if possible. Specifically, replace if there is no
// intervening insert which conflicts.
// Look for OpExtract on sequence of OpInserts in |func|. If there is a
// reaching insert which corresponds to the indices of the extract, replace
// the extract with the value that is inserted. Also resolve extracts from
// CompositeConstruct or ConstantComposite.
bool EliminateInsertExtract(ir::Function* func);
// Initialize extensions whitelist

View File

@ -64,7 +64,7 @@ bool MemPass::IsTargetType(const ir::Instruction* typeInst) const {
int nonMathComp = 0;
typeInst->ForEachInId([&nonMathComp, this](const uint32_t* tid) {
ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
if (!IsBaseTargetType(compTypeInst)) ++nonMathComp;
if (!IsTargetType(compTypeInst)) ++nonMathComp;
});
return nonMathComp == 0;
}

View File

@ -270,6 +270,135 @@ OpFunctionEnd
predefs + after, true, true);
}
TEST_F(InsertExtractElimTest, OptimizeNestedStruct) {
// The following HLSL has been pre-optimized to get the SPIR-V:
// struct S0
// {
// int x;
// SamplerState ss;
// };
//
// struct S1
// {
// float b;
// S0 s0;
// };
//
// struct S2
// {
// int a1;
// S1 resources;
// };
//
// SamplerState samp;
// Texture2D tex;
//
// float4 main(float4 vpos : VPOS) : COLOR0
// {
// S1 s1;
// S2 s2;
// s1.s0.ss = samp;
// s2.resources = s1;
// return tex.Sample(s2.resources.s0.ss, float2(0.5));
// }
const std::string predefs =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %_entryPointOutput
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 500
OpName %main "main"
OpName %S0 "S0"
OpMemberName %S0 0 "x"
OpMemberName %S0 1 "ss"
OpName %S1 "S1"
OpMemberName %S1 0 "b"
OpMemberName %S1 1 "s0"
OpName %samp "samp"
OpName %S2 "S2"
OpMemberName %S2 0 "a1"
OpMemberName %S2 1 "resources"
OpName %tex "tex"
OpName %_entryPointOutput "@entryPointOutput"
OpDecorate %samp DescriptorSet 0
OpDecorate %tex DescriptorSet 0
OpDecorate %_entryPointOutput Location 0
%void = OpTypeVoid
%10 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Function_v4float = OpTypePointer Function %v4float
%14 = OpTypeFunction %v4float %_ptr_Function_v4float
%int = OpTypeInt 32 1
%16 = OpTypeSampler
%S0 = OpTypeStruct %int %16
%S1 = OpTypeStruct %float %S0
%_ptr_Function_S1 = OpTypePointer Function %S1
%int_1 = OpConstant %int 1
%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16
%samp = OpVariable %_ptr_UniformConstant_16 UniformConstant
%_ptr_Function_16 = OpTypePointer Function %16
%S2 = OpTypeStruct %int %S1
%_ptr_Function_S2 = OpTypePointer Function %S2
%22 = OpTypeImage %float 2D 0 0 0 1 Unknown
%_ptr_UniformConstant_22 = OpTypePointer UniformConstant %22
%tex = OpVariable %_ptr_UniformConstant_22 UniformConstant
%24 = OpTypeSampledImage %22
%v2float = OpTypeVector %float 2
%float_0_5 = OpConstant %float 0.5
%27 = OpConstantComposite %v2float %float_0_5 %float_0_5
%_ptr_Input_v4float = OpTypePointer Input %v4float
%_ptr_Output_v4float = OpTypePointer Output %v4float
%_entryPointOutput = OpVariable %_ptr_Output_v4float Output
)";
const std::string before =
R"(%main = OpFunction %void None %10
%30 = OpLabel
%31 = OpVariable %_ptr_Function_S1 Function
%32 = OpVariable %_ptr_Function_S2 Function
%33 = OpLoad %16 %samp
%34 = OpLoad %S1 %31
%35 = OpCompositeInsert %S1 %33 %34 1 1
OpStore %31 %35
%36 = OpLoad %S2 %32
%37 = OpCompositeInsert %S2 %35 %36 1
OpStore %32 %37
%38 = OpLoad %22 %tex
%39 = OpCompositeExtract %16 %37 1 1 1
%40 = OpSampledImage %24 %38 %39
%41 = OpImageSampleImplicitLod %v4float %40 %27
OpStore %_entryPointOutput %41
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %10
%30 = OpLabel
%31 = OpVariable %_ptr_Function_S1 Function
%32 = OpVariable %_ptr_Function_S2 Function
%33 = OpLoad %16 %samp
%34 = OpLoad %S1 %31
%35 = OpCompositeInsert %S1 %33 %34 1 1
OpStore %31 %35
%36 = OpLoad %S2 %32
%37 = OpCompositeInsert %S2 %35 %36 1
OpStore %32 %37
%38 = OpLoad %22 %tex
%40 = OpSampledImage %24 %38 %33
%41 = OpImageSampleImplicitLod %v4float %40 %27
OpStore %_entryPointOutput %41
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::InsertExtractElimPass>(predefs + before,
predefs + after, true, true);
}
TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization) {
// Note: The SPIR-V assembly has had store/load elimination
// performed to allow the inserts and extracts to directly

View File

@ -459,7 +459,7 @@ OpFunctionEnd
}
TEST_F(LocalAccessChainConvertTest,
UntargetedTypeNotConverted) {
NestedStructsConverted) {
// #version 140
//
@ -481,7 +481,7 @@ TEST_F(LocalAccessChainConvertTest,
// gl_FragColor = s2.s1.v1;
// }
const std::string assembly =
const std::string predefs_before =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@ -512,7 +512,41 @@ OpName %gl_FragColor "gl_FragColor"
%_ptr_Function_v4float = OpTypePointer Function %v4float
%_ptr_Output_v4float = OpTypePointer Output %v4float
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
%main = OpFunction %void None %9
)";
const std::string predefs_after =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 140
OpName %main "main"
OpName %S1_t "S1_t"
OpMemberName %S1_t 0 "v1"
OpName %S2_t "S2_t"
OpMemberName %S2_t 0 "v2"
OpMemberName %S2_t 1 "s1"
OpName %s2 "s2"
OpName %BaseColor "BaseColor"
OpName %gl_FragColor "gl_FragColor"
%void = OpTypeVoid
%9 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%S1_t = OpTypeStruct %v4float
%S2_t = OpTypeStruct %v4float %S1_t
%_ptr_Function_S2_t = OpTypePointer Function %S2_t
%int = OpTypeInt 32 1
%_ptr_Input_v4float = OpTypePointer Input %v4float
%BaseColor = OpVariable %_ptr_Input_v4float Input
%_ptr_Function_v4float = OpTypePointer Function %v4float
%_ptr_Output_v4float = OpTypePointer Output %v4float
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
)";
const std::string before =
R"(%main = OpFunction %void None %9
%19 = OpLabel
%s2 = OpVariable %_ptr_Function_S2_t Function
%20 = OpLoad %v4float %BaseColor
@ -523,10 +557,25 @@ OpStore %21 %20
OpStore %gl_FragColor %23
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %9
%19 = OpLabel
%s2 = OpVariable %_ptr_Function_S2_t Function
%20 = OpLoad %v4float %BaseColor
%24 = OpLoad %S2_t %s2
%25 = OpCompositeInsert %S2_t %20 %24 1 0
OpStore %s2 %25
%26 = OpLoad %S2_t %s2
%27 = OpCompositeExtract %v4float %26 1 0
OpStore %gl_FragColor %27
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalAccessChainConvertPass>(
assembly, assembly, false, true);
predefs_before + before , predefs_after + after, true, true);
}
TEST_F(LocalAccessChainConvertTest,