MSL: Add support for subgroup operations.

Some support for subgroups is present starting in Metal 2.0 on both iOS
and macOS. macOS gains more complete support in 10.14 (Metal 2.1).

Some restrictions are present. On iOS and on macOS 10.13, the
implementation of `OpGroupNonUniformElect` is incorrect: if thread 0 has
already terminated or is not executing a conditional branch, the first
thread that *is* will falsely believe itself not to be. Unfortunately,
this operation is part of the "basic" feature set; without it, subgroups
cannot be supported at all.

The `SubgroupSize` and `SubgroupLocalInvocationId` builtins are only
available in compute shaders (and, by extension, tessellation control
shaders), despite SPIR-V making them available in all stages. This
limits the usefulness of some of the subgroup operations in fragment
shaders.

Although Metal on macOS supports some clustered, inclusive, and
exclusive operations, it does not support them all. In particular,
inclusive and exclusive min, max, and, or, and xor; as well as cluster
sizes other than 4 are not supported. If this becomes a problem, they
could be emulated, but at a significant performance cost due to the need
for non-uniform operations.
This commit is contained in:
Chip Davis 2019-05-15 16:03:30 -05:00
parent d11665424d
commit 9d9415754b
11 changed files with 1126 additions and 43 deletions

View File

@ -8,13 +8,15 @@ constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
kernel void main0()
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_device);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_device);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
threadgroup_barrier(mem_flags::mem_threadgroup);
}

View File

@ -0,0 +1,92 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float FragColor;
};
inline uint4 spvSubgroupBallot(bool value)
{
simd_vote vote = simd_ballot(value);
// simd_ballot() returns a 64-bit integer-like object, but
// SPIR-V callers expect a uint4. We must convert.
// FIXME: This won't include higher bits if Apple ever supports
// 128 lanes in an SIMD-group.
return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> 32) & 0xFFFFFFFF), 0, 0);
}
inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
{
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
}
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
{
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
}
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
{
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot)
{
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
}
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
}
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
}
template<typename T>
inline bool spvSubgroupAllEqual(T value)
{
return simd_all(value == simd_broadcast_first(value));
}
template<>
inline bool spvSubgroupAllEqual(bool value)
{
return simd_all(value) || !simd_any(value);
}
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{
uint4 gl_SubgroupEqMask = 27 > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
_9.FragColor = float(gl_NumSubgroups);
_9.FragColor = float(gl_SubgroupID);
_9.FragColor = float(gl_SubgroupSize);
_9.FragColor = float(gl_SubgroupInvocationID);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device);
simdgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_barrier(mem_flags::mem_texture);
_9.FragColor = float4(gl_SubgroupEqMask).x;
_9.FragColor = float4(gl_SubgroupGeMask).x;
_9.FragColor = float4(gl_SubgroupGtMask).x;
_9.FragColor = float4(gl_SubgroupLeMask).x;
_9.FragColor = float4(gl_SubgroupLtMask).x;
uint4 _83 = spvSubgroupBallot(true);
float4 _165 = simd_prefix_inclusive_product(simd_product(float4(20.0)));
int4 _167 = simd_prefix_inclusive_product(simd_product(int4(20)));
}

View File

@ -0,0 +1,23 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float FragColor;
};
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
{
_9.FragColor = float(gl_NumSubgroups);
_9.FragColor = float(gl_SubgroupID);
_9.FragColor = float(gl_SubgroupSize);
_9.FragColor = float(gl_SubgroupInvocationID);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device);
simdgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_barrier(mem_flags::mem_texture);
}

View File

@ -14,17 +14,22 @@ void barrier_shared()
void full_barrier()
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
}
void image_barrier()
{
threadgroup_barrier(mem_flags::mem_texture);
}
void buffer_barrier()
{
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_device);
}
void group_barrier()
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
}
void barrier_shared_exec()
@ -34,17 +39,22 @@ void barrier_shared_exec()
void full_barrier_exec()
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
}
void image_barrier_exec()
{
threadgroup_barrier(mem_flags::mem_texture);
}
void buffer_barrier_exec()
{
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_device);
}
void group_barrier_exec()
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
}
void exec_barrier()
@ -56,10 +66,12 @@ kernel void main0()
{
barrier_shared();
full_barrier();
image_barrier();
buffer_barrier();
group_barrier();
barrier_shared_exec();
full_barrier_exec();
image_barrier_exec();
buffer_barrier_exec();
group_barrier_exec();
exec_barrier();

View File

@ -0,0 +1,146 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float FragColor;
};
inline uint4 spvSubgroupBallot(bool value)
{
simd_vote vote = simd_ballot(value);
// simd_ballot() returns a 64-bit integer-like object, but
// SPIR-V callers expect a uint4. We must convert.
// FIXME: This won't include higher bits if Apple ever supports
// 128 lanes in an SIMD-group.
return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> 32) & 0xFFFFFFFF), 0, 0);
}
inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
{
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
}
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
{
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
}
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
{
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot)
{
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
}
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
}
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
}
template<typename T>
inline bool spvSubgroupAllEqual(T value)
{
return simd_all(value == simd_broadcast_first(value));
}
template<>
inline bool spvSubgroupAllEqual(bool value)
{
return simd_all(value) || !simd_any(value);
}
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{
uint4 gl_SubgroupEqMask = 27 > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
_9.FragColor = float(gl_NumSubgroups);
_9.FragColor = float(gl_SubgroupID);
_9.FragColor = float(gl_SubgroupSize);
_9.FragColor = float(gl_SubgroupInvocationID);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device);
simdgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_barrier(mem_flags::mem_texture);
bool elected = simd_is_first();
_9.FragColor = float4(gl_SubgroupEqMask).x;
_9.FragColor = float4(gl_SubgroupGeMask).x;
_9.FragColor = float4(gl_SubgroupGtMask).x;
_9.FragColor = float4(gl_SubgroupLeMask).x;
_9.FragColor = float4(gl_SubgroupLtMask).x;
float4 broadcasted = simd_broadcast(float4(10.0), 8u);
float3 first = simd_broadcast_first(float3(20.0));
uint4 ballot_value = spvSubgroupBallot(true);
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
uint msb = spvSubgroupBallotFindMSB(ballot_value);
uint shuffled = simd_shuffle(10u, 8u);
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
uint shuffled_up = simd_shuffle_up(20u, 4u);
uint shuffled_down = simd_shuffle_down(20u, 4u);
bool has_all = simd_all(true);
bool has_any = simd_any(true);
bool has_equal = spvSubgroupAllEqual(0);
has_equal = spvSubgroupAllEqual(true);
float4 added = simd_sum(float4(20.0));
int4 iadded = simd_sum(int4(20));
float4 multiplied = simd_product(float4(20.0));
int4 imultiplied = simd_product(int4(20));
float4 lo = simd_min(float4(20.0));
float4 hi = simd_max(float4(20.0));
int4 slo = simd_min(int4(20));
int4 shi = simd_max(int4(20));
uint4 ulo = simd_min(uint4(20u));
uint4 uhi = simd_max(uint4(20u));
uint4 anded = simd_and(ballot_value);
uint4 ored = simd_or(ballot_value);
uint4 xored = simd_xor(ballot_value);
added = simd_prefix_inclusive_sum(added);
iadded = simd_prefix_inclusive_sum(iadded);
multiplied = simd_prefix_inclusive_product(multiplied);
imultiplied = simd_prefix_inclusive_product(imultiplied);
added = simd_prefix_exclusive_sum(multiplied);
multiplied = simd_prefix_exclusive_product(multiplied);
iadded = simd_prefix_exclusive_sum(imultiplied);
imultiplied = simd_prefix_exclusive_product(imultiplied);
added = quad_sum(added);
multiplied = quad_product(multiplied);
iadded = quad_sum(iadded);
imultiplied = quad_product(imultiplied);
lo = quad_min(lo);
hi = quad_max(hi);
ulo = quad_min(ulo);
uhi = quad_max(uhi);
slo = quad_min(slo);
shi = quad_max(shi);
anded = quad_and(anded);
ored = quad_or(ored);
xored = quad_xor(xored);
float4 swap_horiz = quad_shuffle_xor(float4(20.0), 1u);
float4 swap_vertical = quad_shuffle_xor(float4(20.0), 2u);
float4 swap_diagonal = quad_shuffle_xor(float4(20.0), 3u);
float4 quad_broadcast0 = quad_broadcast(float4(20.0), 3u);
}

View File

@ -0,0 +1,32 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float FragColor;
};
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
{
_9.FragColor = float(gl_NumSubgroups);
_9.FragColor = float(gl_SubgroupID);
_9.FragColor = float(gl_SubgroupSize);
_9.FragColor = float(gl_SubgroupInvocationID);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device);
simdgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_barrier(mem_flags::mem_texture);
bool elected = (gl_SubgroupInvocationID == 0);
uint shuffled = quad_shuffle(10u, 8u);
uint shuffled_xor = quad_shuffle_xor(30u, 8u);
uint shuffled_up = quad_shuffle_up(20u, 4u);
uint shuffled_down = quad_shuffle_down(20u, 4u);
float4 swap_horiz = quad_shuffle_xor(float4(20.0), 1u);
float4 swap_vertical = quad_shuffle_xor(float4(20.0), 2u);
float4 swap_diagonal = quad_shuffle_xor(float4(20.0), 3u);
float4 quad_broadcast0 = quad_broadcast(float4(20.0), 3u);
}

View File

@ -11,12 +11,10 @@ void full_barrier()
memoryBarrier();
}
#if 0
void image_barrier()
{
memoryBarrierImage();
}
#endif
void buffer_barrier()
{
@ -40,13 +38,11 @@ void full_barrier_exec()
barrier();
}
#if 0
void image_barrier_exec()
{
memoryBarrierImage();
barrier();
}
#endif
void buffer_barrier_exec()
{
@ -69,13 +65,13 @@ void main()
{
barrier_shared();
full_barrier();
//image_barrier();
image_barrier();
buffer_barrier();
group_barrier();
barrier_shared_exec();
full_barrier_exec();
//image_barrier_exec();
image_barrier_exec();
buffer_barrier_exec();
group_barrier_exec();

View File

@ -0,0 +1,126 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_clustered : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer SSBO
{
float FragColor;
};
void main()
{
// basic
FragColor = float(gl_NumSubgroups);
FragColor = float(gl_SubgroupID);
FragColor = float(gl_SubgroupSize);
FragColor = float(gl_SubgroupInvocationID);
subgroupBarrier();
subgroupMemoryBarrier();
subgroupMemoryBarrierBuffer();
subgroupMemoryBarrierShared();
subgroupMemoryBarrierImage();
bool elected = subgroupElect();
// ballot
FragColor = float(gl_SubgroupEqMask);
FragColor = float(gl_SubgroupGeMask);
FragColor = float(gl_SubgroupGtMask);
FragColor = float(gl_SubgroupLeMask);
FragColor = float(gl_SubgroupLtMask);
vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u);
vec3 first = subgroupBroadcastFirst(vec3(20.0));
uvec4 ballot_value = subgroupBallot(true);
bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
uint bit_count = subgroupBallotBitCount(ballot_value);
uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
uint lsb = subgroupBallotFindLSB(ballot_value);
uint msb = subgroupBallotFindMSB(ballot_value);
// shuffle
uint shuffled = subgroupShuffle(10u, 8u);
uint shuffled_xor = subgroupShuffleXor(30u, 8u);
// shuffle relative
uint shuffled_up = subgroupShuffleUp(20u, 4u);
uint shuffled_down = subgroupShuffleDown(20u, 4u);
// vote
bool has_all = subgroupAll(true);
bool has_any = subgroupAny(true);
bool has_equal = subgroupAllEqual(0);
has_equal = subgroupAllEqual(true);
// arithmetic
vec4 added = subgroupAdd(vec4(20.0));
ivec4 iadded = subgroupAdd(ivec4(20));
vec4 multiplied = subgroupMul(vec4(20.0));
ivec4 imultiplied = subgroupMul(ivec4(20));
vec4 lo = subgroupMin(vec4(20.0));
vec4 hi = subgroupMax(vec4(20.0));
ivec4 slo = subgroupMin(ivec4(20));
ivec4 shi = subgroupMax(ivec4(20));
uvec4 ulo = subgroupMin(uvec4(20));
uvec4 uhi = subgroupMax(uvec4(20));
uvec4 anded = subgroupAnd(ballot_value);
uvec4 ored = subgroupOr(ballot_value);
uvec4 xored = subgroupXor(ballot_value);
added = subgroupInclusiveAdd(added);
iadded = subgroupInclusiveAdd(iadded);
multiplied = subgroupInclusiveMul(multiplied);
imultiplied = subgroupInclusiveMul(imultiplied);
//lo = subgroupInclusiveMin(lo); // FIXME: Unsupported by Metal
//hi = subgroupInclusiveMax(hi);
//slo = subgroupInclusiveMin(slo);
//shi = subgroupInclusiveMax(shi);
//ulo = subgroupInclusiveMin(ulo);
//uhi = subgroupInclusiveMax(uhi);
//anded = subgroupInclusiveAnd(anded);
//ored = subgroupInclusiveOr(ored);
//xored = subgroupInclusiveXor(ored);
//added = subgroupExclusiveAdd(lo);
added = subgroupExclusiveAdd(multiplied);
multiplied = subgroupExclusiveMul(multiplied);
iadded = subgroupExclusiveAdd(imultiplied);
imultiplied = subgroupExclusiveMul(imultiplied);
//lo = subgroupExclusiveMin(lo); // FIXME: Unsupported by Metal
//hi = subgroupExclusiveMax(hi);
//ulo = subgroupExclusiveMin(ulo);
//uhi = subgroupExclusiveMax(uhi);
//slo = subgroupExclusiveMin(slo);
//shi = subgroupExclusiveMax(shi);
//anded = subgroupExclusiveAnd(anded);
//ored = subgroupExclusiveOr(ored);
//xored = subgroupExclusiveXor(ored);
// clustered
added = subgroupClusteredAdd(added, 4u);
multiplied = subgroupClusteredMul(multiplied, 4u);
iadded = subgroupClusteredAdd(iadded, 4u);
imultiplied = subgroupClusteredMul(imultiplied, 4u);
lo = subgroupClusteredMin(lo, 4u);
hi = subgroupClusteredMax(hi, 4u);
ulo = subgroupClusteredMin(ulo, 4u);
uhi = subgroupClusteredMax(uhi, 4u);
slo = subgroupClusteredMin(slo, 4u);
shi = subgroupClusteredMax(shi, 4u);
anded = subgroupClusteredAnd(anded, 4u);
ored = subgroupClusteredOr(ored, 4u);
xored = subgroupClusteredXor(xored, 4u);
// quad
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
}

View File

@ -0,0 +1,42 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer SSBO
{
float FragColor;
};
// Reduced test for functionality exposed on iOS.
void main()
{
// basic
FragColor = float(gl_NumSubgroups);
FragColor = float(gl_SubgroupID);
FragColor = float(gl_SubgroupSize);
FragColor = float(gl_SubgroupInvocationID);
subgroupBarrier();
subgroupMemoryBarrier();
subgroupMemoryBarrierBuffer();
subgroupMemoryBarrierShared();
subgroupMemoryBarrierImage();
bool elected = subgroupElect();
// shuffle
uint shuffled = subgroupShuffle(10u, 8u);
uint shuffled_xor = subgroupShuffleXor(30u, 8u);
// shuffle relative
uint shuffled_up = subgroupShuffleUp(20u, 4u);
uint shuffled_down = subgroupShuffleDown(20u, 4u);
// quad
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
}

View File

@ -93,7 +93,14 @@ void CompilerMSL::build_implicit_builtins()
bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex;
bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
if (need_subpass_input || need_sample_pos || need_vertex_params || need_tesc_params)
bool need_subgroup_mask =
active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
active_input_builtins.get(BuiltInSubgroupLtMask);
bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
active_input_builtins.get(BuiltInSubgroupGtMask));
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
needs_subgroup_invocation_id)
{
bool has_frag_coord = false;
bool has_sample_id = false;
@ -103,18 +110,21 @@ void CompilerMSL::build_implicit_builtins()
bool has_base_instance = false;
bool has_invocation_id = false;
bool has_primitive_id = false;
bool has_subgroup_invocation_id = false;
bool has_subgroup_size = false;
ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
return;
if (need_subpass_input && ir.meta[var.self].decoration.builtin_type == BuiltInFragCoord)
BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
if (need_subpass_input && builtin == BuiltInFragCoord)
{
builtin_frag_coord_id = var.self;
has_frag_coord = true;
}
if (need_sample_pos && ir.meta[var.self].decoration.builtin_type == BuiltInSampleId)
if (need_sample_pos && builtin == BuiltInSampleId)
{
builtin_sample_id_id = var.self;
has_sample_id = true;
@ -122,7 +132,7 @@ void CompilerMSL::build_implicit_builtins()
if (need_vertex_params)
{
switch (ir.meta[var.self].decoration.builtin_type)
switch (builtin)
{
case BuiltInVertexIndex:
builtin_vertex_idx_id = var.self;
@ -147,7 +157,7 @@ void CompilerMSL::build_implicit_builtins()
if (need_tesc_params)
{
switch (ir.meta[var.self].decoration.builtin_type)
switch (builtin)
{
case BuiltInInvocationId:
builtin_invocation_id_id = var.self;
@ -161,6 +171,18 @@ void CompilerMSL::build_implicit_builtins()
break;
}
}
if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
{
builtin_subgroup_invocation_id_id = var.self;
has_subgroup_invocation_id = true;
}
if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
{
builtin_subgroup_size_id = var.self;
has_subgroup_size = true;
}
});
if (!has_frag_coord && need_subpass_input)
@ -311,6 +333,58 @@ void CompilerMSL::build_implicit_builtins()
builtin_primitive_id_id = var_id;
}
}
if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
{
uint32_t offset = ir.increase_bound_by(3);
uint32_t type_id = offset;
uint32_t type_ptr_id = offset + 1;
uint32_t var_id = offset + 2;
// Create gl_SubgroupInvocationID.
SPIRType uint_type;
uint_type.basetype = SPIRType::UInt;
uint_type.width = 32;
set<SPIRType>(type_id, uint_type);
SPIRType uint_type_ptr;
uint_type_ptr = uint_type;
uint_type_ptr.pointer = true;
uint_type_ptr.parent_type = type_id;
uint_type_ptr.storage = StorageClassInput;
auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
ptr_type.self = type_id;
set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
builtin_subgroup_invocation_id_id = var_id;
}
if (!has_subgroup_size && need_subgroup_ge_mask)
{
uint32_t offset = ir.increase_bound_by(3);
uint32_t type_id = offset;
uint32_t type_ptr_id = offset + 1;
uint32_t var_id = offset + 2;
// Create gl_SubgroupSize.
SPIRType uint_type;
uint_type.basetype = SPIRType::UInt;
uint_type.width = 32;
set<SPIRType>(type_id, uint_type);
SPIRType uint_type_ptr;
uint_type_ptr = uint_type;
uint_type_ptr.pointer = true;
uint_type_ptr.parent_type = type_id;
uint_type_ptr.storage = StorageClassInput;
auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
ptr_type.self = type_id;
set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
builtin_subgroup_size_id = var_id;
}
}
if (needs_aux_buffer_def)
@ -598,6 +672,7 @@ string CompilerMSL::compile()
update_active_builtins();
analyze_image_and_sampler_usage();
analyze_sampled_image_usage();
preprocess_op_codes();
build_implicit_builtins();
fixup_image_load_store_access();
@ -606,9 +681,6 @@ string CompilerMSL::compile()
if (aux_buffer_id)
active_interface_variables.insert(aux_buffer_id);
// Preprocess OpCodes to extract the need to output additional header content
preprocess_op_codes();
// Create structs to hold input, output and uniform variables.
// Do output first to ensure out. is declared at top of entry function.
qual_pos_var_name = "";
@ -700,6 +772,9 @@ void CompilerMSL::preprocess_op_codes()
is_rasterization_disabled = true;
capture_output_to_buffer = true;
}
if (preproc.needs_subgroup_invocation_id)
needs_subgroup_invocation_id = true;
}
// Move the Private and Workgroup global variables to the entry function.
@ -2877,6 +2952,90 @@ void CompilerMSL::emit_custom_functions()
statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupBallot:
statement("inline uint4 spvSubgroupBallot(bool value)");
begin_scope();
statement("simd_vote vote = simd_ballot(value);");
statement("// simd_ballot() returns a 64-bit integer-like object, but");
statement("// SPIR-V callers expect a uint4. We must convert.");
statement("// FIXME: This won't include higher bits if Apple ever supports");
statement("// 128 lanes in an SIMD-group.");
statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
"32) & 0xFFFFFFFF), 0, 0);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupBallotBitExtract:
statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
begin_scope();
statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupBallotFindLSB:
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
begin_scope();
statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
"ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupBallotFindMSB:
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
begin_scope();
statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
"(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
"ballot.z == 0), ballot.w == 0);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupBallotBitCount:
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
begin_scope();
statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
end_scope();
statement("");
statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
"uint2(0));");
statement("return spvSubgroupBallotBitCount(ballot & mask);");
end_scope();
statement("");
statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
statement("return spvSubgroupBallotBitCount(ballot & mask);");
end_scope();
statement("");
break;
case SPVFuncImplSubgroupAllEqual:
// Metal doesn't provide a function to evaluate this directly. But, we can
// implement this by comparing every thread's value to one thread's value
// (in this case, the value of the first active thread). Then, by the transitive
// property of equality, if all comparisons return true, then they are all equal.
statement("template<typename T>");
statement("inline bool spvSubgroupAllEqual(T value)");
begin_scope();
statement("return simd_all(value == simd_broadcast_first(value));");
end_scope();
statement("");
statement("template<>");
statement("inline bool spvSubgroupAllEqual(bool value)");
begin_scope();
statement("return simd_all(value) || !simd_any(value);");
end_scope();
statement("");
break;
default:
break;
@ -3895,33 +4054,70 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
return;
string bar_stmt = "threadgroup_barrier(mem_flags::";
uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
// Use the wider of the two scopes (smaller value)
exe_scope = min(exe_scope, mem_scope);
string bar_stmt;
if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
else
bar_stmt = "threadgroup_barrier";
bar_stmt += "(";
uint32_t mem_sem = id_mem_sem ? get<SPIRConstant>(id_mem_sem).scalar() : uint32_t(MemorySemanticsMaskNone);
if (get_execution_model() == ExecutionModelTessellationControl)
// Use the | operator to combine flags if we can.
if (msl_options.supports_msl_version(1, 2))
{
string mem_flags = "";
// For tesc shaders, this also affects objects in the Output storage class.
// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
bar_stmt += "mem_device";
else if (mem_sem & MemorySemanticsCrossWorkgroupMemoryMask)
bar_stmt += "mem_device";
else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
MemorySemanticsAtomicCounterMemoryMask))
bar_stmt += "mem_threadgroup";
else if (mem_sem & MemorySemanticsImageMemoryMask)
bar_stmt += "mem_texture";
if (get_execution_model() == ExecutionModelTessellationControl ||
(mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
mem_flags += "mem_flags::mem_device";
if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
MemorySemanticsAtomicCounterMemoryMask))
{
if (!mem_flags.empty())
mem_flags += " | ";
mem_flags += "mem_flags::mem_threadgroup";
}
if (mem_sem & MemorySemanticsImageMemoryMask)
{
if (!mem_flags.empty())
mem_flags += " | ";
mem_flags += "mem_flags::mem_texture";
}
if (mem_flags.empty())
mem_flags = "mem_flags::mem_none";
bar_stmt += mem_flags;
}
else
bar_stmt += "mem_none";
{
if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
(mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
MemorySemanticsAtomicCounterMemoryMask)))
bar_stmt += "mem_flags::mem_device_and_threadgroup";
else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
bar_stmt += "mem_flags::mem_device";
else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
MemorySemanticsAtomicCounterMemoryMask))
bar_stmt += "mem_flags::mem_threadgroup";
else if (mem_sem & MemorySemanticsImageMemoryMask)
bar_stmt += "mem_flags::mem_texture";
else
bar_stmt += "mem_flags::mem_none";
}
if (msl_options.is_ios() && (msl_options.supports_msl_version(2) && !msl_options.supports_msl_version(2, 1)))
{
bar_stmt += ", ";
// Use the wider of the two scopes (smaller value)
uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
uint32_t scope = min(exe_scope, mem_scope);
switch (scope)
switch (mem_scope)
{
case ScopeCrossDevice:
case ScopeDevice:
@ -5188,6 +5384,8 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
{
case BuiltInInvocationId:
case BuiltInPrimitiveId:
case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
case BuiltInSubgroupSize: // FIXME: Should work in any stage
return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
case BuiltInPatchVertices:
return "";
@ -5347,6 +5545,10 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
case BuiltInNumWorkgroups:
case BuiltInLocalInvocationId:
case BuiltInLocalInvocationIndex:
case BuiltInNumSubgroups:
case BuiltInSubgroupId:
case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
case BuiltInSubgroupSize: // FIXME: Should work in any stage
return string(" [[") + builtin_qualifier(builtin) + "]]";
default:
@ -5593,7 +5795,9 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
if (bi_type != BuiltInSamplePosition && bi_type != BuiltInHelperInvocation &&
bi_type != BuiltInPatchVertices && bi_type != BuiltInTessLevelInner &&
bi_type != BuiltInTessLevelOuter && bi_type != BuiltInPosition && bi_type != BuiltInPointSize &&
bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance)
bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance && bi_type != BuiltInSubgroupEqMask &&
bi_type != BuiltInSubgroupGeMask && bi_type != BuiltInSubgroupGtMask &&
bi_type != BuiltInSubgroupLeMask && bi_type != BuiltInSubgroupLtMask)
{
if (!ep_args.empty())
ep_args += ", ";
@ -5911,6 +6115,94 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
}
break;
case BuiltInSubgroupEqMask:
if (msl_options.is_ios())
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
if (!msl_options.supports_msl_version(2, 1))
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
builtin_subgroup_invocation_id_id, " > 32 ? uint4(0, (1 << (",
to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
});
break;
case BuiltInSubgroupGeMask:
if (msl_options.is_ios())
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
if (!msl_options.supports_msl_version(2, 1))
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
// Case where index < 32, size < 32:
// mask0 = bfe(0xFFFFFFFF, index, size - index);
// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
// Case where index < 32 but size >= 32:
// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
// Case where index >= 32:
// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
// This is expressed without branches to avoid divergent
// control flow--hence the complicated min/max expressions.
// This is further complicated by the fact that if you attempt
// to bfe out-of-bounds on Metal, undefined behavior is the
// result.
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, min(",
to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
to_expression(builtin_subgroup_invocation_id_id),
", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
to_expression(builtin_subgroup_size_id), " - (int)max(",
to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
});
break;
case BuiltInSubgroupGtMask:
if (msl_options.is_ios())
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
if (!msl_options.supports_msl_version(2, 1))
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
// The same logic applies here, except now the index is one
// more than the subgroup invocation ID.
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, min(",
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
to_expression(builtin_subgroup_invocation_id_id),
" - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
to_expression(builtin_subgroup_size_id), " - (int)max(",
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
});
break;
case BuiltInSubgroupLeMask:
if (msl_options.is_ios())
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
if (!msl_options.supports_msl_version(2, 1))
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, 0, min(",
to_expression(builtin_subgroup_invocation_id_id),
" + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
});
break;
case BuiltInSubgroupLtMask:
if (msl_options.is_ios())
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
if (!msl_options.supports_msl_version(2, 1))
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
entry_func.fixup_hooks_in.push_back([=]() {
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
" = uint4(extract_bits(0xFFFFFFFF, 0, min(",
to_expression(builtin_subgroup_invocation_id_id),
", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
});
break;
default:
break;
}
@ -6266,6 +6558,7 @@ void CompilerMSL::replace_illegal_names()
"M_2_SQRTPI",
"M_SQRT2",
"M_SQRT1_2",
"quad_broadcast",
};
static const unordered_set<string> illegal_func_names = {
@ -6748,6 +7041,245 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
return img_type_name;
}
void CompilerMSL::emit_subgroup_op(const Instruction &i)
{
const uint32_t *ops = stream(i);
auto op = static_cast<Op>(i.op);
// Metal 2.0 is required. iOS only supports quad ops. macOS only supports
// broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
// Note that iOS makes no distinction between a quad-group and a subgroup;
// all subgroups are quad-groups there.
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
if (msl_options.is_ios())
{
switch (op)
{
default:
SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
case OpGroupNonUniformElect:
case OpGroupNonUniformBroadcast:
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
case OpGroupNonUniformQuadSwap:
case OpGroupNonUniformQuadBroadcast:
break;
}
}
if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
{
switch (op)
{
default:
SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.0 and up.");
case OpGroupNonUniformElect:
case OpGroupNonUniformBroadcast:
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
break;
}
}
uint32_t result_type = ops[0];
uint32_t id = ops[1];
auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
if (scope != ScopeSubgroup)
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
switch (op)
{
case OpGroupNonUniformElect:
// Vulkan spec says we have to support this if we support subgroups at all.
// But Metal prior to macOS 10.14 doesn't have the simd_is_first() function, and
// iOS doesn't have it at all. So we fake it by comparing the subgroup-local
// ID to 0. This isn't quite correct: this is supposed to return if we're the
// lowest *active* thread, but we'll otherwise be unable to support subgroups
// on macOS 10.13 or iOS.
if (msl_options.is_macos() && msl_options.supports_msl_version(2, 1))
emit_op(result_type, id, "simd_is_first()", true);
else
emit_op(result_type, id, join("(", to_expression(builtin_subgroup_invocation_id_id), " == 0)"), true);
break;
case OpGroupNonUniformBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4],
msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
break;
case OpGroupNonUniformBroadcastFirst:
emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
break;
case OpGroupNonUniformBallot:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
break;
case OpGroupNonUniformInverseBallot:
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
break;
case OpGroupNonUniformBallotBitExtract:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
break;
case OpGroupNonUniformBallotFindLSB:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
break;
case OpGroupNonUniformBallotFindMSB:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
break;
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationReduce)
emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
else if (operation == GroupOperationInclusiveScan)
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotInclusiveBitCount");
else if (operation == GroupOperationExclusiveScan)
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotExclusiveBitCount");
else
SPIRV_CROSS_THROW("Invalid BitCount operation.");
break;
}
case OpGroupNonUniformShuffle:
emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
break;
case OpGroupNonUniformShuffleXor:
emit_binary_func_op(result_type, id, ops[3], ops[4],
msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
break;
case OpGroupNonUniformShuffleUp:
emit_binary_func_op(result_type, id, ops[3], ops[4],
msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
break;
case OpGroupNonUniformShuffleDown:
emit_binary_func_op(result_type, id, ops[3], ops[4],
msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
break;
case OpGroupNonUniformAll:
emit_unary_func_op(result_type, id, ops[3], "simd_all");
break;
case OpGroupNonUniformAny:
emit_unary_func_op(result_type, id, ops[3], "simd_any");
break;
case OpGroupNonUniformAllEqual:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
break;
// clang-format off
#define MSL_GROUP_OP(op, msl_op) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
else if (operation == GroupOperationInclusiveScan) \
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
else if (operation == GroupOperationExclusiveScan) \
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
break; \
}
MSL_GROUP_OP(FAdd, sum)
MSL_GROUP_OP(FMul, product)
MSL_GROUP_OP(IAdd, sum)
MSL_GROUP_OP(IMul, product)
#undef MSL_GROUP_OP
// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
#define MSL_GROUP_OP(op, msl_op) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
else if (operation == GroupOperationInclusiveScan) \
SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationExclusiveScan) \
SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
break; \
}
MSL_GROUP_OP(FMin, min)
MSL_GROUP_OP(FMax, max)
MSL_GROUP_OP(SMin, min)
MSL_GROUP_OP(SMax, max)
MSL_GROUP_OP(UMin, min)
MSL_GROUP_OP(UMax, max)
MSL_GROUP_OP(BitwiseAnd, and)
MSL_GROUP_OP(BitwiseOr, or)
MSL_GROUP_OP(BitwiseXor, xor)
MSL_GROUP_OP(LogicalAnd, and)
MSL_GROUP_OP(LogicalOr, or)
MSL_GROUP_OP(LogicalXor, xor)
// clang-format on
case OpGroupNonUniformQuadSwap:
{
// We can implement this easily based on the following table giving
// the target lane ID from the direction and current lane ID:
// Direction
// | 0 | 1 | 2 |
// ---+---+---+---+
// L 0 | 1 2 3
// a 1 | 0 3 2
// n 2 | 3 0 1
// e 3 | 2 1 0
// Notice that target = source ^ (direction + 1).
uint32_t mask = get<SPIRConstant>(ops[4]).scalar() + 1;
uint32_t mask_id = ir.increase_bound_by(1);
set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
break;
}
case OpGroupNonUniformQuadBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
break;
default:
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
}
register_control_dependent_expression(id);
}
string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
{
if (out_type.basetype == in_type.basetype)
@ -6954,6 +7486,32 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
case BuiltInLocalInvocationIndex:
return "thread_index_in_threadgroup";
case BuiltInSubgroupSize:
return "thread_execution_width";
case BuiltInNumSubgroups:
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
case BuiltInSubgroupId:
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
case BuiltInSubgroupLocalInvocationId:
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
case BuiltInSubgroupEqMask:
case BuiltInSubgroupGeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupLeMask:
case BuiltInSubgroupLtMask:
// Shouldn't be reached.
SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
default:
return "unsupported-built-in";
}
@ -7042,7 +7600,17 @@ string CompilerMSL::builtin_type_decl(BuiltIn builtin)
case BuiltInWorkgroupId:
return "uint3";
case BuiltInLocalInvocationIndex:
case BuiltInNumSubgroups:
case BuiltInSubgroupId:
case BuiltInSubgroupSize:
case BuiltInSubgroupLocalInvocationId:
return "uint";
case BuiltInSubgroupEqMask:
case BuiltInSubgroupGeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupLeMask:
case BuiltInSubgroupLtMask:
return "uint4";
case BuiltInHelperInvocation:
return "bool";
@ -7267,6 +7835,20 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
uses_atomics = true;
break;
case OpGroupNonUniformElect:
if (compiler.msl_options.is_ios() || !compiler.msl_options.supports_msl_version(2, 1))
needs_subgroup_invocation_id = true;
break;
case OpGroupNonUniformInverseBallot:
needs_subgroup_invocation_id = true;
break;
case OpGroupNonUniformBallotBitCount:
if (args[3] != GroupOperationReduce)
needs_subgroup_invocation_id = true;
break;
default:
break;
}
@ -7433,6 +8015,25 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
break;
}
case OpGroupNonUniformBallot:
return SPVFuncImplSubgroupBallot;
case OpGroupNonUniformInverseBallot:
case OpGroupNonUniformBallotBitExtract:
return SPVFuncImplSubgroupBallotBitExtract;
case OpGroupNonUniformBallotFindLSB:
return SPVFuncImplSubgroupBallotFindLSB;
case OpGroupNonUniformBallotFindMSB:
return SPVFuncImplSubgroupBallotFindMSB;
case OpGroupNonUniformBallotBitCount:
return SPVFuncImplSubgroupBallotBitCount;
case OpGroupNonUniformAllEqual:
return SPVFuncImplSubgroupAllEqual;
default:
break;
}

View File

@ -344,6 +344,12 @@ protected:
SPVFuncImplRowMajor4x2,
SPVFuncImplRowMajor4x3,
SPVFuncImplTextureSwizzle,
SPVFuncImplSubgroupBallot,
SPVFuncImplSubgroupBallotBitExtract,
SPVFuncImplSubgroupBallotFindLSB,
SPVFuncImplSubgroupBallotFindMSB,
SPVFuncImplSubgroupBallotBitCount,
SPVFuncImplSubgroupAllEqual,
SPVFuncImplArrayCopyMultidimMax = 6
};
@ -354,6 +360,7 @@ protected:
void emit_header() override;
void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override;
void emit_subgroup_op(const Instruction &i) override;
void emit_fixup() override;
std::string to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
const std::string &qualifier = "");
@ -477,6 +484,8 @@ protected:
uint32_t builtin_base_instance_id = 0;
uint32_t builtin_invocation_id_id = 0;
uint32_t builtin_primitive_id_id = 0;
uint32_t builtin_subgroup_invocation_id_id = 0;
uint32_t builtin_subgroup_size_id = 0;
uint32_t aux_buffer_id = 0;
void bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
@ -518,6 +527,7 @@ protected:
bool needs_aux_buffer_def = false;
bool used_aux_buffer = false;
bool added_builtin_tess_level = false;
bool needs_subgroup_invocation_id = false;
std::string qual_pos_var_name;
std::string stage_in_var_name = "in";
std::string stage_out_var_name = "out";
@ -561,6 +571,7 @@ protected:
bool suppress_missing_prototypes = false;
bool uses_atomics = false;
bool uses_resource_write = false;
bool needs_subgroup_invocation_id = false;
};
// OpcodeHandler that scans for uses of sampled images