diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp index 1f25b33b8..e1bb61d79 100644 --- a/source/opt/desc_sroa.cpp +++ b/source/opt/desc_sroa.cpp @@ -56,7 +56,23 @@ bool DescriptorScalarReplacement::IsCandidate(Instruction* var) { uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1); Instruction* var_type_inst = context()->get_def_use_mgr()->GetDef(var_type_id); - if (var_type_inst->opcode() != SpvOpTypeArray) { + if (var_type_inst->opcode() != SpvOpTypeArray && + var_type_inst->opcode() != SpvOpTypeStruct) { + return false; + } + + // All structures with descriptor assignments must be replaced by variables, + // one for each of their members - with the exceptions of buffers. + // Buffers are represented as structures, but we shouldn't replace a buffer + // with its elements. All buffers have offset decorations for members of their + // structure types. + bool has_offset_decoration = false; + context()->get_decoration_mgr()->ForEachDecoration( + var_type_inst->result_id(), SpvDecorationOffset, + [&has_offset_decoration](const Instruction&) { + has_offset_decoration = true; + }); + if (has_offset_decoration) { return false; } @@ -177,21 +193,36 @@ uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var, uint32_t ptr_type_id = var->type_id(); Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); assert(ptr_type_inst->opcode() == SpvOpTypePointer && - "Variable should be a pointer to an array."); - uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1); - Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id); - assert(arr_type_inst->opcode() == SpvOpTypeArray && - "Variable should be a pointer to an array."); + "Variable should be a pointer to an array or structure."); + uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id); + const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray; + const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct; + assert((is_array || is_struct) && + "Variable should be a pointer to an array or structure."); - uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1); - const analysis::Constant* array_len_const = - context()->get_constant_mgr()->FindDeclaredConstant(array_len_id); - assert(array_len_const != nullptr && "Array length must be a constant."); - uint32_t array_len = array_len_const->GetU32(); + // For arrays, each array element should be replaced with a new replacement + // variable + if (is_array) { + uint32_t array_len_id = pointee_type_inst->GetSingleWordInOperand(1); + const analysis::Constant* array_len_const = + context()->get_constant_mgr()->FindDeclaredConstant(array_len_id); + assert(array_len_const != nullptr && "Array length must be a constant."); + uint32_t array_len = array_len_const->GetU32(); - replacement_vars = replacement_variables_ - .insert({var, std::vector(array_len, 0)}) - .first; + replacement_vars = replacement_variables_ + .insert({var, std::vector(array_len, 0)}) + .first; + } + // For structures, each member should be replaced with a new replacement + // variable + if (is_struct) { + const uint32_t num_members = pointee_type_inst->NumInOperands(); + replacement_vars = + replacement_variables_ + .insert({var, std::vector(num_members, 0)}) + .first; + } } if (replacement_vars->second[idx] == 0) { @@ -212,12 +243,17 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable( uint32_t ptr_type_id = var->type_id(); Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); assert(ptr_type_inst->opcode() == SpvOpTypePointer && - "Variable should be a pointer to an array."); - uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1); - Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id); - assert(arr_type_inst->opcode() == SpvOpTypeArray && - "Variable should be a pointer to an array."); - uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0); + "Variable should be a pointer to an array or structure."); + uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id); + const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray; + const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct; + assert((is_array || is_struct) && + "Variable should be a pointer to an array or structure."); + + uint32_t element_type_id = + is_array ? pointee_type_inst->GetSingleWordInOperand(0) + : pointee_type_inst->GetSingleWordInOperand(idx); uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType( element_type_id, storage_class); @@ -242,19 +278,33 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable( uint32_t decoration = new_decoration->GetSingleWordInOperand(1u); if (decoration == SpvDecorationBinding) { - uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx; + uint32_t new_binding = + new_decoration->GetSingleWordInOperand(2) + + idx * GetNumBindingsUsedByType(ptr_element_type_id); new_decoration->SetInOperand(2, {new_binding}); } context()->AddAnnotationInst(std::move(new_decoration)); } // Create a new OpName for the replacement variable. + std::vector> names_to_add; for (auto p : context()->GetNames(var->result_id())) { Instruction* name_inst = p.second; std::string name_str = utils::MakeString(name_inst->GetOperand(1).words); - name_str += "["; - name_str += utils::ToString(idx); - name_str += "]"; + if (is_array) { + name_str += "[" + utils::ToString(idx) + "]"; + } + if (is_struct) { + Instruction* member_name_inst = + context()->GetMemberName(pointee_type_inst->result_id(), idx); + name_str += "."; + if (member_name_inst) + name_str += utils::MakeString(member_name_inst->GetOperand(2).words); + else + // In case the member does not have a name assigned to it, use the + // member index. + name_str += utils::ToString(idx); + } std::unique_ptr new_name(new Instruction( context(), SpvOpName, 0, 0, @@ -262,12 +312,53 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable( {SPV_OPERAND_TYPE_ID, {id}}, {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}})); Instruction* new_name_inst = new_name.get(); - context()->AddDebug2Inst(std::move(new_name)); get_def_use_mgr()->AnalyzeInstDefUse(new_name_inst); + names_to_add.push_back(std::move(new_name)); } + // We shouldn't add the new names when we are iterating over name ranges + // above. We can add all the new names now. + for (auto& new_name : names_to_add) + context()->AddDebug2Inst(std::move(new_name)); + return id; } +uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType( + uint32_t type_id) { + Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); + + // If it's a pointer, look at the underlying type. + if (type_inst->opcode() == SpvOpTypePointer) { + type_id = type_inst->GetSingleWordInOperand(1); + type_inst = get_def_use_mgr()->GetDef(type_id); + } + + // Arrays consume N*M binding numbers where N is the array length, and M is + // the number of bindings used by each array element. + if (type_inst->opcode() == SpvOpTypeArray) { + uint32_t element_type_id = type_inst->GetSingleWordInOperand(0); + uint32_t length_id = type_inst->GetSingleWordInOperand(1); + const analysis::Constant* length_const = + context()->get_constant_mgr()->FindDeclaredConstant(length_id); + // OpTypeArray's length must always be a constant + assert(length_const != nullptr); + uint32_t num_elems = length_const->GetU32(); + return num_elems * GetNumBindingsUsedByType(element_type_id); + } + + // The number of bindings consumed by a structure is the sum of the bindings + // used by its members. + if (type_inst->opcode() == SpvOpTypeStruct) { + uint32_t sum = 0; + for (uint32_t i = 0; i < type_inst->NumInOperands(); i++) + sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i)); + return sum; + } + + // All other types are considered to take up 1 binding number. + return 1; +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h index a95c6b582..24ec6e2bf 100644 --- a/source/opt/desc_sroa.h +++ b/source/opt/desc_sroa.h @@ -70,6 +70,15 @@ class DescriptorScalarReplacement : public Pass { // element of |var|. uint32_t CreateReplacementVariable(Instruction* var, uint32_t idx); + // Returns the number of bindings used by the given |type_id|. + // All types are considered to use 1 binding slot, except: + // 1- A pointer type consumes as many binding numbers as its pointee. + // 2- An array of size N consumes N*M binding numbers, where M is the number + // of bindings used by each array element. + // 3- The number of bindings consumed by a structure is the sum of the + // bindings used by its members. + uint32_t GetNumBindingsUsedByType(uint32_t type_id); + // A map from an OpVariable instruction to the set of variables that will be // used to replace it. The entry |replacement_variables_[var][i]| is the id of // a variable that will be used in the place of the the ith element of the diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index a1b63ff9e..b19365741 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -357,6 +357,13 @@ class IRContext { inline IteratorRange::iterator> GetNames(uint32_t id); + // Returns an OpMemberName instruction that targets |struct_type_id| at + // index |index|. Returns nullptr if no such instruction exists. + // While the SPIR-V spec does not prohibit having multiple OpMemberName + // instructions for the same structure member, it is hard to imagine a member + // having more than one name. This method returns the first one it finds. + inline Instruction* GetMemberName(uint32_t struct_type_id, uint32_t index); + // Sets the message consumer to the given |consumer|. |consumer| which will be // invoked every time there is a message to be communicated to the outside. void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } @@ -1061,7 +1068,9 @@ void IRContext::AddDebug1Inst(std::unique_ptr&& d) { void IRContext::AddDebug2Inst(std::unique_ptr&& d) { if (AreAnalysesValid(kAnalysisNameMap)) { if (d->opcode() == SpvOpName || d->opcode() == SpvOpMemberName) { - id_to_name_->insert({d->result_id(), d.get()}); + // OpName and OpMemberName do not have result-ids. The target of the + // instruction is at InOperand index 0. + id_to_name_->insert({d->GetSingleWordInOperand(0), d.get()}); } } module()->AddDebug2Inst(std::move(d)); @@ -1135,6 +1144,21 @@ IRContext::GetNames(uint32_t id) { return make_range(std::move(result.first), std::move(result.second)); } +Instruction* IRContext::GetMemberName(uint32_t struct_type_id, uint32_t index) { + if (!AreAnalysesValid(kAnalysisNameMap)) { + BuildIdToNameMap(); + } + auto result = id_to_name_->equal_range(struct_type_id); + for (auto i = result.first; i != result.second; ++i) { + auto* name_instr = i->second; + if (name_instr->opcode() == SpvOpMemberName && + name_instr->GetSingleWordInOperand(1) == index) { + return name_instr; + } + } + return nullptr; +} + } // namespace opt } // namespace spvtools diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp index 11074c347..f12d39e87 100644 --- a/test/opt/desc_sroa_test.cpp +++ b/test/opt/desc_sroa_test.cpp @@ -25,7 +25,116 @@ namespace { using DescriptorScalarReplacementTest = PassTest<::testing::Test>; -TEST_F(DescriptorScalarReplacementTest, ExpandTexture) { +std::string GetStructureArrayTestSpirv() { + // The SPIR-V for the following high-level shader: + // Flattening structures and arrays should result in the following binding + // numbers. Only the ones that are actually used in the shader should be in + // the final SPIR-V. + // + // globalS[0][0].t[0] 0 (used) + // globalS[0][0].t[1] 1 + // globalS[0][0].s[0] 2 (used) + // globalS[0][0].s[1] 3 + // globalS[0][1].t[0] 4 + // globalS[0][1].t[1] 5 + // globalS[0][1].s[0] 6 + // globalS[0][1].s[1] 7 + // globalS[1][0].t[0] 8 + // globalS[1][0].t[1] 9 + // globalS[1][0].s[0] 10 + // globalS[1][0].s[1] 11 + // globalS[1][1].t[0] 12 + // globalS[1][1].t[1] 13 (used) + // globalS[1][1].s[0] 14 + // globalS[1][1].s[1] 15 (used) + + /* + struct S { + Texture2D t[2]; + SamplerState s[2]; + }; + + S globalS[2][2]; + + float4 main() : SV_Target { + return globalS[0][0].t[0].Sample(globalS[0][0].s[0], float2(0,0)) + + globalS[1][1].t[1].Sample(globalS[1][1].s[1], float2(0,0)); + } + */ + + return R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %out_var_SV_Target + OpExecutionMode %main OriginUpperLeft + OpName %S "S" + OpMemberName %S 0 "t" + OpMemberName %S 1 "s" + OpName %type_2d_image "type.2d.image" + OpName %type_sampler "type.sampler" + OpName %globalS "globalS" + OpName %out_var_SV_Target "out.var.SV_Target" + OpName %main "main" + OpName %src_main "src.main" + OpName %bb_entry "bb.entry" + OpName %type_sampled_image "type.sampled.image" + OpDecorate %out_var_SV_Target Location 0 + OpDecorate %globalS DescriptorSet 0 + OpDecorate %globalS Binding 0 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %float = OpTypeFloat 32 + %float_0 = OpConstant %float 0 + %v2float = OpTypeVector %float 2 + %10 = OpConstantComposite %v2float %float_0 %float_0 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%type_2d_image = OpTypeImage %float 2D 2 0 0 1 Unknown +%_arr_type_2d_image_uint_2 = OpTypeArray %type_2d_image %uint_2 +%type_sampler = OpTypeSampler +%_arr_type_sampler_uint_2 = OpTypeArray %type_sampler %uint_2 + %S = OpTypeStruct %_arr_type_2d_image_uint_2 %_arr_type_sampler_uint_2 +%_arr_S_uint_2 = OpTypeArray %S %uint_2 +%_arr__arr_S_uint_2_uint_2 = OpTypeArray %_arr_S_uint_2 %uint_2 +%_ptr_UniformConstant__arr__arr_S_uint_2_uint_2 = OpTypePointer UniformConstant %_arr__arr_S_uint_2_uint_2 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %24 = OpTypeFunction %void + %28 = OpTypeFunction %v4float +%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image +%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler +%type_sampled_image = OpTypeSampledImage %type_2d_image + %globalS = OpVariable %_ptr_UniformConstant__arr__arr_S_uint_2_uint_2 UniformConstant +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output + %main = OpFunction %void None %24 + %25 = OpLabel + %26 = OpFunctionCall %v4float %src_main + OpStore %out_var_SV_Target %26 + OpReturn + OpFunctionEnd + %src_main = OpFunction %v4float None %28 + %bb_entry = OpLabel + %31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %globalS %int_0 %int_0 %int_0 %int_0 + %32 = OpLoad %type_2d_image %31 + %34 = OpAccessChain %_ptr_UniformConstant_type_sampler %globalS %int_0 %int_0 %int_1 %int_0 + %35 = OpLoad %type_sampler %34 + %37 = OpSampledImage %type_sampled_image %32 %35 + %38 = OpImageSampleImplicitLod %v4float %37 %10 None + %39 = OpAccessChain %_ptr_UniformConstant_type_2d_image %globalS %int_1 %int_1 %int_0 %int_1 + %40 = OpLoad %type_2d_image %39 + %41 = OpAccessChain %_ptr_UniformConstant_type_sampler %globalS %int_1 %int_1 %int_1 %int_1 + %42 = OpLoad %type_sampler %41 + %43 = OpSampledImage %type_sampled_image %40 %42 + %44 = OpImageSampleImplicitLod %v4float %43 %10 None + %45 = OpFAdd %v4float %38 %44 + OpReturnValue %45 + OpFunctionEnd + )"; +} + +TEST_F(DescriptorScalarReplacementTest, ExpandArrayOfTextures) { const std::string text = R"( ; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0 ; CHECK: OpDecorate [[var1]] Binding 0 @@ -94,7 +203,7 @@ TEST_F(DescriptorScalarReplacementTest, ExpandTexture) { SinglePassRunAndMatch(text, true); } -TEST_F(DescriptorScalarReplacementTest, ExpandSampler) { +TEST_F(DescriptorScalarReplacementTest, ExpandArrayOfSamplers) { const std::string text = R"( ; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0 ; CHECK: OpDecorate [[var1]] Binding 1 @@ -145,7 +254,7 @@ TEST_F(DescriptorScalarReplacementTest, ExpandSampler) { SinglePassRunAndMatch(text, true); } -TEST_F(DescriptorScalarReplacementTest, ExpandSSBO) { +TEST_F(DescriptorScalarReplacementTest, ExpandArrayOfSSBOs) { // Tests the expansion of an SSBO. Also check that an access chain with more // than 1 index is correctly handled. const std::string text = R"( @@ -265,6 +374,177 @@ TEST_F(DescriptorScalarReplacementTest, NameNewVariables) { SinglePassRunAndMatch(text, true); } + +TEST_F(DescriptorScalarReplacementTest, DontExpandCBuffers) { + // Checks that constant buffers are not expanded. + // Constant buffers are represented as global structures, but they should not + // be replaced with new variables for their elements. + /* + cbuffer MyCbuffer : register(b1) { + float2 a; + float2 b; + }; + float main() : A { + return a.x + b.y; + } + */ + const std::string text = R"( +; CHECK: OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_0 %int_0 +; CHECK: OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_1 %int_1 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %out_var_A + OpSource HLSL 600 + OpName %type_MyCbuffer "type.MyCbuffer" + OpMemberName %type_MyCbuffer 0 "a" + OpMemberName %type_MyCbuffer 1 "b" + OpName %MyCbuffer "MyCbuffer" + OpName %out_var_A "out.var.A" + OpName %main "main" + OpDecorate %out_var_A Location 0 + OpDecorate %MyCbuffer DescriptorSet 0 + OpDecorate %MyCbuffer Binding 1 + OpMemberDecorate %type_MyCbuffer 0 Offset 0 + OpMemberDecorate %type_MyCbuffer 1 Offset 8 + OpDecorate %type_MyCbuffer Block + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 +%type_MyCbuffer = OpTypeStruct %v2float %v2float +%_ptr_Uniform_type_MyCbuffer = OpTypePointer Uniform %type_MyCbuffer +%_ptr_Output_float = OpTypePointer Output %float + %void = OpTypeVoid + %13 = OpTypeFunction %void +%_ptr_Uniform_float = OpTypePointer Uniform %float + %MyCbuffer = OpVariable %_ptr_Uniform_type_MyCbuffer Uniform + %out_var_A = OpVariable %_ptr_Output_float Output + %main = OpFunction %void None %13 + %15 = OpLabel + %16 = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_0 %int_0 + %17 = OpLoad %float %16 + %18 = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_1 %int_1 + %19 = OpLoad %float %18 + %20 = OpFAdd %float %17 %19 + OpStore %out_var_A %20 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, DontExpandStructuredBuffers) { + // Checks that structured buffers are not expanded. + // Structured buffers are represented as global structures, that have one + // member which is a runtime array. + /* + struct S { + float2 a; + float2 b; + }; + RWStructuredBuffer sb; + float main() : A { + return sb[0].a.x + sb[0].b.x; + } + */ + const std::string text = R"( +; CHECK: OpAccessChain %_ptr_Uniform_float %sb %int_0 %uint_0 %int_0 %int_0 +; CHECK: OpAccessChain %_ptr_Uniform_float %sb %int_0 %uint_0 %int_1 %int_0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %out_var_A + OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" + OpName %S "S" + OpMemberName %S 0 "a" + OpMemberName %S 1 "b" + OpName %sb "sb" + OpName %out_var_A "out.var.A" + OpName %main "main" + OpDecorate %out_var_A Location 0 + OpDecorate %sb DescriptorSet 0 + OpDecorate %sb Binding 0 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %_runtimearr_S ArrayStride 16 + OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 + OpDecorate %type_RWStructuredBuffer_S BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %int_1 = OpConstant %int 1 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %S = OpTypeStruct %v2float %v2float +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%_ptr_Output_float = OpTypePointer Output %float + %void = OpTypeVoid + %17 = OpTypeFunction %void +%_ptr_Uniform_float = OpTypePointer Uniform %float + %sb = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform + %out_var_A = OpVariable %_ptr_Output_float Output + %main = OpFunction %void None %17 + %19 = OpLabel + %20 = OpAccessChain %_ptr_Uniform_float %sb %int_0 %uint_0 %int_0 %int_0 + %21 = OpLoad %float %20 + %22 = OpAccessChain %_ptr_Uniform_float %sb %int_0 %uint_0 %int_1 %int_0 + %23 = OpLoad %float %22 + %24 = OpFAdd %float %21 %23 + OpStore %out_var_A %24 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, StructureArrayNames) { + // Checks that names are properly generated for multi-dimension arrays and + // structure members. + const std::string checks = R"( +; CHECK: OpName %globalS_0__0__t_0_ "globalS[0][0].t[0]" +; CHECK: OpName %globalS_0__0__s_0_ "globalS[0][0].s[0]" +; CHECK: OpName %globalS_1__1__t_1_ "globalS[1][1].t[1]" +; CHECK: OpName %globalS_1__1__s_1_ "globalS[1][1].s[1]" + )"; + + const std::string text = checks + GetStructureArrayTestSpirv(); + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, StructureArrayBindings) { + // Checks that flattening structures and arrays results in correct binding + // numbers. + const std::string checks = R"( +; CHECK: OpDecorate %globalS_0__0__t_0_ Binding 0 +; CHECK: OpDecorate %globalS_0__0__s_0_ Binding 2 +; CHECK: OpDecorate %globalS_1__1__t_1_ Binding 13 +; CHECK: OpDecorate %globalS_1__1__s_1_ Binding 15 + )"; + + const std::string text = checks + GetStructureArrayTestSpirv(); + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, StructureArrayReplacements) { + // Checks that all access chains indexing into structures and/or arrays are + // replaced with direct access to replacement variables. + const std::string checks = R"( +; CHECK-NOT: OpAccessChain +; CHECK: OpLoad %type_2d_image %globalS_0__0__t_0_ +; CHECK: OpLoad %type_sampler %globalS_0__0__s_0_ +; CHECK: OpLoad %type_2d_image %globalS_1__1__t_1_ +; CHECK: OpLoad %type_sampler %globalS_1__1__s_1_ + )"; + + const std::string text = checks + GetStructureArrayTestSpirv(); + SinglePassRunAndMatch(text, true); +} + } // namespace } // namespace opt } // namespace spvtools