MSL: Correct definitions of subgroup ballot mask variables.

`SubgroupEqMask` had a fencepost error that gave wrong values for
invocation ID 32.

For `SubgroupGeMask` and `SubgroupGtMask`, I forgot to shift the values
from `extract_bits()` up so that the mask is in the correct position.
Using `insert_bits()` instead should fold these two operations into one.

`SubgroupLtMask` and `SubgroupLeMask` were already correct.
This commit is contained in:
Chip Davis 2020-10-20 21:58:07 -05:00
parent a57b4b1b2e
commit 6ccb902462
3 changed files with 18 additions and 18 deletions

View File

@ -68,9 +68,9 @@ inline bool spvSubgroupAllEqual(bool value)
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
_9.FragColor = float(gl_NumSubgroups);

View File

@ -69,9 +69,9 @@ fragment main0_out main0()
main0_out out = {};
uint gl_SubgroupSize = simd_sum(1);
uint gl_SubgroupInvocationID = simd_prefix_exclusive_sum(1);
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
out.FragColor = float(gl_SubgroupSize);

View File

@ -10463,7 +10463,7 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
});
@ -10475,25 +10475,25 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
// Case where index < 32, size < 32:
// mask0 = bfe(0xFFFFFFFF, index, size - index);
// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
// mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
// mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
// Case where index < 32 but size >= 32:
// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
// mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
// mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
// Case where index >= 32:
// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
// mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
// mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
// This is expressed without branches to avoid divergent
// control flow--hence the complicated min/max expressions.
// This is further complicated by the fact that if you attempt
// to bfe out-of-bounds on Metal, undefined behavior is the
// to bfi/bfe out-of-bounds on Metal, undefined behavior is the
// result.
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, min(",
" = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
to_expression(builtin_subgroup_invocation_id_id),
", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
to_expression(builtin_subgroup_size_id), " - (int)max(",
to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
@ -10508,11 +10508,11 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
// The same logic applies here, except now the index is one
// more than the subgroup invocation ID.
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, min(",
" = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
to_expression(builtin_subgroup_invocation_id_id),
" - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
" - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
to_expression(builtin_subgroup_size_id), " - (int)max(",
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");