MSL: Adjust BuiltInWorkgroupId for vkCmdDispatchBase().

This command allows the caller to set the base value of
`BuiltInWorkgroupId`, and thus of `BuiltInGlobalInvocationId`. Metal
provides no direct support for this... but it does provide a builtin,
`[[grid_origin]]`, normally used to pass the base values for the stage
input region, which we will now abuse to pass the dispatch base and
avoid burning a buffer binding.

`[[grid_origin]]`, as part of Metal's support for compute stage input,
requires MSL 1.2. For 1.0 and 1.1, we're forced to provide a buffer.

(Curiously, this builtin was undocumented until the MSL 2.2 release. Go
figure.)
This commit is contained in:
Chip Davis 2019-07-22 13:08:04 -05:00
parent 07bb1a53e0
commit fb5ee4cb5c
11 changed files with 325 additions and 3 deletions

View File

@ -516,6 +516,7 @@ struct CLIArguments
bool msl_texture_buffer_native = false;
bool msl_multiview = false;
bool msl_view_index_from_device_index = false;
bool msl_dispatch_base = false;
bool glsl_emit_push_constant_as_ubo = false;
bool glsl_emit_ubo_as_plain_uniforms = false;
bool emit_line_directives = false;
@ -596,6 +597,7 @@ static void print_help()
"\t[--msl-discrete-descriptor-set <index>]\n"
"\t[--msl-multiview]\n"
"\t[--msl-view-index-from-device-index]\n"
"\t[--msl-dispatch-base]\n"
"\t[--hlsl]\n"
"\t[--reflect]\n"
"\t[--shader-model]\n"
@ -756,6 +758,7 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
msl_opts.texture_buffer_native = args.msl_texture_buffer_native;
msl_opts.multiview = args.msl_multiview;
msl_opts.view_index_from_device_index = args.msl_view_index_from_device_index;
msl_opts.dispatch_base = args.msl_dispatch_base;
msl_comp->set_msl_options(msl_opts);
for (auto &v : args.msl_discrete_descriptor_sets)
msl_comp->add_discrete_descriptor_set(v);
@ -1078,6 +1081,7 @@ static int main_inner(int argc, char *argv[])
cbs.add("--msl-multiview", [&args](CLIParser &) { args.msl_multiview = true; });
cbs.add("--msl-view-index-from-device-index",
[&args](CLIParser &) { args.msl_view_index_from_device_index = true; });
cbs.add("--msl-dispatch-base", [&args](CLIParser &) { args.msl_dispatch_base = true; });
cbs.add("--extension", [&args](CLIParser &parser) { args.extensions.push_back(parser.next_string()); });
cbs.add("--rename-entry-point", [&args](CLIParser &parser) {
auto old_name = parser.next_string();

View File

@ -0,0 +1,38 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
struct SSBO
{
float4 in_data[1];
};
struct SSBO2
{
float4 out_data[1];
};
struct SSBO3
{
uint counter;
};
constant uint _59_tmp [[function_constant(10)]];
constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u;
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u);
kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 spvDispatchBase [[grid_origin]])
{
gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize;
float4 _33 = _27.in_data[gl_GlobalInvocationID.x];
if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
{
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
_49.out_data[_56] = _33;
}
}

View File

@ -0,0 +1,34 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
struct SSBO
{
float4 in_data[1];
};
struct SSBO2
{
float4 out_data[1];
};
struct SSBO3
{
uint counter;
};
kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1);
float4 _33 = _27.in_data[gl_GlobalInvocationID.x];
if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
{
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
_49.out_data[_56] = _33;
}
}

View File

@ -0,0 +1,41 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
struct SSBO
{
float4 in_data[1];
};
struct SSBO2
{
float4 out_data[1];
};
struct SSBO3
{
uint counter;
};
constant uint _59_tmp [[function_constant(10)]];
constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u;
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u);
kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 spvDispatchBase [[grid_origin]])
{
gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize;
gl_WorkGroupID += spvDispatchBase;
uint ident = gl_GlobalInvocationID.x;
uint workgroup = gl_WorkGroupID.x;
float4 idata = _27.in_data[ident];
if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
{
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
_49.out_data[_56] = idata;
}
}

View File

@ -0,0 +1,37 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
struct SSBO
{
float4 in_data[1];
};
struct SSBO2
{
float4 out_data[1];
};
struct SSBO3
{
uint counter;
};
kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1);
gl_WorkGroupID += spvDispatchBase;
uint ident = gl_GlobalInvocationID.x;
uint workgroup = gl_WorkGroupID.x;
float4 idata = _27.in_data[ident];
if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
{
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
_49.out_data[_56] = idata;
}
}

View File

@ -0,0 +1,29 @@
#version 310 es
layout(local_size_x_id = 10) in;
layout(std430, binding = 0) readonly buffer SSBO
{
vec4 in_data[];
};
layout(std430, binding = 1) writeonly buffer SSBO2
{
vec4 out_data[];
};
layout(std430, binding = 2) buffer SSBO3
{
uint counter;
};
void main()
{
uint ident = gl_GlobalInvocationID.x;
uint workgroup = gl_WorkGroupID.x;
vec4 idata = in_data[ident];
if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2)
{
out_data[atomicAdd(counter, 1u)] = idata;
}
}

View File

@ -0,0 +1,29 @@
#version 310 es
layout(local_size_x = 1) in;
layout(std430, binding = 0) readonly buffer SSBO
{
vec4 in_data[];
};
layout(std430, binding = 1) writeonly buffer SSBO2
{
vec4 out_data[];
};
layout(std430, binding = 2) buffer SSBO3
{
uint counter;
};
void main()
{
uint ident = gl_GlobalInvocationID.x;
uint workgroup = gl_WorkGroupID.x;
vec4 idata = in_data[ident];
if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2)
{
out_data[atomicAdd(counter, 1u)] = idata;
}
}

View File

@ -1433,6 +1433,10 @@ enum ExtendedDecorations
// Marks a buffer block for using explicit offsets (GLSL/HLSL).
SPIRVCrossDecorationExplicitOffset,
// Apply to a variable in the Input storage class; marks it as holding the base group passed to vkCmdDispatchBase().
// In MSL, this is used to adjust the WorkgroupId and GlobalInvocationId variables.
SPIRVCrossDecorationBuiltInDispatchBase,
SPIRVCrossDecorationCount
};

View File

@ -107,8 +107,11 @@ void CompilerMSL::build_implicit_builtins()
active_input_builtins.get(BuiltInSubgroupGtMask));
bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
(msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
bool need_dispatch_base =
msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
(active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
need_multiview || needs_subgroup_invocation_id)
need_multiview || need_dispatch_base || needs_subgroup_invocation_id)
{
bool has_frag_coord = false;
bool has_sample_id = false;
@ -121,6 +124,7 @@ void CompilerMSL::build_implicit_builtins()
bool has_subgroup_invocation_id = false;
bool has_subgroup_size = false;
bool has_view_idx = false;
uint32_t workgroup_id_type = 0;
ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
@ -208,6 +212,13 @@ void CompilerMSL::build_implicit_builtins()
has_view_idx = true;
}
}
// The base workgroup needs to have the same type and vector size
// as the workgroup or invocation ID, so keep track of the type that
// was used.
if (need_dispatch_base && workgroup_id_type == 0 &&
(builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
workgroup_id_type = var.basetype;
});
if (!has_frag_coord && need_subpass_input)
@ -457,6 +468,42 @@ void CompilerMSL::build_implicit_builtins()
builtin_subgroup_size_id = var_id;
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
}
if (need_dispatch_base)
{
uint32_t var_id;
if (msl_options.supports_msl_version(1, 2))
{
// If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
// to convey this information and save a buffer slot.
uint32_t offset = ir.increase_bound_by(1);
var_id = offset;
set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
get_entry_point().interface_variables.push_back(var_id);
}
else
{
// Otherwise, we need to fall back to a good ol' fashioned buffer.
uint32_t offset = ir.increase_bound_by(2);
var_id = offset;
uint32_t type_id = offset + 1;
SPIRType var_type = get<SPIRType>(workgroup_id_type);
var_type.storage = StorageClassUniform;
set<SPIRType>(type_id, var_type);
set<SPIRVariable>(var_id, type_id, StorageClassUniform);
// This should never match anything.
set_decoration(var_id, DecorationDescriptorSet, ~(5u));
set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
msl_options.indirect_params_buffer_index);
}
set_name(var_id, "spvDispatchBase");
builtin_dispatch_base_id = var_id;
}
}
if (needs_swizzle_buffer_def)
@ -802,6 +849,8 @@ string CompilerMSL::compile()
active_interface_variables.insert(view_mask_buffer_id);
if (builtin_layer_id)
active_interface_variables.insert(builtin_layer_id);
if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
active_interface_variables.insert(builtin_dispatch_base_id);
// Create structs to hold input, output and uniform variables.
// Do output first to ensure out. is declared at top of entry function.
@ -6748,6 +6797,19 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
ep_args += "]]";
}
}
if (var.storage == StorageClassInput &&
has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
{
// This is a special implicit builtin, not corresponding to any SPIR-V builtin,
// which holds the base that was passed to vkCmdDispatchBase(). If it's present,
// assume we emitted it for a good reason.
assert(msl_options.supports_msl_version(1, 2));
if (!ep_args.empty())
ep_args += ", ";
ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
}
});
// Correct the types of all encountered active builtins. We couldn't do this before
@ -7023,7 +7085,11 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
default:
if (!ep_args.empty())
ep_args += ", ";
ep_args += type_to_glsl(type, var_id) + " " + r.name;
if (!type.pointer)
ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
type_to_glsl(type, var_id) + "& " + r.name;
else
ep_args += type_to_glsl(type, var_id) + " " + r.name;
ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
break;
}
@ -7343,6 +7409,35 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
msl_options.device_index, ";");
});
break;
case BuiltInWorkgroupId:
if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
break;
// The vkCmdDispatchBase() command lets the client set the base value
// of WorkgroupId. Metal has no direct equivalent; we must make this
// adjustment ourselves.
entry_func.fixup_hooks_in.push_back([=]() {
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
});
break;
case BuiltInGlobalInvocationId:
if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
break;
// GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
// This needs to be adjusted too.
entry_func.fixup_hooks_in.push_back([=]() {
auto &execution = get_entry_point();
uint32_t workgroup_size_id = execution.workgroup_size.constant;
if (workgroup_size_id)
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
" * ", to_expression(workgroup_size_id), ";");
else
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
" * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
execution.workgroup_size.z, ");");
});
break;
default:
break;
}

View File

@ -198,6 +198,7 @@ public:
bool tess_domain_origin_lower_left = false;
bool multiview = false;
bool view_index_from_device_index = false;
bool dispatch_base = false;
// Enable use of MSL 2.0 indirect argument buffers.
// MSL 2.0 must also be enabled.
@ -225,7 +226,7 @@ public:
msl_version = make_msl_version(major, minor, patch);
}
bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) const
{
return msl_version >= make_msl_version(major, minor, patch);
}
@ -276,6 +277,13 @@ public:
return msl_options.multiview && !msl_options.view_index_from_device_index;
}
// Provide feedback to calling API to allow it to pass a buffer
// containing the dispatch base workgroup ID.
bool needs_dispatch_base_buffer() const
{
return msl_options.dispatch_base && !msl_options.supports_msl_version(1, 2);
}
// Provide feedback to calling API to allow it to pass an output
// buffer if the shader needs it.
bool needs_output_buffer() const
@ -563,6 +571,7 @@ protected:
uint32_t builtin_primitive_id_id = 0;
uint32_t builtin_subgroup_invocation_id_id = 0;
uint32_t builtin_subgroup_size_id = 0;
uint32_t builtin_dispatch_base_id = 0;
uint32_t swizzle_buffer_id = 0;
uint32_t buffer_size_buffer_id = 0;
uint32_t view_mask_buffer_id = 0;

View File

@ -207,6 +207,8 @@ def cross_compile_msl(shader, spirv, opt, iterations, paths):
msl_args.append('--msl-multiview')
if '.viewfromdev.' in shader:
msl_args.append('--msl-view-index-from-device-index')
if '.dispatchbase.' in shader:
msl_args.append('--msl-dispatch-base')
subprocess.check_call(msl_args)