From 488559ff4c2e7dcb7cd256bffaa649a8dbe45cef Mon Sep 17 00:00:00 2001 From: Bill Hollings Date: Fri, 14 Jun 2024 13:32:26 -0400 Subject: [PATCH 1/2] MSL: Support descriptor sets with recursive content when using argument buffers. When using argument buffers, handle descriptor set entry points with recursive content, similar to discrete entry points with recursive content. - For descriptor sets entry points with recursive content, add descriptor set to recursive_inputs, and create a local var for it. - For recursive entry points that are contained in a descriptor set argument buffer, don't add entry point to recursive_inputs, or create a local var for that content entry point. - Add test shader. --- ...eplace-recursive-inputs.msl3.argument.comp | 33 +++++++++++++++++ ...eplace-recursive-inputs.msl3.argument.comp | 37 +++++++++++++++++++ ...eplace-recursive-inputs.msl3.argument.comp | 21 +++++++++++ spirv_msl.cpp | 30 ++++++++++++--- 4 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 reference/opt/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp create mode 100644 reference/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp create mode 100644 shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp diff --git a/reference/opt/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp b/reference/opt/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp new file mode 100644 index 00000000..90d105bd --- /dev/null +++ b/reference/opt/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp @@ -0,0 +1,33 @@ +#include +#include + +using namespace metal; + +struct recurs; + +struct recurs +{ + int m1; + device recurs* m2; +}; + +struct recurs_1 +{ + int m1; + device recurs_1* m2; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +struct spvDescriptorSetBuffer0 +{ + device recurs* nums [[id(0)]]; + texture2d tex [[id(1)]]; +}; + +kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp; + spvDescriptorSet0.tex.write(uint4(uint(((*spvDescriptorSet0.nums).m1 + (*spvDescriptorSet0.nums).m2->m1) + (*spvDescriptorSet0.nums).m2->m2->m1), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy))); +} + diff --git a/reference/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp b/reference/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp new file mode 100644 index 00000000..4dfda270 --- /dev/null +++ b/reference/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp @@ -0,0 +1,37 @@ +#include +#include + +using namespace metal; + +struct recurs; + +struct recurs +{ + int m1; + device recurs* m2; +}; + +struct recurs_1 +{ + int m1; + device recurs_1* m2; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +struct spvDescriptorSetBuffer0 +{ + device recurs* nums [[id(0)]]; + texture2d tex [[id(1)]]; +}; + +kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp; + int rslt = 0; + rslt += (*spvDescriptorSet0.nums).m1; + rslt += (*spvDescriptorSet0.nums).m2->m1; + rslt += (*spvDescriptorSet0.nums).m2->m2->m1; + spvDescriptorSet0.tex.write(uint4(uint(rslt), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy))); +} + diff --git a/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp b/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp new file mode 100644 index 00000000..ce776525 --- /dev/null +++ b/shaders-msl/comp/metal3_1_regression_patch.replace-recursive-inputs.msl3.argument.comp @@ -0,0 +1,21 @@ +#version 450 +#extension GL_EXT_buffer_reference2 : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(buffer_reference) buffer recurs; +layout(buffer_reference, buffer_reference_align = 16, set = 0, binding = 1, std140) buffer recurs +{ + int m1; + recurs m2; +} nums; + +layout(set = 0, binding = 0, r32ui) uniform writeonly uimage2D tex; + +void main() +{ + int rslt = 0; + rslt += nums.m1; + rslt += nums.m2.m1; + rslt += nums.m2.m2.m1; + imageStore(tex, ivec2(gl_GlobalInvocationID.xy), uvec4(rslt, 0u, 0u, 1u)); +} diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 0682caf6..5eca5ada 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -13602,7 +13602,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) claimed_bindings.set(buffer_binding); - ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); + ep_args += get_argument_address_space(var) + " "; + + if (recursive_inputs.count(type.self)) + ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp"; + else + ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); + ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]"; next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1); @@ -14053,7 +14059,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs() } } - if (msl_options.replace_recursive_inputs && type_contains_recursion(type) && + if ( !msl_options.argument_buffers && + msl_options.replace_recursive_inputs && type_contains_recursion(type) && (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)) { @@ -18340,7 +18347,8 @@ void CompilerMSL::analyze_argument_buffers() else buffer_type.storage = StorageClassUniform; - set_name(type_id, join("spvDescriptorSetBuffer", desc_set)); + auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set); + set_name(type_id, buffer_type_name); auto &ptr_type = set(ptr_type_id, OpTypePointer); ptr_type = buffer_type; @@ -18350,8 +18358,9 @@ void CompilerMSL::analyze_argument_buffers() ptr_type.parent_type = type_id; uint32_t buffer_variable_id = next_id; - set(buffer_variable_id, ptr_type_id, StorageClassUniform); - set_name(buffer_variable_id, join("spvDescriptorSet", desc_set)); + auto &buffer_var = set(buffer_variable_id, ptr_type_id, StorageClassUniform); + auto buffer_name = join("spvDescriptorSet", desc_set); + set_name(buffer_variable_id, buffer_name); // Ids must be emitted in ID order. stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool { @@ -18559,6 +18568,17 @@ void CompilerMSL::analyze_argument_buffers() set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding); member_index++; } + + if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type)) + { + recursive_inputs.insert(type_id); + auto &entry_func = this->get(ir.default_entry_point); + auto addr_space = get_argument_address_space(buffer_var); + entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() { + statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;"); + }); + } + } } From d7ad3d72578c463225e998ba79b71e743ec2c8c4 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Mon, 17 Jun 2024 12:44:22 +0200 Subject: [PATCH 2/2] Apply suggestions from code review Fix nits --- spirv_msl.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 5eca5ada..2cd7666b 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -14059,8 +14059,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs() } } - if ( !msl_options.argument_buffers && - msl_options.replace_recursive_inputs && type_contains_recursion(type) && + if (!msl_options.argument_buffers && + msl_options.replace_recursive_inputs && type_contains_recursion(type) && (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)) { @@ -18578,7 +18578,6 @@ void CompilerMSL::analyze_argument_buffers() statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;"); }); } - } }