diff --git a/main.cpp b/main.cpp index 00db1be1..6340e0e8 100644 --- a/main.cpp +++ b/main.cpp @@ -516,6 +516,7 @@ struct CLIArguments bool msl_texture_buffer_native = false; bool msl_multiview = false; bool msl_view_index_from_device_index = false; + bool msl_dispatch_base = false; bool glsl_emit_push_constant_as_ubo = false; bool glsl_emit_ubo_as_plain_uniforms = false; bool emit_line_directives = false; @@ -596,6 +597,7 @@ static void print_help() "\t[--msl-discrete-descriptor-set ]\n" "\t[--msl-multiview]\n" "\t[--msl-view-index-from-device-index]\n" + "\t[--msl-dispatch-base]\n" "\t[--hlsl]\n" "\t[--reflect]\n" "\t[--shader-model]\n" @@ -756,6 +758,7 @@ static string compile_iteration(const CLIArguments &args, std::vector msl_opts.texture_buffer_native = args.msl_texture_buffer_native; msl_opts.multiview = args.msl_multiview; msl_opts.view_index_from_device_index = args.msl_view_index_from_device_index; + msl_opts.dispatch_base = args.msl_dispatch_base; msl_comp->set_msl_options(msl_opts); for (auto &v : args.msl_discrete_descriptor_sets) msl_comp->add_discrete_descriptor_set(v); @@ -1078,6 +1081,7 @@ static int main_inner(int argc, char *argv[]) cbs.add("--msl-multiview", [&args](CLIParser &) { args.msl_multiview = true; }); cbs.add("--msl-view-index-from-device-index", [&args](CLIParser &) { args.msl_view_index_from_device_index = true; }); + cbs.add("--msl-dispatch-base", [&args](CLIParser &) { args.msl_dispatch_base = true; }); cbs.add("--extension", [&args](CLIParser &parser) { args.extensions.push_back(parser.next_string()); }); cbs.add("--rename-entry-point", [&args](CLIParser &parser) { auto old_name = parser.next_string(); diff --git a/reference/opt/shaders-msl/comp/basic.dispatchbase.comp b/reference/opt/shaders-msl/comp/basic.dispatchbase.comp new file mode 100644 index 00000000..ebbc144c --- /dev/null +++ b/reference/opt/shaders-msl/comp/basic.dispatchbase.comp @@ -0,0 +1,38 @@ +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + float4 in_data[1]; +}; + +struct SSBO2 +{ + float4 out_data[1]; +}; + +struct SSBO3 +{ + uint counter; +}; + +constant uint _59_tmp [[function_constant(10)]]; +constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u; +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u); + +kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 spvDispatchBase [[grid_origin]]) +{ + gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize; + float4 _33 = _27.in_data[gl_GlobalInvocationID.x]; + if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875) + { + uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed); + _49.out_data[_56] = _33; + } +} + diff --git a/reference/opt/shaders-msl/comp/basic.dispatchbase.msl11.comp b/reference/opt/shaders-msl/comp/basic.dispatchbase.msl11.comp new file mode 100644 index 00000000..8c3d2576 --- /dev/null +++ b/reference/opt/shaders-msl/comp/basic.dispatchbase.msl11.comp @@ -0,0 +1,34 @@ +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + float4 in_data[1]; +}; + +struct SSBO2 +{ + float4 out_data[1]; +}; + +struct SSBO3 +{ + uint counter; +}; + +kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1); + float4 _33 = _27.in_data[gl_GlobalInvocationID.x]; + if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875) + { + uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed); + _49.out_data[_56] = _33; + } +} + diff --git a/reference/shaders-msl/comp/basic.dispatchbase.comp b/reference/shaders-msl/comp/basic.dispatchbase.comp new file mode 100644 index 00000000..92d517cf --- /dev/null +++ b/reference/shaders-msl/comp/basic.dispatchbase.comp @@ -0,0 +1,41 @@ +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + float4 in_data[1]; +}; + +struct SSBO2 +{ + float4 out_data[1]; +}; + +struct SSBO3 +{ + uint counter; +}; + +constant uint _59_tmp [[function_constant(10)]]; +constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u; +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u); + +kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 spvDispatchBase [[grid_origin]]) +{ + gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize; + gl_WorkGroupID += spvDispatchBase; + uint ident = gl_GlobalInvocationID.x; + uint workgroup = gl_WorkGroupID.x; + float4 idata = _27.in_data[ident]; + if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875) + { + uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed); + _49.out_data[_56] = idata; + } +} + diff --git a/reference/shaders-msl/comp/basic.dispatchbase.msl11.comp b/reference/shaders-msl/comp/basic.dispatchbase.msl11.comp new file mode 100644 index 00000000..084518a5 --- /dev/null +++ b/reference/shaders-msl/comp/basic.dispatchbase.msl11.comp @@ -0,0 +1,37 @@ +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + float4 in_data[1]; +}; + +struct SSBO2 +{ + float4 out_data[1]; +}; + +struct SSBO3 +{ + uint counter; +}; + +kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1); + gl_WorkGroupID += spvDispatchBase; + uint ident = gl_GlobalInvocationID.x; + uint workgroup = gl_WorkGroupID.x; + float4 idata = _27.in_data[ident]; + if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875) + { + uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed); + _49.out_data[_56] = idata; + } +} + diff --git a/shaders-msl/comp/basic.dispatchbase.comp b/shaders-msl/comp/basic.dispatchbase.comp new file mode 100644 index 00000000..2c873468 --- /dev/null +++ b/shaders-msl/comp/basic.dispatchbase.comp @@ -0,0 +1,29 @@ +#version 310 es +layout(local_size_x_id = 10) in; + +layout(std430, binding = 0) readonly buffer SSBO +{ + vec4 in_data[]; +}; + +layout(std430, binding = 1) writeonly buffer SSBO2 +{ + vec4 out_data[]; +}; + +layout(std430, binding = 2) buffer SSBO3 +{ + uint counter; +}; + +void main() +{ + uint ident = gl_GlobalInvocationID.x; + uint workgroup = gl_WorkGroupID.x; + vec4 idata = in_data[ident]; + if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2) + { + out_data[atomicAdd(counter, 1u)] = idata; + } +} + diff --git a/shaders-msl/comp/basic.dispatchbase.msl11.comp b/shaders-msl/comp/basic.dispatchbase.msl11.comp new file mode 100644 index 00000000..91453332 --- /dev/null +++ b/shaders-msl/comp/basic.dispatchbase.msl11.comp @@ -0,0 +1,29 @@ +#version 310 es +layout(local_size_x = 1) in; + +layout(std430, binding = 0) readonly buffer SSBO +{ + vec4 in_data[]; +}; + +layout(std430, binding = 1) writeonly buffer SSBO2 +{ + vec4 out_data[]; +}; + +layout(std430, binding = 2) buffer SSBO3 +{ + uint counter; +}; + +void main() +{ + uint ident = gl_GlobalInvocationID.x; + uint workgroup = gl_WorkGroupID.x; + vec4 idata = in_data[ident]; + if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2) + { + out_data[atomicAdd(counter, 1u)] = idata; + } +} + diff --git a/spirv_common.hpp b/spirv_common.hpp index bc626436..3db55cf5 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -1433,6 +1433,10 @@ enum ExtendedDecorations // Marks a buffer block for using explicit offsets (GLSL/HLSL). SPIRVCrossDecorationExplicitOffset, + // Apply to a variable in the Input storage class; marks it as holding the base group passed to vkCmdDispatchBase(). + // In MSL, this is used to adjust the WorkgroupId and GlobalInvocationId variables. + SPIRVCrossDecorationBuiltInDispatchBase, + SPIRVCrossDecorationCount }; diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 9f33034f..05d0421c 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -107,8 +107,11 @@ void CompilerMSL::build_implicit_builtins() active_input_builtins.get(BuiltInSubgroupGtMask)); bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index && (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex)); + bool need_dispatch_base = + msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute && + (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId)); if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params || - need_multiview || needs_subgroup_invocation_id) + need_multiview || need_dispatch_base || needs_subgroup_invocation_id) { bool has_frag_coord = false; bool has_sample_id = false; @@ -121,6 +124,7 @@ void CompilerMSL::build_implicit_builtins() bool has_subgroup_invocation_id = false; bool has_subgroup_size = false; bool has_view_idx = false; + uint32_t workgroup_id_type = 0; ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin) @@ -208,6 +212,13 @@ void CompilerMSL::build_implicit_builtins() has_view_idx = true; } } + + // The base workgroup needs to have the same type and vector size + // as the workgroup or invocation ID, so keep track of the type that + // was used. + if (need_dispatch_base && workgroup_id_type == 0 && + (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId)) + workgroup_id_type = var.basetype; }); if (!has_frag_coord && need_subpass_input) @@ -457,6 +468,42 @@ void CompilerMSL::build_implicit_builtins() builtin_subgroup_size_id = var_id; mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id); } + + if (need_dispatch_base) + { + uint32_t var_id; + if (msl_options.supports_msl_version(1, 2)) + { + // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin + // to convey this information and save a buffer slot. + uint32_t offset = ir.increase_bound_by(1); + var_id = offset; + + set(var_id, workgroup_id_type, StorageClassInput); + set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase); + get_entry_point().interface_variables.push_back(var_id); + } + else + { + // Otherwise, we need to fall back to a good ol' fashioned buffer. + uint32_t offset = ir.increase_bound_by(2); + var_id = offset; + uint32_t type_id = offset + 1; + + SPIRType var_type = get(workgroup_id_type); + var_type.storage = StorageClassUniform; + set(type_id, var_type); + + set(var_id, type_id, StorageClassUniform); + // This should never match anything. + set_decoration(var_id, DecorationDescriptorSet, ~(5u)); + set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index); + set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, + msl_options.indirect_params_buffer_index); + } + set_name(var_id, "spvDispatchBase"); + builtin_dispatch_base_id = var_id; + } } if (needs_swizzle_buffer_def) @@ -802,6 +849,8 @@ string CompilerMSL::compile() active_interface_variables.insert(view_mask_buffer_id); if (builtin_layer_id) active_interface_variables.insert(builtin_layer_id); + if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2)) + active_interface_variables.insert(builtin_dispatch_base_id); // Create structs to hold input, output and uniform variables. // Do output first to ensure out. is declared at top of entry function. @@ -6748,6 +6797,19 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args) ep_args += "]]"; } } + + if (var.storage == StorageClassInput && + has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase)) + { + // This is a special implicit builtin, not corresponding to any SPIR-V builtin, + // which holds the base that was passed to vkCmdDispatchBase(). If it's present, + // assume we emitted it for a good reason. + assert(msl_options.supports_msl_version(1, 2)); + if (!ep_args.empty()) + ep_args += ", "; + + ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]"; + } }); // Correct the types of all encountered active builtins. We couldn't do this before @@ -7023,7 +7085,11 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) default: if (!ep_args.empty()) ep_args += ", "; - ep_args += type_to_glsl(type, var_id) + " " + r.name; + if (!type.pointer) + ep_args += get_type_address_space(get(var.basetype), var_id) + " " + + type_to_glsl(type, var_id) + "& " + r.name; + else + ep_args += type_to_glsl(type, var_id) + " " + r.name; ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]"; break; } @@ -7343,6 +7409,35 @@ void CompilerMSL::fix_up_shader_inputs_outputs() msl_options.device_index, ";"); }); break; + case BuiltInWorkgroupId: + if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId)) + break; + + // The vkCmdDispatchBase() command lets the client set the base value + // of WorkgroupId. Metal has no direct equivalent; we must make this + // adjustment ourselves. + entry_func.fixup_hooks_in.push_back([=]() { + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";"); + }); + break; + case BuiltInGlobalInvocationId: + if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId)) + break; + + // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize. + // This needs to be adjusted too. + entry_func.fixup_hooks_in.push_back([=]() { + auto &execution = get_entry_point(); + uint32_t workgroup_size_id = execution.workgroup_size.constant; + if (workgroup_size_id) + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), + " * ", to_expression(workgroup_size_id), ";"); + else + statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), + " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ", + execution.workgroup_size.z, ");"); + }); + break; default: break; } diff --git a/spirv_msl.hpp b/spirv_msl.hpp index abd481b3..21473e2d 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -198,6 +198,7 @@ public: bool tess_domain_origin_lower_left = false; bool multiview = false; bool view_index_from_device_index = false; + bool dispatch_base = false; // Enable use of MSL 2.0 indirect argument buffers. // MSL 2.0 must also be enabled. @@ -225,7 +226,7 @@ public: msl_version = make_msl_version(major, minor, patch); } - bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) + bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) const { return msl_version >= make_msl_version(major, minor, patch); } @@ -276,6 +277,13 @@ public: return msl_options.multiview && !msl_options.view_index_from_device_index; } + // Provide feedback to calling API to allow it to pass a buffer + // containing the dispatch base workgroup ID. + bool needs_dispatch_base_buffer() const + { + return msl_options.dispatch_base && !msl_options.supports_msl_version(1, 2); + } + // Provide feedback to calling API to allow it to pass an output // buffer if the shader needs it. bool needs_output_buffer() const @@ -563,6 +571,7 @@ protected: uint32_t builtin_primitive_id_id = 0; uint32_t builtin_subgroup_invocation_id_id = 0; uint32_t builtin_subgroup_size_id = 0; + uint32_t builtin_dispatch_base_id = 0; uint32_t swizzle_buffer_id = 0; uint32_t buffer_size_buffer_id = 0; uint32_t view_mask_buffer_id = 0; diff --git a/test_shaders.py b/test_shaders.py index b3f69253..73b72865 100755 --- a/test_shaders.py +++ b/test_shaders.py @@ -207,6 +207,8 @@ def cross_compile_msl(shader, spirv, opt, iterations, paths): msl_args.append('--msl-multiview') if '.viewfromdev.' in shader: msl_args.append('--msl-view-index-from-device-index') + if '.dispatchbase.' in shader: + msl_args.append('--msl-dispatch-base') subprocess.check_call(msl_args)