From e28edd458b729da7bbfd51e375feb33103709e6f Mon Sep 17 00:00:00 2001 From: GregF Date: Fri, 17 Nov 2017 16:47:11 -0700 Subject: [PATCH] Optimize loads/stores on nested structs Also fix LocalAccessChainConvert test: nested structs now convert Add InsertExtractElim test for nested struct --- source/opt/insert_extract_elim.cpp | 49 +++++-- source/opt/insert_extract_elim.h | 27 ++-- source/opt/mem_pass.cpp | 2 +- test/opt/insert_extract_elim_test.cpp | 129 +++++++++++++++++++ test/opt/local_access_chain_convert_test.cpp | 57 +++++++- 5 files changed, 234 insertions(+), 30 deletions(-) diff --git a/source/opt/insert_extract_elim.cpp b/source/opt/insert_extract_elim.cpp index d3c648548..b506d9ebb 100644 --- a/source/opt/insert_extract_elim.cpp +++ b/source/opt/insert_extract_elim.cpp @@ -31,24 +31,28 @@ const uint32_t kInsertCompositeIdInIdx = 1; } // anonymous namespace bool InsertExtractElimPass::ExtInsMatch(const ir::Instruction* extInst, - const ir::Instruction* insInst) const { - if (extInst->NumInOperands() != insInst->NumInOperands() - 1) return false; - uint32_t numIdx = extInst->NumInOperands() - 1; + const ir::Instruction* insInst, + const uint32_t extOffset) const { + if (extInst->NumInOperands() - extOffset != insInst->NumInOperands() - 1) + return false; + uint32_t numIdx = extInst->NumInOperands() - 1 - extOffset; for (uint32_t i = 0; i < numIdx; ++i) - if (extInst->GetSingleWordInOperand(i + 1) != + if (extInst->GetSingleWordInOperand(i + 1 + extOffset) != insInst->GetSingleWordInOperand(i + 2)) return false; return true; } bool InsertExtractElimPass::ExtInsConflict( - const ir::Instruction* extInst, const ir::Instruction* insInst) const { - if (extInst->NumInOperands() == insInst->NumInOperands() - 1) return false; - uint32_t extNumIdx = extInst->NumInOperands() - 1; + const ir::Instruction* extInst, const ir::Instruction* insInst, + const uint32_t extOffset) const { + if (extInst->NumInOperands() - extOffset == insInst->NumInOperands() - 1) + return false; + uint32_t extNumIdx = extInst->NumInOperands() - 1 - extOffset; uint32_t insNumIdx = insInst->NumInOperands() - 2; uint32_t numIdx = std::min(extNumIdx, insNumIdx); for (uint32_t i = 0; i < numIdx; ++i) - if (extInst->GetSingleWordInOperand(i + 1) != + if (extInst->GetSingleWordInOperand(i + 1 + extOffset) != insInst->GetSingleWordInOperand(i + 2)) return false; return true; @@ -68,13 +72,32 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) { uint32_t cid = ii->GetSingleWordInOperand(kExtractCompositeIdInIdx); ir::Instruction* cinst = get_def_use_mgr()->GetDef(cid); uint32_t replId = 0; + // Offset of extract indices being compared to insert indices. + // Offset increases as indices are matched. + uint32_t extOffset = 0; while (cinst->opcode() == SpvOpCompositeInsert) { - if (ExtInsConflict(&*ii, cinst)) break; - if (ExtInsMatch(&*ii, cinst)) { + if (ExtInsMatch(&*ii, cinst, extOffset)) { + // Match! Use inserted value as replacement replId = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx); break; } - cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx); + else if (ExtInsConflict(&*ii, cinst, extOffset)) { + // If extract has fewer indices than the insert, stop searching. + // Otherwise increment offset of extract indices considered and + // continue searching through the inserted value + if (ii->NumInOperands() - extOffset < + cinst->NumInOperands() - 1) { + break; + } + else { + extOffset += cinst->NumInOperands() - 2; + cid = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx); + } + } + else { + // Consider next composite in insert chain + cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx); + } cinst = get_def_use_mgr()->GetDef(cid); } // If search ended with CompositeConstruct or ConstantComposite @@ -85,8 +108,8 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) { // vector composition, and additional CompositeInsert. if ((cinst->opcode() == SpvOpCompositeConstruct || cinst->opcode() == SpvOpConstantComposite) && - (*ii).NumInOperands() == 2) { - uint32_t compIdx = (*ii).GetSingleWordInOperand(1); + (*ii).NumInOperands() - extOffset == 2) { + uint32_t compIdx = (*ii).GetSingleWordInOperand(extOffset + 1); if (IsVectorType(cinst->type_id())) { if (compIdx < cinst->NumInOperands()) { uint32_t i = 0; diff --git a/source/opt/insert_extract_elim.h b/source/opt/insert_extract_elim.h index d5dba00f0..cb0f4f3fd 100644 --- a/source/opt/insert_extract_elim.h +++ b/source/opt/insert_extract_elim.h @@ -40,25 +40,28 @@ class InsertExtractElimPass : public Pass { Status Process(ir::IRContext*) override; private: - // Return true if indices of extract |extInst| and insert |insInst| match + // Return true if indices of extract |extInst| starting at |extOffset| + // match indices of insert |insInst|. bool ExtInsMatch(const ir::Instruction* extInst, - const ir::Instruction* insInst) const; + const ir::Instruction* insInst, + const uint32_t extOffset) const; - // Return true if indices of extract |extInst| and insert |insInst| conflict, - // specifically, if the insert changes bits specified by the extract, but - // changes either more bits or less bits than the extract specifies, - // meaning the exact value being inserted cannot be used to replace - // the extract. + // Return true if indices of extract |extInst| starting at |extOffset| and + // indices of insert |insInst| conflict, specifically, if the insert + // changes bits specified by the extract, but changes either more bits + // or less bits than the extract specifies, meaning the exact value being + // inserted cannot be used to replace the extract. bool ExtInsConflict(const ir::Instruction* extInst, - const ir::Instruction* insInst) const; + const ir::Instruction* insInst, + const uint32_t extOffset) const; // Return true if |typeId| is a vector type bool IsVectorType(uint32_t typeId); - // Look for OpExtract on sequence of OpInserts in |func|. If there is an - // insert with identical indices, replace the extract with the value - // that is inserted if possible. Specifically, replace if there is no - // intervening insert which conflicts. + // Look for OpExtract on sequence of OpInserts in |func|. If there is a + // reaching insert which corresponds to the indices of the extract, replace + // the extract with the value that is inserted. Also resolve extracts from + // CompositeConstruct or ConstantComposite. bool EliminateInsertExtract(ir::Function* func); // Initialize extensions whitelist diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index ae5baa816..1292804a1 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -64,7 +64,7 @@ bool MemPass::IsTargetType(const ir::Instruction* typeInst) const { int nonMathComp = 0; typeInst->ForEachInId([&nonMathComp, this](const uint32_t* tid) { ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); - if (!IsBaseTargetType(compTypeInst)) ++nonMathComp; + if (!IsTargetType(compTypeInst)) ++nonMathComp; }); return nonMathComp == 0; } diff --git a/test/opt/insert_extract_elim_test.cpp b/test/opt/insert_extract_elim_test.cpp index 64dcef92d..bc90a3002 100644 --- a/test/opt/insert_extract_elim_test.cpp +++ b/test/opt/insert_extract_elim_test.cpp @@ -270,6 +270,135 @@ OpFunctionEnd predefs + after, true, true); } +TEST_F(InsertExtractElimTest, OptimizeNestedStruct) { + // The following HLSL has been pre-optimized to get the SPIR-V: + // struct S0 + // { + // int x; + // SamplerState ss; + // }; + // + // struct S1 + // { + // float b; + // S0 s0; + // }; + // + // struct S2 + // { + // int a1; + // S1 resources; + // }; + // + // SamplerState samp; + // Texture2D tex; + // + // float4 main(float4 vpos : VPOS) : COLOR0 + // { + // S1 s1; + // S2 s2; + // s1.s0.ss = samp; + // s2.resources = s1; + // return tex.Sample(s2.resources.s0.ss, float2(0.5)); + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_entryPointOutput +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 500 +OpName %main "main" +OpName %S0 "S0" +OpMemberName %S0 0 "x" +OpMemberName %S0 1 "ss" +OpName %S1 "S1" +OpMemberName %S1 0 "b" +OpMemberName %S1 1 "s0" +OpName %samp "samp" +OpName %S2 "S2" +OpMemberName %S2 0 "a1" +OpMemberName %S2 1 "resources" +OpName %tex "tex" +OpName %_entryPointOutput "@entryPointOutput" +OpDecorate %samp DescriptorSet 0 +OpDecorate %tex DescriptorSet 0 +OpDecorate %_entryPointOutput Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %v4float %_ptr_Function_v4float +%int = OpTypeInt 32 1 +%16 = OpTypeSampler +%S0 = OpTypeStruct %int %16 +%S1 = OpTypeStruct %float %S0 +%_ptr_Function_S1 = OpTypePointer Function %S1 +%int_1 = OpConstant %int 1 +%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16 +%samp = OpVariable %_ptr_UniformConstant_16 UniformConstant +%_ptr_Function_16 = OpTypePointer Function %16 +%S2 = OpTypeStruct %int %S1 +%_ptr_Function_S2 = OpTypePointer Function %S2 +%22 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_22 = OpTypePointer UniformConstant %22 +%tex = OpVariable %_ptr_UniformConstant_22 UniformConstant +%24 = OpTypeSampledImage %22 +%v2float = OpTypeVector %float 2 +%float_0_5 = OpConstant %float 0.5 +%27 = OpConstantComposite %v2float %float_0_5 %float_0_5 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%30 = OpLabel +%31 = OpVariable %_ptr_Function_S1 Function +%32 = OpVariable %_ptr_Function_S2 Function +%33 = OpLoad %16 %samp +%34 = OpLoad %S1 %31 +%35 = OpCompositeInsert %S1 %33 %34 1 1 +OpStore %31 %35 +%36 = OpLoad %S2 %32 +%37 = OpCompositeInsert %S2 %35 %36 1 +OpStore %32 %37 +%38 = OpLoad %22 %tex +%39 = OpCompositeExtract %16 %37 1 1 1 +%40 = OpSampledImage %24 %38 %39 +%41 = OpImageSampleImplicitLod %v4float %40 %27 +OpStore %_entryPointOutput %41 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%30 = OpLabel +%31 = OpVariable %_ptr_Function_S1 Function +%32 = OpVariable %_ptr_Function_S2 Function +%33 = OpLoad %16 %samp +%34 = OpLoad %S1 %31 +%35 = OpCompositeInsert %S1 %33 %34 1 1 +OpStore %31 %35 +%36 = OpLoad %S2 %32 +%37 = OpCompositeInsert %S2 %35 %36 1 +OpStore %32 %37 +%38 = OpLoad %22 %tex +%40 = OpSampledImage %24 %38 %33 +%41 = OpImageSampleImplicitLod %v4float %40 %27 +OpStore %_entryPointOutput %41 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization) { // Note: The SPIR-V assembly has had store/load elimination // performed to allow the inserts and extracts to directly diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp index 3b5a2fa9f..a450e6b2c 100644 --- a/test/opt/local_access_chain_convert_test.cpp +++ b/test/opt/local_access_chain_convert_test.cpp @@ -459,7 +459,7 @@ OpFunctionEnd } TEST_F(LocalAccessChainConvertTest, - UntargetedTypeNotConverted) { + NestedStructsConverted) { // #version 140 // @@ -481,7 +481,7 @@ TEST_F(LocalAccessChainConvertTest, // gl_FragColor = s2.s1.v1; // } - const std::string assembly = + const std::string predefs_before = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -512,7 +512,41 @@ OpName %gl_FragColor "gl_FragColor" %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Output_v4float = OpTypePointer Output %v4float %gl_FragColor = OpVariable %_ptr_Output_v4float Output -%main = OpFunction %void None %9 +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S1_t "S1_t" +OpMemberName %S1_t 0 "v1" +OpName %S2_t "S2_t" +OpMemberName %S2_t 0 "v2" +OpMemberName %S2_t 1 "s1" +OpName %s2 "s2" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S1_t = OpTypeStruct %v4float +%S2_t = OpTypeStruct %v4float %S1_t +%_ptr_Function_S2_t = OpTypePointer Function %S2_t +%int = OpTypeInt 32 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 %19 = OpLabel %s2 = OpVariable %_ptr_Function_S2_t Function %20 = OpLoad %v4float %BaseColor @@ -523,10 +557,25 @@ OpStore %21 %20 OpStore %gl_FragColor %23 OpReturn OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%s2 = OpVariable %_ptr_Function_S2_t Function +%20 = OpLoad %v4float %BaseColor +%24 = OpLoad %S2_t %s2 +%25 = OpCompositeInsert %S2_t %20 %24 1 0 +OpStore %s2 %25 +%26 = OpLoad %S2_t %s2 +%27 = OpCompositeExtract %v4float %26 1 0 +OpStore %gl_FragColor %27 +OpReturn +OpFunctionEnd )"; SinglePassRunAndCheck( - assembly, assembly, false, true); + predefs_before + before , predefs_after + after, true, true); } TEST_F(LocalAccessChainConvertTest,