MSL: Refactor and fix use of quadgroup vs simdgroup.

This commit is contained in:
Hans-Kristian Arntzen 2022-02-28 11:58:33 +01:00
parent c08ee860c8
commit 5555f2784b
3 changed files with 35 additions and 30 deletions

View File

@ -221,7 +221,7 @@ struct SSBO
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{
uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0));

View File

@ -5370,7 +5370,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast(value, lane);");
else
statement("return simd_broadcast(value, lane);");
@ -5379,7 +5379,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast((ushort)value, lane);");
else
statement("return !!simd_broadcast((ushort)value, lane);");
@ -5388,7 +5388,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
else
statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
@ -5400,7 +5400,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupBroadcastFirst(T value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast_first(value);");
else
statement("return simd_broadcast_first(value);");
@ -5409,7 +5409,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupBroadcastFirst(bool value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast_first((ushort)value);");
else
statement("return !!simd_broadcast_first((ushort)value);");
@ -5418,7 +5418,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
else
statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
@ -5429,7 +5429,7 @@ void CompilerMSL::emit_custom_functions()
case SPVFuncImplSubgroupBallot:
statement("inline uint4 spvSubgroupBallot(bool value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
{
statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
}
@ -5557,7 +5557,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline bool spvSubgroupAllEqual(T value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == quad_broadcast_first(value)));");
else
statement("return simd_all(all(value == simd_broadcast_first(value)));");
@ -5566,7 +5566,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupAllEqual(bool value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_all(value) || !quad_any(value);");
else
statement("return simd_all(value) || !simd_any(value);");
@ -5575,7 +5575,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
else
statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
@ -5587,7 +5587,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffle(T value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle(value, lane);");
else
statement("return simd_shuffle(value, lane);");
@ -5596,7 +5596,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle((ushort)value, lane);");
else
statement("return !!simd_shuffle((ushort)value, lane);");
@ -5605,7 +5605,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
else
statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
@ -5617,7 +5617,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_xor(value, mask);");
else
statement("return simd_shuffle_xor(value, mask);");
@ -5626,7 +5626,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_xor((ushort)value, mask);");
else
statement("return !!simd_shuffle_xor((ushort)value, mask);");
@ -5635,7 +5635,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
else
statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
@ -5647,7 +5647,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_up(value, delta);");
else
statement("return simd_shuffle_up(value, delta);");
@ -5656,7 +5656,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_up((ushort)value, delta);");
else
statement("return !!simd_shuffle_up((ushort)value, delta);");
@ -5665,7 +5665,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
else
statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
@ -5677,7 +5677,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_down(value, delta);");
else
statement("return simd_shuffle_down(value, delta);");
@ -5686,7 +5686,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_down((ushort)value, delta);");
else
statement("return !!simd_shuffle_down((ushort)value, delta);");
@ -5695,7 +5695,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
else
statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
@ -13972,7 +13972,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
switch (op)
{
case OpGroupNonUniformElect:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
emit_op(result_type, id, "quad_is_first()", false);
else
emit_op(result_type, id, "simd_is_first()", false);
@ -14045,14 +14045,14 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
break;
case OpGroupNonUniformAll:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_all");
else
emit_unary_func_op(result_type, id, ops[3], "simd_all");
break;
case OpGroupNonUniformAny:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_any");
else
emit_unary_func_op(result_type, id, ops[3], "simd_any");
@ -14550,7 +14550,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
case BuiltInSubgroupId:
if (msl_options.emulate_subgroups)
@ -14558,7 +14558,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
case BuiltInSubgroupLocalInvocationId:
if (msl_options.emulate_subgroups)
@ -14577,7 +14577,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
// We are generating a Metal kernel function.
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
}
else
SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");

View File

@ -393,7 +393,7 @@ public:
// and will be addressed using the current ViewIndex.
bool arrayed_subpass_input = false;
// Whether to use SIMD-group or quadgroup functions to implement group nnon-uniform
// Whether to use SIMD-group or quadgroup functions to implement group non-uniform
// operations. Some GPUs on iOS do not support the SIMD-group functions, only the
// quadgroup functions.
bool ios_use_simdgroup_functions = false;
@ -445,6 +445,11 @@ public:
return platform == macOS;
}
bool use_quadgroup_operation() const
{
return is_ios() && !ios_use_simdgroup_functions;
}
void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
{
msl_version = make_msl_version(major, minor, patch);