diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp index b68549a0c..5e950069d 100644 --- a/source/opt/desc_sroa.cpp +++ b/source/opt/desc_sroa.cpp @@ -63,16 +63,7 @@ bool DescriptorScalarReplacement::IsCandidate(Instruction* var) { // All structures with descriptor assignments must be replaced by variables, // one for each of their members - with the exceptions of buffers. - // Buffers are represented as structures, but we shouldn't replace a buffer - // with its elements. All buffers have offset decorations for members of their - // structure types. - bool has_offset_decoration = false; - context()->get_decoration_mgr()->ForEachDecoration( - var_type_inst->result_id(), SpvDecorationOffset, - [&has_offset_decoration](const Instruction&) { - has_offset_decoration = true; - }); - if (has_offset_decoration) { + if (IsTypeOfStructuredBuffer(var_type_inst)) { return false; } @@ -99,6 +90,23 @@ bool DescriptorScalarReplacement::IsCandidate(Instruction* var) { return true; } +bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer( + const Instruction* type) const { + if (type->opcode() != SpvOpTypeStruct) { + return false; + } + + // All buffers have offset decorations for members of their structure types. + // This is how we distinguish it from a structure of descriptors. + bool has_offset_decoration = false; + context()->get_decoration_mgr()->ForEachDecoration( + type->result_id(), SpvDecorationOffset, + [&has_offset_decoration](const Instruction&) { + has_offset_decoration = true; + }); + return has_offset_decoration; +} + bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { std::vector access_chain_work_list; std::vector load_work_list; @@ -368,7 +376,8 @@ uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType( // The number of bindings consumed by a structure is the sum of the bindings // used by its members. - if (type_inst->opcode() == SpvOpTypeStruct) { + if (type_inst->opcode() == SpvOpTypeStruct && + !IsTypeOfStructuredBuffer(type_inst)) { uint32_t sum = 0; for (uint32_t i = 0; i < type_inst->NumInOperands(); i++) sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i)); diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h index c3aa0ea2b..cd72fd301 100644 --- a/source/opt/desc_sroa.h +++ b/source/opt/desc_sroa.h @@ -93,6 +93,11 @@ class DescriptorScalarReplacement : public Pass { // bindings used by its members. uint32_t GetNumBindingsUsedByType(uint32_t type_id); + // Returns true if |type| is a type that could be used for a structured buffer + // as opposed to a type that would be used for a structure of resource + // descriptors. + bool IsTypeOfStructuredBuffer(const Instruction* type) const; + // A map from an OpVariable instruction to the set of variables that will be // used to replace it. The entry |replacement_variables_[var][i]| is the id of // a variable that will be used in the place of the the ith element of the diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp index cdcc9a835..b35ad474d 100644 --- a/test/opt/desc_sroa_test.cpp +++ b/test/opt/desc_sroa_test.cpp @@ -729,6 +729,47 @@ TEST_F(DescriptorScalarReplacementTest, ResourceStructAsFunctionParam) { SinglePassRunAndMatch(checks + shader, true); } +TEST_F(DescriptorScalarReplacementTest, BindingForResourceArrayOfStructs) { + // Check that correct binding numbers are given to an array of descriptors + // to structs. + + const std::string shader = R"( +; CHECK: OpDecorate {{%\w+}} Binding 0 +; CHECK: OpDecorate {{%\w+}} Binding 1 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "psmain" + OpExecutionMode %2 OriginUpperLeft + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 4 + OpDecorate %_struct_4 Block + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 + %_struct_4 = OpTypeStruct %float %int +%_arr__struct_4_uint_2 = OpTypeArray %_struct_4 %uint_2 +%_ptr_Uniform__arr__struct_4_uint_2 = OpTypePointer Uniform %_arr__struct_4_uint_2 + %void = OpTypeVoid + %25 = OpTypeFunction %void +%_ptr_Uniform_int = OpTypePointer Uniform %int + %5 = OpVariable %_ptr_Uniform__arr__struct_4_uint_2 Uniform + %2 = OpFunction %void None %25 + %29 = OpLabel + %40 = OpAccessChain %_ptr_Uniform_int %5 %int_0 %int_1 + %41 = OpAccessChain %_ptr_Uniform_int %5 %int_1 %int_1 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(shader, true); +} + } // namespace } // namespace opt } // namespace spvtools