MSL: Correct definitions of subgroup ballot mask variables.
`SubgroupEqMask` had a fencepost error that gave wrong values for invocation ID 32. For `SubgroupGeMask` and `SubgroupGtMask`, I forgot to shift the values from `extract_bits()` up so that the mask is in the correct position. Using `insert_bits()` instead should fold these two operations into one. `SubgroupLtMask` and `SubgroupLeMask` were already correct.
This commit is contained in:
parent
a57b4b1b2e
commit
6ccb902462
@ -68,9 +68,9 @@ inline bool spvSubgroupAllEqual(bool value)
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
|
||||
{
|
||||
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
_9.FragColor = float(gl_NumSubgroups);
|
||||
|
@ -69,9 +69,9 @@ fragment main0_out main0()
|
||||
main0_out out = {};
|
||||
uint gl_SubgroupSize = simd_sum(1);
|
||||
uint gl_SubgroupInvocationID = simd_prefix_exclusive_sum(1);
|
||||
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
out.FragColor = float(gl_SubgroupSize);
|
||||
|
@ -10463,7 +10463,7 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
|
||||
});
|
||||
@ -10475,25 +10475,25 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
// Case where index < 32, size < 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, index, size - index);
|
||||
// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
|
||||
// mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
|
||||
// mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
|
||||
// Case where index < 32 but size >= 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
|
||||
// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
|
||||
// mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
|
||||
// mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
|
||||
// Case where index >= 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
|
||||
// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
|
||||
// mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
|
||||
// mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
|
||||
// This is expressed without branches to avoid divergent
|
||||
// control flow--hence the complicated min/max expressions.
|
||||
// This is further complicated by the fact that if you attempt
|
||||
// to bfe out-of-bounds on Metal, undefined behavior is the
|
||||
// to bfi/bfe out-of-bounds on Metal, undefined behavior is the
|
||||
// result.
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, min(",
|
||||
" = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
|
||||
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
|
||||
", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
|
||||
to_expression(builtin_subgroup_size_id), " - (int)max(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
|
||||
@ -10508,11 +10508,11 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
|
||||
// The same logic applies here, except now the index is one
|
||||
// more than the subgroup invocation ID.
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, min(",
|
||||
" = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
|
||||
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
" - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
|
||||
" - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
|
||||
to_expression(builtin_subgroup_size_id), " - (int)max(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
|
||||
|
Loading…
Reference in New Issue
Block a user