MSL: Refactor and fix use of quadgroup vs simdgroup.
This commit is contained in:
parent
c08ee860c8
commit
5555f2784b
@ -221,7 +221,7 @@ struct SSBO
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
|
||||
{
|
||||
uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0));
|
||||
|
@ -5370,7 +5370,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_broadcast(value, lane);");
|
||||
else
|
||||
statement("return simd_broadcast(value, lane);");
|
||||
@ -5379,7 +5379,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_broadcast((ushort)value, lane);");
|
||||
else
|
||||
statement("return !!simd_broadcast((ushort)value, lane);");
|
||||
@ -5388,7 +5388,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
|
||||
@ -5400,7 +5400,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupBroadcastFirst(T value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_broadcast_first(value);");
|
||||
else
|
||||
statement("return simd_broadcast_first(value);");
|
||||
@ -5409,7 +5409,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupBroadcastFirst(bool value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_broadcast_first((ushort)value);");
|
||||
else
|
||||
statement("return !!simd_broadcast_first((ushort)value);");
|
||||
@ -5418,7 +5418,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
|
||||
@ -5429,7 +5429,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
case SPVFuncImplSubgroupBallot:
|
||||
statement("inline uint4 spvSubgroupBallot(bool value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
{
|
||||
statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
|
||||
}
|
||||
@ -5557,7 +5557,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline bool spvSubgroupAllEqual(T value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_all(all(value == quad_broadcast_first(value)));");
|
||||
else
|
||||
statement("return simd_all(all(value == simd_broadcast_first(value)));");
|
||||
@ -5566,7 +5566,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupAllEqual(bool value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_all(value) || !quad_any(value);");
|
||||
else
|
||||
statement("return simd_all(value) || !simd_any(value);");
|
||||
@ -5575,7 +5575,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
|
||||
else
|
||||
statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
|
||||
@ -5587,7 +5587,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupShuffle(T value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_shuffle(value, lane);");
|
||||
else
|
||||
statement("return simd_shuffle(value, lane);");
|
||||
@ -5596,7 +5596,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_shuffle((ushort)value, lane);");
|
||||
else
|
||||
statement("return !!simd_shuffle((ushort)value, lane);");
|
||||
@ -5605,7 +5605,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
|
||||
@ -5617,7 +5617,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_shuffle_xor(value, mask);");
|
||||
else
|
||||
statement("return simd_shuffle_xor(value, mask);");
|
||||
@ -5626,7 +5626,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_shuffle_xor((ushort)value, mask);");
|
||||
else
|
||||
statement("return !!simd_shuffle_xor((ushort)value, mask);");
|
||||
@ -5635,7 +5635,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
|
||||
@ -5647,7 +5647,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_shuffle_up(value, delta);");
|
||||
else
|
||||
statement("return simd_shuffle_up(value, delta);");
|
||||
@ -5656,7 +5656,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_shuffle_up((ushort)value, delta);");
|
||||
else
|
||||
statement("return !!simd_shuffle_up((ushort)value, delta);");
|
||||
@ -5665,7 +5665,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
|
||||
@ -5677,7 +5677,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return quad_shuffle_down(value, delta);");
|
||||
else
|
||||
statement("return simd_shuffle_down(value, delta);");
|
||||
@ -5686,7 +5686,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return !!quad_shuffle_down((ushort)value, delta);");
|
||||
else
|
||||
statement("return !!simd_shuffle_down((ushort)value, delta);");
|
||||
@ -5695,7 +5695,7 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("template<uint N>");
|
||||
statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
|
||||
begin_scope();
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
|
||||
else
|
||||
statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
|
||||
@ -13972,7 +13972,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
|
||||
switch (op)
|
||||
{
|
||||
case OpGroupNonUniformElect:
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
emit_op(result_type, id, "quad_is_first()", false);
|
||||
else
|
||||
emit_op(result_type, id, "simd_is_first()", false);
|
||||
@ -14045,14 +14045,14 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAll:
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
emit_unary_func_op(result_type, id, ops[3], "quad_all");
|
||||
else
|
||||
emit_unary_func_op(result_type, id, ops[3], "simd_all");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAny:
|
||||
if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
|
||||
if (msl_options.use_quadgroup_operation())
|
||||
emit_unary_func_op(result_type, id, ops[3], "quad_any");
|
||||
else
|
||||
emit_unary_func_op(result_type, id, ops[3], "simd_any");
|
||||
@ -14550,7 +14550,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
|
||||
SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
|
||||
return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
|
||||
|
||||
case BuiltInSubgroupId:
|
||||
if (msl_options.emulate_subgroups)
|
||||
@ -14558,7 +14558,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
|
||||
SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
|
||||
return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
|
||||
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
if (msl_options.emulate_subgroups)
|
||||
@ -14577,7 +14577,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
|
||||
// We are generating a Metal kernel function.
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
|
||||
return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
|
||||
}
|
||||
else
|
||||
SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
|
||||
|
@ -393,7 +393,7 @@ public:
|
||||
// and will be addressed using the current ViewIndex.
|
||||
bool arrayed_subpass_input = false;
|
||||
|
||||
// Whether to use SIMD-group or quadgroup functions to implement group nnon-uniform
|
||||
// Whether to use SIMD-group or quadgroup functions to implement group non-uniform
|
||||
// operations. Some GPUs on iOS do not support the SIMD-group functions, only the
|
||||
// quadgroup functions.
|
||||
bool ios_use_simdgroup_functions = false;
|
||||
@ -445,6 +445,11 @@ public:
|
||||
return platform == macOS;
|
||||
}
|
||||
|
||||
bool use_quadgroup_operation() const
|
||||
{
|
||||
return is_ios() && !ios_use_simdgroup_functions;
|
||||
}
|
||||
|
||||
void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
|
||||
{
|
||||
msl_version = make_msl_version(major, minor, patch);
|
||||
|
Loading…
Reference in New Issue
Block a user