MSL: Support descriptor sets with recursive content when using argument buffers.

When using argument buffers, handle descriptor set entry points with
recursive content, similar to discrete entry points with recursive content.

- For descriptor sets entry points with recursive content, add
  descriptor set to recursive_inputs, and create a local var for it.

- For recursive entry points that are contained in a descriptor set
  argument buffer, don't add entry point to recursive_inputs, or create
  a local var for that content entry point.

- Add test shader.
This commit is contained in:
Bill Hollings 2024-06-14 13:32:26 -04:00
parent 2d990d355a
commit 488559ff4c
4 changed files with 116 additions and 5 deletions

View File

@ -0,0 +1,33 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct recurs;
struct recurs
{
int m1;
device recurs* m2;
};
struct recurs_1
{
int m1;
device recurs_1* m2;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
struct spvDescriptorSetBuffer0
{
device recurs* nums [[id(0)]];
texture2d<uint, access::write> tex [[id(1)]];
};
kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp;
spvDescriptorSet0.tex.write(uint4(uint(((*spvDescriptorSet0.nums).m1 + (*spvDescriptorSet0.nums).m2->m1) + (*spvDescriptorSet0.nums).m2->m2->m1), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy)));
}

View File

@ -0,0 +1,37 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct recurs;
struct recurs
{
int m1;
device recurs* m2;
};
struct recurs_1
{
int m1;
device recurs_1* m2;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
struct spvDescriptorSetBuffer0
{
device recurs* nums [[id(0)]];
texture2d<uint, access::write> tex [[id(1)]];
};
kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp;
int rslt = 0;
rslt += (*spvDescriptorSet0.nums).m1;
rslt += (*spvDescriptorSet0.nums).m2->m1;
rslt += (*spvDescriptorSet0.nums).m2->m2->m1;
spvDescriptorSet0.tex.write(uint4(uint(rslt), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy)));
}

View File

@ -0,0 +1,21 @@
#version 450
#extension GL_EXT_buffer_reference2 : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(buffer_reference) buffer recurs;
layout(buffer_reference, buffer_reference_align = 16, set = 0, binding = 1, std140) buffer recurs
{
int m1;
recurs m2;
} nums;
layout(set = 0, binding = 0, r32ui) uniform writeonly uimage2D tex;
void main()
{
int rslt = 0;
rslt += nums.m1;
rslt += nums.m2.m1;
rslt += nums.m2.m2.m1;
imageStore(tex, ivec2(gl_GlobalInvocationID.xy), uvec4(rslt, 0u, 0u, 1u));
}

View File

@ -13602,7 +13602,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
claimed_bindings.set(buffer_binding);
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += get_argument_address_space(var) + " ";
if (recursive_inputs.count(type.self))
ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp";
else
ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
@ -14053,7 +14059,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
}
}
if (msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
if ( !msl_options.argument_buffers &&
msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
(var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
{
@ -18340,7 +18347,8 @@ void CompilerMSL::analyze_argument_buffers()
else
buffer_type.storage = StorageClassUniform;
set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set);
set_name(type_id, buffer_type_name);
auto &ptr_type = set<SPIRType>(ptr_type_id, OpTypePointer);
ptr_type = buffer_type;
@ -18350,8 +18358,9 @@ void CompilerMSL::analyze_argument_buffers()
ptr_type.parent_type = type_id;
uint32_t buffer_variable_id = next_id;
set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
auto &buffer_var = set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
auto buffer_name = join("spvDescriptorSet", desc_set);
set_name(buffer_variable_id, buffer_name);
// Ids must be emitted in ID order.
stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
@ -18559,6 +18568,17 @@ void CompilerMSL::analyze_argument_buffers()
set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding);
member_index++;
}
if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type))
{
recursive_inputs.insert(type_id);
auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
auto addr_space = get_argument_address_space(buffer_var);
entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() {
statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;");
});
}
}
}