Merge pull request #1872 from KhronosGroup/fix-1867

MSL: Refactor and fix use of quadgroup vs simdgroup.
This commit is contained in:
Hans-Kristian Arntzen 2022-02-28 12:50:43 +01:00 committed by GitHub
commit 02440e85cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 30 deletions

View File

@ -221,7 +221,7 @@ struct SSBO
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]]) kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{ {
uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0)); uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0)); uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0));

View File

@ -5370,7 +5370,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupBroadcast(T value, ushort lane)"); statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast(value, lane);"); statement("return quad_broadcast(value, lane);");
else else
statement("return simd_broadcast(value, lane);"); statement("return simd_broadcast(value, lane);");
@ -5379,7 +5379,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)"); statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast((ushort)value, lane);"); statement("return !!quad_broadcast((ushort)value, lane);");
else else
statement("return !!simd_broadcast((ushort)value, lane);"); statement("return !!simd_broadcast((ushort)value, lane);");
@ -5388,7 +5388,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)"); statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);"); statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
else else
statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);"); statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
@ -5400,7 +5400,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupBroadcastFirst(T value)"); statement("inline T spvSubgroupBroadcastFirst(T value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast_first(value);"); statement("return quad_broadcast_first(value);");
else else
statement("return simd_broadcast_first(value);"); statement("return simd_broadcast_first(value);");
@ -5409,7 +5409,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupBroadcastFirst(bool value)"); statement("inline bool spvSubgroupBroadcastFirst(bool value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast_first((ushort)value);"); statement("return !!quad_broadcast_first((ushort)value);");
else else
statement("return !!simd_broadcast_first((ushort)value);"); statement("return !!simd_broadcast_first((ushort)value);");
@ -5418,7 +5418,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)"); statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);"); statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
else else
statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);"); statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
@ -5429,7 +5429,7 @@ void CompilerMSL::emit_custom_functions()
case SPVFuncImplSubgroupBallot: case SPVFuncImplSubgroupBallot:
statement("inline uint4 spvSubgroupBallot(bool value)"); statement("inline uint4 spvSubgroupBallot(bool value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
{ {
statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);"); statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
} }
@ -5557,7 +5557,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline bool spvSubgroupAllEqual(T value)"); statement("inline bool spvSubgroupAllEqual(T value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == quad_broadcast_first(value)));"); statement("return quad_all(all(value == quad_broadcast_first(value)));");
else else
statement("return simd_all(all(value == simd_broadcast_first(value)));"); statement("return simd_all(all(value == simd_broadcast_first(value)));");
@ -5566,7 +5566,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupAllEqual(bool value)"); statement("inline bool spvSubgroupAllEqual(bool value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_all(value) || !quad_any(value);"); statement("return quad_all(value) || !quad_any(value);");
else else
statement("return simd_all(value) || !simd_any(value);"); statement("return simd_all(value) || !simd_any(value);");
@ -5575,7 +5575,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)"); statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));"); statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
else else
statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));"); statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
@ -5587,7 +5587,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupShuffle(T value, ushort lane)"); statement("inline T spvSubgroupShuffle(T value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle(value, lane);"); statement("return quad_shuffle(value, lane);");
else else
statement("return simd_shuffle(value, lane);"); statement("return simd_shuffle(value, lane);");
@ -5596,7 +5596,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupShuffle(bool value, ushort lane)"); statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle((ushort)value, lane);"); statement("return !!quad_shuffle((ushort)value, lane);");
else else
statement("return !!simd_shuffle((ushort)value, lane);"); statement("return !!simd_shuffle((ushort)value, lane);");
@ -5605,7 +5605,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)"); statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);"); statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
else else
statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);"); statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
@ -5617,7 +5617,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupShuffleXor(T value, ushort mask)"); statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_xor(value, mask);"); statement("return quad_shuffle_xor(value, mask);");
else else
statement("return simd_shuffle_xor(value, mask);"); statement("return simd_shuffle_xor(value, mask);");
@ -5626,7 +5626,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)"); statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_xor((ushort)value, mask);"); statement("return !!quad_shuffle_xor((ushort)value, mask);");
else else
statement("return !!simd_shuffle_xor((ushort)value, mask);"); statement("return !!simd_shuffle_xor((ushort)value, mask);");
@ -5635,7 +5635,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)"); statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);"); statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
else else
statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);"); statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
@ -5647,7 +5647,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupShuffleUp(T value, ushort delta)"); statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_up(value, delta);"); statement("return quad_shuffle_up(value, delta);");
else else
statement("return simd_shuffle_up(value, delta);"); statement("return simd_shuffle_up(value, delta);");
@ -5656,7 +5656,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)"); statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_up((ushort)value, delta);"); statement("return !!quad_shuffle_up((ushort)value, delta);");
else else
statement("return !!simd_shuffle_up((ushort)value, delta);"); statement("return !!simd_shuffle_up((ushort)value, delta);");
@ -5665,7 +5665,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)"); statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);"); statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
else else
statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);"); statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
@ -5677,7 +5677,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>"); statement("template<typename T>");
statement("inline T spvSubgroupShuffleDown(T value, ushort delta)"); statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_down(value, delta);"); statement("return quad_shuffle_down(value, delta);");
else else
statement("return simd_shuffle_down(value, delta);"); statement("return simd_shuffle_down(value, delta);");
@ -5686,7 +5686,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>"); statement("template<>");
statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)"); statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_down((ushort)value, delta);"); statement("return !!quad_shuffle_down((ushort)value, delta);");
else else
statement("return !!simd_shuffle_down((ushort)value, delta);"); statement("return !!simd_shuffle_down((ushort)value, delta);");
@ -5695,7 +5695,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>"); statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)"); statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
begin_scope(); begin_scope();
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);"); statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
else else
statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);"); statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
@ -13972,7 +13972,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
switch (op) switch (op)
{ {
case OpGroupNonUniformElect: case OpGroupNonUniformElect:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
emit_op(result_type, id, "quad_is_first()", false); emit_op(result_type, id, "quad_is_first()", false);
else else
emit_op(result_type, id, "simd_is_first()", false); emit_op(result_type, id, "simd_is_first()", false);
@ -14045,14 +14045,14 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
break; break;
case OpGroupNonUniformAll: case OpGroupNonUniformAll:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_all"); emit_unary_func_op(result_type, id, ops[3], "quad_all");
else else
emit_unary_func_op(result_type, id, ops[3], "simd_all"); emit_unary_func_op(result_type, id, ops[3], "simd_all");
break; break;
case OpGroupNonUniformAny: case OpGroupNonUniformAny:
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_any"); emit_unary_func_op(result_type, id, ops[3], "quad_any");
else else
emit_unary_func_op(result_type, id, ops[3], "simd_any"); emit_unary_func_op(result_type, id, ops[3], "simd_any");
@ -14550,7 +14550,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation."); SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
if (!msl_options.supports_msl_version(2)) if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup"; return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
case BuiltInSubgroupId: case BuiltInSubgroupId:
if (msl_options.emulate_subgroups) if (msl_options.emulate_subgroups)
@ -14558,7 +14558,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation."); SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
if (!msl_options.supports_msl_version(2)) if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup"; return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
case BuiltInSubgroupLocalInvocationId: case BuiltInSubgroupLocalInvocationId:
if (msl_options.emulate_subgroups) if (msl_options.emulate_subgroups)
@ -14577,7 +14577,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
// We are generating a Metal kernel function. // We are generating a Metal kernel function.
if (!msl_options.supports_msl_version(2)) if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0."); SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup"; return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
} }
else else
SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function."); SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");

View File

@ -393,7 +393,7 @@ public:
// and will be addressed using the current ViewIndex. // and will be addressed using the current ViewIndex.
bool arrayed_subpass_input = false; bool arrayed_subpass_input = false;
// Whether to use SIMD-group or quadgroup functions to implement group nnon-uniform // Whether to use SIMD-group or quadgroup functions to implement group non-uniform
// operations. Some GPUs on iOS do not support the SIMD-group functions, only the // operations. Some GPUs on iOS do not support the SIMD-group functions, only the
// quadgroup functions. // quadgroup functions.
bool ios_use_simdgroup_functions = false; bool ios_use_simdgroup_functions = false;
@ -445,6 +445,11 @@ public:
return platform == macOS; return platform == macOS;
} }
bool use_quadgroup_operation() const
{
return is_ios() && !ios_use_simdgroup_functions;
}
void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
{ {
msl_version = make_msl_version(major, minor, patch); msl_version = make_msl_version(major, minor, patch);