MSL: Add support for subgroup operations.
Some support for subgroups is present starting in Metal 2.0 on both iOS and macOS. macOS gains more complete support in 10.14 (Metal 2.1). Some restrictions are present. On iOS and on macOS 10.13, the implementation of `OpGroupNonUniformElect` is incorrect: if thread 0 has already terminated or is not executing a conditional branch, the first thread that *is* will falsely believe itself not to be. Unfortunately, this operation is part of the "basic" feature set; without it, subgroups cannot be supported at all. The `SubgroupSize` and `SubgroupLocalInvocationId` builtins are only available in compute shaders (and, by extension, tessellation control shaders), despite SPIR-V making them available in all stages. This limits the usefulness of some of the subgroup operations in fragment shaders. Although Metal on macOS supports some clustered, inclusive, and exclusive operations, it does not support them all. In particular, inclusive and exclusive min, max, and, or, and xor; as well as cluster sizes other than 4 are not supported. If this becomes a problem, they could be emulated, but at a significant performance cost due to the need for non-uniform operations.
This commit is contained in:
parent
d11665424d
commit
9d9415754b
@ -8,13 +8,15 @@ constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
|
||||
kernel void main0()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,92 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
inline uint4 spvSubgroupBallot(bool value)
|
||||
{
|
||||
simd_vote vote = simd_ballot(value);
|
||||
// simd_ballot() returns a 64-bit integer-like object, but
|
||||
// SPIR-V callers expect a uint4. We must convert.
|
||||
// FIXME: This won't include higher bits if Apple ever supports
|
||||
// 128 lanes in an SIMD-group.
|
||||
return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> 32) & 0xFFFFFFFF), 0, 0);
|
||||
}
|
||||
|
||||
inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
|
||||
{
|
||||
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
|
||||
{
|
||||
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
|
||||
{
|
||||
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot)
|
||||
{
|
||||
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool spvSubgroupAllEqual(T value)
|
||||
{
|
||||
return simd_all(value == simd_broadcast_first(value));
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool spvSubgroupAllEqual(bool value)
|
||||
{
|
||||
return simd_all(value) || !simd_any(value);
|
||||
}
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
|
||||
{
|
||||
uint4 gl_SubgroupEqMask = 27 > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
_9.FragColor = float(gl_NumSubgroups);
|
||||
_9.FragColor = float(gl_SubgroupID);
|
||||
_9.FragColor = float(gl_SubgroupSize);
|
||||
_9.FragColor = float(gl_SubgroupInvocationID);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
simdgroup_barrier(mem_flags::mem_texture);
|
||||
_9.FragColor = float4(gl_SubgroupEqMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupGeMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupGtMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupLeMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupLtMask).x;
|
||||
uint4 _83 = spvSubgroupBallot(true);
|
||||
float4 _165 = simd_prefix_inclusive_product(simd_product(float4(20.0)));
|
||||
int4 _167 = simd_prefix_inclusive_product(simd_product(int4(20)));
|
||||
}
|
||||
|
@ -0,0 +1,23 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
|
||||
{
|
||||
_9.FragColor = float(gl_NumSubgroups);
|
||||
_9.FragColor = float(gl_SubgroupID);
|
||||
_9.FragColor = float(gl_SubgroupSize);
|
||||
_9.FragColor = float(gl_SubgroupInvocationID);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
simdgroup_barrier(mem_flags::mem_texture);
|
||||
}
|
||||
|
@ -14,17 +14,22 @@ void barrier_shared()
|
||||
|
||||
void full_barrier()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void image_barrier()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void buffer_barrier()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
}
|
||||
|
||||
void group_barrier()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void barrier_shared_exec()
|
||||
@ -34,17 +39,22 @@ void barrier_shared_exec()
|
||||
|
||||
void full_barrier_exec()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void image_barrier_exec()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void buffer_barrier_exec()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
}
|
||||
|
||||
void group_barrier_exec()
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
}
|
||||
|
||||
void exec_barrier()
|
||||
@ -56,10 +66,12 @@ kernel void main0()
|
||||
{
|
||||
barrier_shared();
|
||||
full_barrier();
|
||||
image_barrier();
|
||||
buffer_barrier();
|
||||
group_barrier();
|
||||
barrier_shared_exec();
|
||||
full_barrier_exec();
|
||||
image_barrier_exec();
|
||||
buffer_barrier_exec();
|
||||
group_barrier_exec();
|
||||
exec_barrier();
|
||||
|
@ -0,0 +1,146 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
inline uint4 spvSubgroupBallot(bool value)
|
||||
{
|
||||
simd_vote vote = simd_ballot(value);
|
||||
// simd_ballot() returns a 64-bit integer-like object, but
|
||||
// SPIR-V callers expect a uint4. We must convert.
|
||||
// FIXME: This won't include higher bits if Apple ever supports
|
||||
// 128 lanes in an SIMD-group.
|
||||
return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> 32) & 0xFFFFFFFF), 0, 0);
|
||||
}
|
||||
|
||||
inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
|
||||
{
|
||||
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
|
||||
{
|
||||
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
|
||||
{
|
||||
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot)
|
||||
{
|
||||
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool spvSubgroupAllEqual(T value)
|
||||
{
|
||||
return simd_all(value == simd_broadcast_first(value));
|
||||
}
|
||||
|
||||
template<>
|
||||
inline bool spvSubgroupAllEqual(bool value)
|
||||
{
|
||||
return simd_all(value) || !simd_any(value);
|
||||
}
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
|
||||
{
|
||||
uint4 gl_SubgroupEqMask = 27 > 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0));
|
||||
uint4 gl_SubgroupGeMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupGtMask = uint4(extract_bits(0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0));
|
||||
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
_9.FragColor = float(gl_NumSubgroups);
|
||||
_9.FragColor = float(gl_SubgroupID);
|
||||
_9.FragColor = float(gl_SubgroupSize);
|
||||
_9.FragColor = float(gl_SubgroupInvocationID);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
simdgroup_barrier(mem_flags::mem_texture);
|
||||
bool elected = simd_is_first();
|
||||
_9.FragColor = float4(gl_SubgroupEqMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupGeMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupGtMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupLeMask).x;
|
||||
_9.FragColor = float4(gl_SubgroupLtMask).x;
|
||||
float4 broadcasted = simd_broadcast(float4(10.0), 8u);
|
||||
float3 first = simd_broadcast_first(float3(20.0));
|
||||
uint4 ballot_value = spvSubgroupBallot(true);
|
||||
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
|
||||
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
|
||||
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
|
||||
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
|
||||
uint msb = spvSubgroupBallotFindMSB(ballot_value);
|
||||
uint shuffled = simd_shuffle(10u, 8u);
|
||||
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
|
||||
uint shuffled_up = simd_shuffle_up(20u, 4u);
|
||||
uint shuffled_down = simd_shuffle_down(20u, 4u);
|
||||
bool has_all = simd_all(true);
|
||||
bool has_any = simd_any(true);
|
||||
bool has_equal = spvSubgroupAllEqual(0);
|
||||
has_equal = spvSubgroupAllEqual(true);
|
||||
float4 added = simd_sum(float4(20.0));
|
||||
int4 iadded = simd_sum(int4(20));
|
||||
float4 multiplied = simd_product(float4(20.0));
|
||||
int4 imultiplied = simd_product(int4(20));
|
||||
float4 lo = simd_min(float4(20.0));
|
||||
float4 hi = simd_max(float4(20.0));
|
||||
int4 slo = simd_min(int4(20));
|
||||
int4 shi = simd_max(int4(20));
|
||||
uint4 ulo = simd_min(uint4(20u));
|
||||
uint4 uhi = simd_max(uint4(20u));
|
||||
uint4 anded = simd_and(ballot_value);
|
||||
uint4 ored = simd_or(ballot_value);
|
||||
uint4 xored = simd_xor(ballot_value);
|
||||
added = simd_prefix_inclusive_sum(added);
|
||||
iadded = simd_prefix_inclusive_sum(iadded);
|
||||
multiplied = simd_prefix_inclusive_product(multiplied);
|
||||
imultiplied = simd_prefix_inclusive_product(imultiplied);
|
||||
added = simd_prefix_exclusive_sum(multiplied);
|
||||
multiplied = simd_prefix_exclusive_product(multiplied);
|
||||
iadded = simd_prefix_exclusive_sum(imultiplied);
|
||||
imultiplied = simd_prefix_exclusive_product(imultiplied);
|
||||
added = quad_sum(added);
|
||||
multiplied = quad_product(multiplied);
|
||||
iadded = quad_sum(iadded);
|
||||
imultiplied = quad_product(imultiplied);
|
||||
lo = quad_min(lo);
|
||||
hi = quad_max(hi);
|
||||
ulo = quad_min(ulo);
|
||||
uhi = quad_max(uhi);
|
||||
slo = quad_min(slo);
|
||||
shi = quad_max(shi);
|
||||
anded = quad_and(anded);
|
||||
ored = quad_or(ored);
|
||||
xored = quad_xor(xored);
|
||||
float4 swap_horiz = quad_shuffle_xor(float4(20.0), 1u);
|
||||
float4 swap_vertical = quad_shuffle_xor(float4(20.0), 2u);
|
||||
float4 swap_diagonal = quad_shuffle_xor(float4(20.0), 3u);
|
||||
float4 quad_broadcast0 = quad_broadcast(float4(20.0), 3u);
|
||||
}
|
||||
|
@ -0,0 +1,32 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
|
||||
{
|
||||
_9.FragColor = float(gl_NumSubgroups);
|
||||
_9.FragColor = float(gl_SubgroupID);
|
||||
_9.FragColor = float(gl_SubgroupSize);
|
||||
_9.FragColor = float(gl_SubgroupInvocationID);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
|
||||
simdgroup_barrier(mem_flags::mem_device);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
simdgroup_barrier(mem_flags::mem_texture);
|
||||
bool elected = (gl_SubgroupInvocationID == 0);
|
||||
uint shuffled = quad_shuffle(10u, 8u);
|
||||
uint shuffled_xor = quad_shuffle_xor(30u, 8u);
|
||||
uint shuffled_up = quad_shuffle_up(20u, 4u);
|
||||
uint shuffled_down = quad_shuffle_down(20u, 4u);
|
||||
float4 swap_horiz = quad_shuffle_xor(float4(20.0), 1u);
|
||||
float4 swap_vertical = quad_shuffle_xor(float4(20.0), 2u);
|
||||
float4 swap_diagonal = quad_shuffle_xor(float4(20.0), 3u);
|
||||
float4 quad_broadcast0 = quad_broadcast(float4(20.0), 3u);
|
||||
}
|
||||
|
@ -11,12 +11,10 @@ void full_barrier()
|
||||
memoryBarrier();
|
||||
}
|
||||
|
||||
#if 0
|
||||
void image_barrier()
|
||||
{
|
||||
memoryBarrierImage();
|
||||
}
|
||||
#endif
|
||||
|
||||
void buffer_barrier()
|
||||
{
|
||||
@ -40,13 +38,11 @@ void full_barrier_exec()
|
||||
barrier();
|
||||
}
|
||||
|
||||
#if 0
|
||||
void image_barrier_exec()
|
||||
{
|
||||
memoryBarrierImage();
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
void buffer_barrier_exec()
|
||||
{
|
||||
@ -69,13 +65,13 @@ void main()
|
||||
{
|
||||
barrier_shared();
|
||||
full_barrier();
|
||||
//image_barrier();
|
||||
image_barrier();
|
||||
buffer_barrier();
|
||||
group_barrier();
|
||||
|
||||
barrier_shared_exec();
|
||||
full_barrier_exec();
|
||||
//image_barrier_exec();
|
||||
image_barrier_exec();
|
||||
buffer_barrier_exec();
|
||||
group_barrier_exec();
|
||||
|
||||
|
126
shaders-msl/vulkan/comp/subgroups.nocompat.invalid.vk.msl21.comp
Normal file
126
shaders-msl/vulkan/comp/subgroups.nocompat.invalid.vk.msl21.comp
Normal file
@ -0,0 +1,126 @@
|
||||
#version 450
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_ballot : require
|
||||
#extension GL_KHR_shader_subgroup_vote : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
#extension GL_KHR_shader_subgroup_quad : require
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) buffer SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
// basic
|
||||
FragColor = float(gl_NumSubgroups);
|
||||
FragColor = float(gl_SubgroupID);
|
||||
FragColor = float(gl_SubgroupSize);
|
||||
FragColor = float(gl_SubgroupInvocationID);
|
||||
subgroupBarrier();
|
||||
subgroupMemoryBarrier();
|
||||
subgroupMemoryBarrierBuffer();
|
||||
subgroupMemoryBarrierShared();
|
||||
subgroupMemoryBarrierImage();
|
||||
bool elected = subgroupElect();
|
||||
|
||||
// ballot
|
||||
FragColor = float(gl_SubgroupEqMask);
|
||||
FragColor = float(gl_SubgroupGeMask);
|
||||
FragColor = float(gl_SubgroupGtMask);
|
||||
FragColor = float(gl_SubgroupLeMask);
|
||||
FragColor = float(gl_SubgroupLtMask);
|
||||
vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u);
|
||||
vec3 first = subgroupBroadcastFirst(vec3(20.0));
|
||||
uvec4 ballot_value = subgroupBallot(true);
|
||||
bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
|
||||
bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
|
||||
uint bit_count = subgroupBallotBitCount(ballot_value);
|
||||
uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
|
||||
uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
|
||||
uint lsb = subgroupBallotFindLSB(ballot_value);
|
||||
uint msb = subgroupBallotFindMSB(ballot_value);
|
||||
|
||||
// shuffle
|
||||
uint shuffled = subgroupShuffle(10u, 8u);
|
||||
uint shuffled_xor = subgroupShuffleXor(30u, 8u);
|
||||
|
||||
// shuffle relative
|
||||
uint shuffled_up = subgroupShuffleUp(20u, 4u);
|
||||
uint shuffled_down = subgroupShuffleDown(20u, 4u);
|
||||
|
||||
// vote
|
||||
bool has_all = subgroupAll(true);
|
||||
bool has_any = subgroupAny(true);
|
||||
bool has_equal = subgroupAllEqual(0);
|
||||
has_equal = subgroupAllEqual(true);
|
||||
|
||||
// arithmetic
|
||||
vec4 added = subgroupAdd(vec4(20.0));
|
||||
ivec4 iadded = subgroupAdd(ivec4(20));
|
||||
vec4 multiplied = subgroupMul(vec4(20.0));
|
||||
ivec4 imultiplied = subgroupMul(ivec4(20));
|
||||
vec4 lo = subgroupMin(vec4(20.0));
|
||||
vec4 hi = subgroupMax(vec4(20.0));
|
||||
ivec4 slo = subgroupMin(ivec4(20));
|
||||
ivec4 shi = subgroupMax(ivec4(20));
|
||||
uvec4 ulo = subgroupMin(uvec4(20));
|
||||
uvec4 uhi = subgroupMax(uvec4(20));
|
||||
uvec4 anded = subgroupAnd(ballot_value);
|
||||
uvec4 ored = subgroupOr(ballot_value);
|
||||
uvec4 xored = subgroupXor(ballot_value);
|
||||
|
||||
added = subgroupInclusiveAdd(added);
|
||||
iadded = subgroupInclusiveAdd(iadded);
|
||||
multiplied = subgroupInclusiveMul(multiplied);
|
||||
imultiplied = subgroupInclusiveMul(imultiplied);
|
||||
//lo = subgroupInclusiveMin(lo); // FIXME: Unsupported by Metal
|
||||
//hi = subgroupInclusiveMax(hi);
|
||||
//slo = subgroupInclusiveMin(slo);
|
||||
//shi = subgroupInclusiveMax(shi);
|
||||
//ulo = subgroupInclusiveMin(ulo);
|
||||
//uhi = subgroupInclusiveMax(uhi);
|
||||
//anded = subgroupInclusiveAnd(anded);
|
||||
//ored = subgroupInclusiveOr(ored);
|
||||
//xored = subgroupInclusiveXor(ored);
|
||||
//added = subgroupExclusiveAdd(lo);
|
||||
|
||||
added = subgroupExclusiveAdd(multiplied);
|
||||
multiplied = subgroupExclusiveMul(multiplied);
|
||||
iadded = subgroupExclusiveAdd(imultiplied);
|
||||
imultiplied = subgroupExclusiveMul(imultiplied);
|
||||
//lo = subgroupExclusiveMin(lo); // FIXME: Unsupported by Metal
|
||||
//hi = subgroupExclusiveMax(hi);
|
||||
//ulo = subgroupExclusiveMin(ulo);
|
||||
//uhi = subgroupExclusiveMax(uhi);
|
||||
//slo = subgroupExclusiveMin(slo);
|
||||
//shi = subgroupExclusiveMax(shi);
|
||||
//anded = subgroupExclusiveAnd(anded);
|
||||
//ored = subgroupExclusiveOr(ored);
|
||||
//xored = subgroupExclusiveXor(ored);
|
||||
|
||||
// clustered
|
||||
added = subgroupClusteredAdd(added, 4u);
|
||||
multiplied = subgroupClusteredMul(multiplied, 4u);
|
||||
iadded = subgroupClusteredAdd(iadded, 4u);
|
||||
imultiplied = subgroupClusteredMul(imultiplied, 4u);
|
||||
lo = subgroupClusteredMin(lo, 4u);
|
||||
hi = subgroupClusteredMax(hi, 4u);
|
||||
ulo = subgroupClusteredMin(ulo, 4u);
|
||||
uhi = subgroupClusteredMax(uhi, 4u);
|
||||
slo = subgroupClusteredMin(slo, 4u);
|
||||
shi = subgroupClusteredMax(shi, 4u);
|
||||
anded = subgroupClusteredAnd(anded, 4u);
|
||||
ored = subgroupClusteredOr(ored, 4u);
|
||||
xored = subgroupClusteredXor(xored, 4u);
|
||||
|
||||
// quad
|
||||
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
|
||||
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
|
||||
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
|
||||
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
#version 450
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
#extension GL_KHR_shader_subgroup_quad : require
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) buffer SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
// Reduced test for functionality exposed on iOS.
|
||||
|
||||
void main()
|
||||
{
|
||||
// basic
|
||||
FragColor = float(gl_NumSubgroups);
|
||||
FragColor = float(gl_SubgroupID);
|
||||
FragColor = float(gl_SubgroupSize);
|
||||
FragColor = float(gl_SubgroupInvocationID);
|
||||
subgroupBarrier();
|
||||
subgroupMemoryBarrier();
|
||||
subgroupMemoryBarrierBuffer();
|
||||
subgroupMemoryBarrierShared();
|
||||
subgroupMemoryBarrierImage();
|
||||
bool elected = subgroupElect();
|
||||
|
||||
// shuffle
|
||||
uint shuffled = subgroupShuffle(10u, 8u);
|
||||
uint shuffled_xor = subgroupShuffleXor(30u, 8u);
|
||||
|
||||
// shuffle relative
|
||||
uint shuffled_up = subgroupShuffleUp(20u, 4u);
|
||||
uint shuffled_down = subgroupShuffleDown(20u, 4u);
|
||||
|
||||
// quad
|
||||
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
|
||||
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
|
||||
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
|
||||
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
|
||||
}
|
651
spirv_msl.cpp
651
spirv_msl.cpp
@ -93,7 +93,14 @@ void CompilerMSL::build_implicit_builtins()
|
||||
bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
|
||||
bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex;
|
||||
bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
|
||||
if (need_subpass_input || need_sample_pos || need_vertex_params || need_tesc_params)
|
||||
bool need_subgroup_mask =
|
||||
active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
|
||||
active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
|
||||
active_input_builtins.get(BuiltInSubgroupLtMask);
|
||||
bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
|
||||
active_input_builtins.get(BuiltInSubgroupGtMask));
|
||||
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
|
||||
needs_subgroup_invocation_id)
|
||||
{
|
||||
bool has_frag_coord = false;
|
||||
bool has_sample_id = false;
|
||||
@ -103,18 +110,21 @@ void CompilerMSL::build_implicit_builtins()
|
||||
bool has_base_instance = false;
|
||||
bool has_invocation_id = false;
|
||||
bool has_primitive_id = false;
|
||||
bool has_subgroup_invocation_id = false;
|
||||
bool has_subgroup_size = false;
|
||||
|
||||
ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
|
||||
if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
|
||||
return;
|
||||
|
||||
if (need_subpass_input && ir.meta[var.self].decoration.builtin_type == BuiltInFragCoord)
|
||||
BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
|
||||
if (need_subpass_input && builtin == BuiltInFragCoord)
|
||||
{
|
||||
builtin_frag_coord_id = var.self;
|
||||
has_frag_coord = true;
|
||||
}
|
||||
|
||||
if (need_sample_pos && ir.meta[var.self].decoration.builtin_type == BuiltInSampleId)
|
||||
if (need_sample_pos && builtin == BuiltInSampleId)
|
||||
{
|
||||
builtin_sample_id_id = var.self;
|
||||
has_sample_id = true;
|
||||
@ -122,7 +132,7 @@ void CompilerMSL::build_implicit_builtins()
|
||||
|
||||
if (need_vertex_params)
|
||||
{
|
||||
switch (ir.meta[var.self].decoration.builtin_type)
|
||||
switch (builtin)
|
||||
{
|
||||
case BuiltInVertexIndex:
|
||||
builtin_vertex_idx_id = var.self;
|
||||
@ -147,7 +157,7 @@ void CompilerMSL::build_implicit_builtins()
|
||||
|
||||
if (need_tesc_params)
|
||||
{
|
||||
switch (ir.meta[var.self].decoration.builtin_type)
|
||||
switch (builtin)
|
||||
{
|
||||
case BuiltInInvocationId:
|
||||
builtin_invocation_id_id = var.self;
|
||||
@ -161,6 +171,18 @@ void CompilerMSL::build_implicit_builtins()
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
|
||||
{
|
||||
builtin_subgroup_invocation_id_id = var.self;
|
||||
has_subgroup_invocation_id = true;
|
||||
}
|
||||
|
||||
if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
|
||||
{
|
||||
builtin_subgroup_size_id = var.self;
|
||||
has_subgroup_size = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (!has_frag_coord && need_subpass_input)
|
||||
@ -311,6 +333,58 @@ void CompilerMSL::build_implicit_builtins()
|
||||
builtin_primitive_id_id = var_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
|
||||
{
|
||||
uint32_t offset = ir.increase_bound_by(3);
|
||||
uint32_t type_id = offset;
|
||||
uint32_t type_ptr_id = offset + 1;
|
||||
uint32_t var_id = offset + 2;
|
||||
|
||||
// Create gl_SubgroupInvocationID.
|
||||
SPIRType uint_type;
|
||||
uint_type.basetype = SPIRType::UInt;
|
||||
uint_type.width = 32;
|
||||
set<SPIRType>(type_id, uint_type);
|
||||
|
||||
SPIRType uint_type_ptr;
|
||||
uint_type_ptr = uint_type;
|
||||
uint_type_ptr.pointer = true;
|
||||
uint_type_ptr.parent_type = type_id;
|
||||
uint_type_ptr.storage = StorageClassInput;
|
||||
auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
|
||||
ptr_type.self = type_id;
|
||||
|
||||
set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
|
||||
set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
|
||||
builtin_subgroup_invocation_id_id = var_id;
|
||||
}
|
||||
|
||||
if (!has_subgroup_size && need_subgroup_ge_mask)
|
||||
{
|
||||
uint32_t offset = ir.increase_bound_by(3);
|
||||
uint32_t type_id = offset;
|
||||
uint32_t type_ptr_id = offset + 1;
|
||||
uint32_t var_id = offset + 2;
|
||||
|
||||
// Create gl_SubgroupSize.
|
||||
SPIRType uint_type;
|
||||
uint_type.basetype = SPIRType::UInt;
|
||||
uint_type.width = 32;
|
||||
set<SPIRType>(type_id, uint_type);
|
||||
|
||||
SPIRType uint_type_ptr;
|
||||
uint_type_ptr = uint_type;
|
||||
uint_type_ptr.pointer = true;
|
||||
uint_type_ptr.parent_type = type_id;
|
||||
uint_type_ptr.storage = StorageClassInput;
|
||||
auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
|
||||
ptr_type.self = type_id;
|
||||
|
||||
set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
|
||||
set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
|
||||
builtin_subgroup_size_id = var_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_aux_buffer_def)
|
||||
@ -598,6 +672,7 @@ string CompilerMSL::compile()
|
||||
update_active_builtins();
|
||||
analyze_image_and_sampler_usage();
|
||||
analyze_sampled_image_usage();
|
||||
preprocess_op_codes();
|
||||
build_implicit_builtins();
|
||||
|
||||
fixup_image_load_store_access();
|
||||
@ -606,9 +681,6 @@ string CompilerMSL::compile()
|
||||
if (aux_buffer_id)
|
||||
active_interface_variables.insert(aux_buffer_id);
|
||||
|
||||
// Preprocess OpCodes to extract the need to output additional header content
|
||||
preprocess_op_codes();
|
||||
|
||||
// Create structs to hold input, output and uniform variables.
|
||||
// Do output first to ensure out. is declared at top of entry function.
|
||||
qual_pos_var_name = "";
|
||||
@ -700,6 +772,9 @@ void CompilerMSL::preprocess_op_codes()
|
||||
is_rasterization_disabled = true;
|
||||
capture_output_to_buffer = true;
|
||||
}
|
||||
|
||||
if (preproc.needs_subgroup_invocation_id)
|
||||
needs_subgroup_invocation_id = true;
|
||||
}
|
||||
|
||||
// Move the Private and Workgroup global variables to the entry function.
|
||||
@ -2877,6 +2952,90 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallot:
|
||||
statement("inline uint4 spvSubgroupBallot(bool value)");
|
||||
begin_scope();
|
||||
statement("simd_vote vote = simd_ballot(value);");
|
||||
statement("// simd_ballot() returns a 64-bit integer-like object, but");
|
||||
statement("// SPIR-V callers expect a uint4. We must convert.");
|
||||
statement("// FIXME: This won't include higher bits if Apple ever supports");
|
||||
statement("// 128 lanes in an SIMD-group.");
|
||||
statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
|
||||
"32) & 0xFFFFFFFF), 0, 0);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotBitExtract:
|
||||
statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
|
||||
begin_scope();
|
||||
statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotFindLSB:
|
||||
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
|
||||
begin_scope();
|
||||
statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
|
||||
"ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotFindMSB:
|
||||
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
|
||||
begin_scope();
|
||||
statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
|
||||
"(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
|
||||
"ballot.z == 0), ballot.w == 0);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotBitCount:
|
||||
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
|
||||
begin_scope();
|
||||
statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
|
||||
"uint2(0));");
|
||||
statement("return spvSubgroupBallotBitCount(ballot & mask);");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
|
||||
statement("return spvSubgroupBallotBitCount(ballot & mask);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupAllEqual:
|
||||
// Metal doesn't provide a function to evaluate this directly. But, we can
|
||||
// implement this by comparing every thread's value to one thread's value
|
||||
// (in this case, the value of the first active thread). Then, by the transitive
|
||||
// property of equality, if all comparisons return true, then they are all equal.
|
||||
statement("template<typename T>");
|
||||
statement("inline bool spvSubgroupAllEqual(T value)");
|
||||
begin_scope();
|
||||
statement("return simd_all(value == simd_broadcast_first(value));");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("template<>");
|
||||
statement("inline bool spvSubgroupAllEqual(bool value)");
|
||||
begin_scope();
|
||||
statement("return simd_all(value) || !simd_any(value);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
@ -3895,33 +4054,70 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
|
||||
if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
|
||||
return;
|
||||
|
||||
string bar_stmt = "threadgroup_barrier(mem_flags::";
|
||||
uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
|
||||
uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
|
||||
// Use the wider of the two scopes (smaller value)
|
||||
exe_scope = min(exe_scope, mem_scope);
|
||||
|
||||
string bar_stmt;
|
||||
if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
|
||||
bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
|
||||
else
|
||||
bar_stmt = "threadgroup_barrier";
|
||||
bar_stmt += "(";
|
||||
|
||||
uint32_t mem_sem = id_mem_sem ? get<SPIRConstant>(id_mem_sem).scalar() : uint32_t(MemorySemanticsMaskNone);
|
||||
|
||||
if (get_execution_model() == ExecutionModelTessellationControl)
|
||||
// Use the | operator to combine flags if we can.
|
||||
if (msl_options.supports_msl_version(1, 2))
|
||||
{
|
||||
string mem_flags = "";
|
||||
// For tesc shaders, this also affects objects in the Output storage class.
|
||||
// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
|
||||
bar_stmt += "mem_device";
|
||||
else if (mem_sem & MemorySemanticsCrossWorkgroupMemoryMask)
|
||||
bar_stmt += "mem_device";
|
||||
else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
|
||||
MemorySemanticsAtomicCounterMemoryMask))
|
||||
bar_stmt += "mem_threadgroup";
|
||||
else if (mem_sem & MemorySemanticsImageMemoryMask)
|
||||
bar_stmt += "mem_texture";
|
||||
if (get_execution_model() == ExecutionModelTessellationControl ||
|
||||
(mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
|
||||
mem_flags += "mem_flags::mem_device";
|
||||
if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
|
||||
MemorySemanticsAtomicCounterMemoryMask))
|
||||
{
|
||||
if (!mem_flags.empty())
|
||||
mem_flags += " | ";
|
||||
mem_flags += "mem_flags::mem_threadgroup";
|
||||
}
|
||||
if (mem_sem & MemorySemanticsImageMemoryMask)
|
||||
{
|
||||
if (!mem_flags.empty())
|
||||
mem_flags += " | ";
|
||||
mem_flags += "mem_flags::mem_texture";
|
||||
}
|
||||
|
||||
if (mem_flags.empty())
|
||||
mem_flags = "mem_flags::mem_none";
|
||||
|
||||
bar_stmt += mem_flags;
|
||||
}
|
||||
else
|
||||
bar_stmt += "mem_none";
|
||||
{
|
||||
if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
|
||||
(mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
|
||||
MemorySemanticsAtomicCounterMemoryMask)))
|
||||
bar_stmt += "mem_flags::mem_device_and_threadgroup";
|
||||
else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
|
||||
bar_stmt += "mem_flags::mem_device";
|
||||
else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
|
||||
MemorySemanticsAtomicCounterMemoryMask))
|
||||
bar_stmt += "mem_flags::mem_threadgroup";
|
||||
else if (mem_sem & MemorySemanticsImageMemoryMask)
|
||||
bar_stmt += "mem_flags::mem_texture";
|
||||
else
|
||||
bar_stmt += "mem_flags::mem_none";
|
||||
}
|
||||
|
||||
if (msl_options.is_ios() && (msl_options.supports_msl_version(2) && !msl_options.supports_msl_version(2, 1)))
|
||||
{
|
||||
bar_stmt += ", ";
|
||||
|
||||
// Use the wider of the two scopes (smaller value)
|
||||
uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
|
||||
uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
|
||||
uint32_t scope = min(exe_scope, mem_scope);
|
||||
switch (scope)
|
||||
switch (mem_scope)
|
||||
{
|
||||
case ScopeCrossDevice:
|
||||
case ScopeDevice:
|
||||
@ -5188,6 +5384,8 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
|
||||
{
|
||||
case BuiltInInvocationId:
|
||||
case BuiltInPrimitiveId:
|
||||
case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
|
||||
case BuiltInSubgroupSize: // FIXME: Should work in any stage
|
||||
return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
|
||||
case BuiltInPatchVertices:
|
||||
return "";
|
||||
@ -5347,6 +5545,10 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
|
||||
case BuiltInNumWorkgroups:
|
||||
case BuiltInLocalInvocationId:
|
||||
case BuiltInLocalInvocationIndex:
|
||||
case BuiltInNumSubgroups:
|
||||
case BuiltInSubgroupId:
|
||||
case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
|
||||
case BuiltInSubgroupSize: // FIXME: Should work in any stage
|
||||
return string(" [[") + builtin_qualifier(builtin) + "]]";
|
||||
|
||||
default:
|
||||
@ -5593,7 +5795,9 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
|
||||
if (bi_type != BuiltInSamplePosition && bi_type != BuiltInHelperInvocation &&
|
||||
bi_type != BuiltInPatchVertices && bi_type != BuiltInTessLevelInner &&
|
||||
bi_type != BuiltInTessLevelOuter && bi_type != BuiltInPosition && bi_type != BuiltInPointSize &&
|
||||
bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance)
|
||||
bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance && bi_type != BuiltInSubgroupEqMask &&
|
||||
bi_type != BuiltInSubgroupGeMask && bi_type != BuiltInSubgroupGtMask &&
|
||||
bi_type != BuiltInSubgroupLeMask && bi_type != BuiltInSubgroupLtMask)
|
||||
{
|
||||
if (!ep_args.empty())
|
||||
ep_args += ", ";
|
||||
@ -5911,6 +6115,94 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
|
||||
entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
|
||||
}
|
||||
break;
|
||||
case BuiltInSubgroupEqMask:
|
||||
if (msl_options.is_ios())
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
|
||||
if (!msl_options.supports_msl_version(2, 1))
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
|
||||
builtin_subgroup_invocation_id_id, " > 32 ? uint4(0, (1 << (",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
|
||||
});
|
||||
break;
|
||||
case BuiltInSubgroupGeMask:
|
||||
if (msl_options.is_ios())
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
|
||||
if (!msl_options.supports_msl_version(2, 1))
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
// Case where index < 32, size < 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, index, size - index);
|
||||
// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
|
||||
// Case where index < 32 but size >= 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
|
||||
// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
|
||||
// Case where index >= 32:
|
||||
// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
|
||||
// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
|
||||
// This is expressed without branches to avoid divergent
|
||||
// control flow--hence the complicated min/max expressions.
|
||||
// This is further complicated by the fact that if you attempt
|
||||
// to bfe out-of-bounds on Metal, undefined behavior is the
|
||||
// result.
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
|
||||
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
|
||||
to_expression(builtin_subgroup_size_id), " - (int)max(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
|
||||
});
|
||||
break;
|
||||
case BuiltInSubgroupGtMask:
|
||||
if (msl_options.is_ios())
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
|
||||
if (!msl_options.supports_msl_version(2, 1))
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
// The same logic applies here, except now the index is one
|
||||
// more than the subgroup invocation ID.
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
|
||||
to_expression(builtin_subgroup_size_id), ", 32) - (int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
" - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
|
||||
to_expression(builtin_subgroup_size_id), " - (int)max(",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
|
||||
});
|
||||
break;
|
||||
case BuiltInSubgroupLeMask:
|
||||
if (msl_options.is_ios())
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
|
||||
if (!msl_options.supports_msl_version(2, 1))
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, 0, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
" + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
|
||||
});
|
||||
break;
|
||||
case BuiltInSubgroupLtMask:
|
||||
if (msl_options.is_ios())
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
|
||||
if (!msl_options.supports_msl_version(2, 1))
|
||||
SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
|
||||
" = uint4(extract_bits(0xFFFFFFFF, 0, min(",
|
||||
to_expression(builtin_subgroup_invocation_id_id),
|
||||
", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
|
||||
to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
|
||||
});
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -6266,6 +6558,7 @@ void CompilerMSL::replace_illegal_names()
|
||||
"M_2_SQRTPI",
|
||||
"M_SQRT2",
|
||||
"M_SQRT1_2",
|
||||
"quad_broadcast",
|
||||
};
|
||||
|
||||
static const unordered_set<string> illegal_func_names = {
|
||||
@ -6748,6 +7041,245 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
|
||||
return img_type_name;
|
||||
}
|
||||
|
||||
void CompilerMSL::emit_subgroup_op(const Instruction &i)
|
||||
{
|
||||
const uint32_t *ops = stream(i);
|
||||
auto op = static_cast<Op>(i.op);
|
||||
|
||||
// Metal 2.0 is required. iOS only supports quad ops. macOS only supports
|
||||
// broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
|
||||
// Note that iOS makes no distinction between a quad-group and a subgroup;
|
||||
// all subgroups are quad-groups there.
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
|
||||
|
||||
if (msl_options.is_ios())
|
||||
{
|
||||
switch (op)
|
||||
{
|
||||
default:
|
||||
SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
|
||||
case OpGroupNonUniformElect:
|
||||
case OpGroupNonUniformBroadcast:
|
||||
case OpGroupNonUniformShuffle:
|
||||
case OpGroupNonUniformShuffleXor:
|
||||
case OpGroupNonUniformShuffleUp:
|
||||
case OpGroupNonUniformShuffleDown:
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
case OpGroupNonUniformQuadBroadcast:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
|
||||
{
|
||||
switch (op)
|
||||
{
|
||||
default:
|
||||
SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.0 and up.");
|
||||
case OpGroupNonUniformElect:
|
||||
case OpGroupNonUniformBroadcast:
|
||||
case OpGroupNonUniformShuffle:
|
||||
case OpGroupNonUniformShuffleXor:
|
||||
case OpGroupNonUniformShuffleUp:
|
||||
case OpGroupNonUniformShuffleDown:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
|
||||
auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
|
||||
if (scope != ScopeSubgroup)
|
||||
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case OpGroupNonUniformElect:
|
||||
// Vulkan spec says we have to support this if we support subgroups at all.
|
||||
// But Metal prior to macOS 10.14 doesn't have the simd_is_first() function, and
|
||||
// iOS doesn't have it at all. So we fake it by comparing the subgroup-local
|
||||
// ID to 0. This isn't quite correct: this is supposed to return if we're the
|
||||
// lowest *active* thread, but we'll otherwise be unable to support subgroups
|
||||
// on macOS 10.13 or iOS.
|
||||
if (msl_options.is_macos() && msl_options.supports_msl_version(2, 1))
|
||||
emit_op(result_type, id, "simd_is_first()", true);
|
||||
else
|
||||
emit_op(result_type, id, join("(", to_expression(builtin_subgroup_invocation_id_id), " == 0)"), true);
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBroadcast:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4],
|
||||
msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBroadcastFirst:
|
||||
emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallot:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformInverseBallot:
|
||||
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitExtract:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindLSB:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindMSB:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
{
|
||||
auto operation = static_cast<GroupOperation>(ops[3]);
|
||||
if (operation == GroupOperationReduce)
|
||||
emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
|
||||
else if (operation == GroupOperationInclusiveScan)
|
||||
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
|
||||
"spvSubgroupBallotInclusiveBitCount");
|
||||
else if (operation == GroupOperationExclusiveScan)
|
||||
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
|
||||
"spvSubgroupBallotExclusiveBitCount");
|
||||
else
|
||||
SPIRV_CROSS_THROW("Invalid BitCount operation.");
|
||||
break;
|
||||
}
|
||||
|
||||
case OpGroupNonUniformShuffle:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformShuffleXor:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4],
|
||||
msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformShuffleUp:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4],
|
||||
msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformShuffleDown:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4],
|
||||
msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAll:
|
||||
emit_unary_func_op(result_type, id, ops[3], "simd_all");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAny:
|
||||
emit_unary_func_op(result_type, id, ops[3], "simd_any");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAllEqual:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
|
||||
break;
|
||||
|
||||
// clang-format off
|
||||
#define MSL_GROUP_OP(op, msl_op) \
|
||||
case OpGroupNonUniform##op: \
|
||||
{ \
|
||||
auto operation = static_cast<GroupOperation>(ops[3]); \
|
||||
if (operation == GroupOperationReduce) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
|
||||
else if (operation == GroupOperationInclusiveScan) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
|
||||
else if (operation == GroupOperationExclusiveScan) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
|
||||
else if (operation == GroupOperationClusteredReduce) \
|
||||
{ \
|
||||
/* Only cluster sizes of 4 are supported. */ \
|
||||
uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
|
||||
if (cluster_size != 4) \
|
||||
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
|
||||
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
|
||||
} \
|
||||
else \
|
||||
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||
break; \
|
||||
}
|
||||
MSL_GROUP_OP(FAdd, sum)
|
||||
MSL_GROUP_OP(FMul, product)
|
||||
MSL_GROUP_OP(IAdd, sum)
|
||||
MSL_GROUP_OP(IMul, product)
|
||||
#undef MSL_GROUP_OP
|
||||
// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
|
||||
#define MSL_GROUP_OP(op, msl_op) \
|
||||
case OpGroupNonUniform##op: \
|
||||
{ \
|
||||
auto operation = static_cast<GroupOperation>(ops[3]); \
|
||||
if (operation == GroupOperationReduce) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
|
||||
else if (operation == GroupOperationInclusiveScan) \
|
||||
SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
|
||||
else if (operation == GroupOperationExclusiveScan) \
|
||||
SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
|
||||
else if (operation == GroupOperationClusteredReduce) \
|
||||
{ \
|
||||
/* Only cluster sizes of 4 are supported. */ \
|
||||
uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
|
||||
if (cluster_size != 4) \
|
||||
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
|
||||
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
|
||||
} \
|
||||
else \
|
||||
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||
break; \
|
||||
}
|
||||
MSL_GROUP_OP(FMin, min)
|
||||
MSL_GROUP_OP(FMax, max)
|
||||
MSL_GROUP_OP(SMin, min)
|
||||
MSL_GROUP_OP(SMax, max)
|
||||
MSL_GROUP_OP(UMin, min)
|
||||
MSL_GROUP_OP(UMax, max)
|
||||
MSL_GROUP_OP(BitwiseAnd, and)
|
||||
MSL_GROUP_OP(BitwiseOr, or)
|
||||
MSL_GROUP_OP(BitwiseXor, xor)
|
||||
MSL_GROUP_OP(LogicalAnd, and)
|
||||
MSL_GROUP_OP(LogicalOr, or)
|
||||
MSL_GROUP_OP(LogicalXor, xor)
|
||||
// clang-format on
|
||||
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
{
|
||||
// We can implement this easily based on the following table giving
|
||||
// the target lane ID from the direction and current lane ID:
|
||||
// Direction
|
||||
// | 0 | 1 | 2 |
|
||||
// ---+---+---+---+
|
||||
// L 0 | 1 2 3
|
||||
// a 1 | 0 3 2
|
||||
// n 2 | 3 0 1
|
||||
// e 3 | 2 1 0
|
||||
// Notice that target = source ^ (direction + 1).
|
||||
uint32_t mask = get<SPIRConstant>(ops[4]).scalar() + 1;
|
||||
uint32_t mask_id = ir.increase_bound_by(1);
|
||||
set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
|
||||
emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
|
||||
break;
|
||||
}
|
||||
|
||||
case OpGroupNonUniformQuadBroadcast:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
|
||||
break;
|
||||
|
||||
default:
|
||||
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
|
||||
}
|
||||
|
||||
register_control_dependent_expression(id);
|
||||
}
|
||||
|
||||
string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
|
||||
{
|
||||
if (out_type.basetype == in_type.basetype)
|
||||
@ -6954,6 +7486,32 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
|
||||
case BuiltInLocalInvocationIndex:
|
||||
return "thread_index_in_threadgroup";
|
||||
|
||||
case BuiltInSubgroupSize:
|
||||
return "thread_execution_width";
|
||||
|
||||
case BuiltInNumSubgroups:
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
|
||||
|
||||
case BuiltInSubgroupId:
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
|
||||
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
if (!msl_options.supports_msl_version(2))
|
||||
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
|
||||
return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
|
||||
|
||||
case BuiltInSubgroupEqMask:
|
||||
case BuiltInSubgroupGeMask:
|
||||
case BuiltInSubgroupGtMask:
|
||||
case BuiltInSubgroupLeMask:
|
||||
case BuiltInSubgroupLtMask:
|
||||
// Shouldn't be reached.
|
||||
SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
|
||||
|
||||
default:
|
||||
return "unsupported-built-in";
|
||||
}
|
||||
@ -7042,7 +7600,17 @@ string CompilerMSL::builtin_type_decl(BuiltIn builtin)
|
||||
case BuiltInWorkgroupId:
|
||||
return "uint3";
|
||||
case BuiltInLocalInvocationIndex:
|
||||
case BuiltInNumSubgroups:
|
||||
case BuiltInSubgroupId:
|
||||
case BuiltInSubgroupSize:
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
return "uint";
|
||||
case BuiltInSubgroupEqMask:
|
||||
case BuiltInSubgroupGeMask:
|
||||
case BuiltInSubgroupGtMask:
|
||||
case BuiltInSubgroupLeMask:
|
||||
case BuiltInSubgroupLtMask:
|
||||
return "uint4";
|
||||
|
||||
case BuiltInHelperInvocation:
|
||||
return "bool";
|
||||
@ -7267,6 +7835,20 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
|
||||
uses_atomics = true;
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformElect:
|
||||
if (compiler.msl_options.is_ios() || !compiler.msl_options.supports_msl_version(2, 1))
|
||||
needs_subgroup_invocation_id = true;
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformInverseBallot:
|
||||
needs_subgroup_invocation_id = true;
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
if (args[3] != GroupOperationReduce)
|
||||
needs_subgroup_invocation_id = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -7433,6 +8015,25 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
|
||||
break;
|
||||
}
|
||||
|
||||
case OpGroupNonUniformBallot:
|
||||
return SPVFuncImplSubgroupBallot;
|
||||
|
||||
case OpGroupNonUniformInverseBallot:
|
||||
case OpGroupNonUniformBallotBitExtract:
|
||||
return SPVFuncImplSubgroupBallotBitExtract;
|
||||
|
||||
case OpGroupNonUniformBallotFindLSB:
|
||||
return SPVFuncImplSubgroupBallotFindLSB;
|
||||
|
||||
case OpGroupNonUniformBallotFindMSB:
|
||||
return SPVFuncImplSubgroupBallotFindMSB;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
return SPVFuncImplSubgroupBallotBitCount;
|
||||
|
||||
case OpGroupNonUniformAllEqual:
|
||||
return SPVFuncImplSubgroupAllEqual;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -344,6 +344,12 @@ protected:
|
||||
SPVFuncImplRowMajor4x2,
|
||||
SPVFuncImplRowMajor4x3,
|
||||
SPVFuncImplTextureSwizzle,
|
||||
SPVFuncImplSubgroupBallot,
|
||||
SPVFuncImplSubgroupBallotBitExtract,
|
||||
SPVFuncImplSubgroupBallotFindLSB,
|
||||
SPVFuncImplSubgroupBallotFindMSB,
|
||||
SPVFuncImplSubgroupBallotBitCount,
|
||||
SPVFuncImplSubgroupAllEqual,
|
||||
SPVFuncImplArrayCopyMultidimMax = 6
|
||||
};
|
||||
|
||||
@ -354,6 +360,7 @@ protected:
|
||||
void emit_header() override;
|
||||
void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
|
||||
void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override;
|
||||
void emit_subgroup_op(const Instruction &i) override;
|
||||
void emit_fixup() override;
|
||||
std::string to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
|
||||
const std::string &qualifier = "");
|
||||
@ -477,6 +484,8 @@ protected:
|
||||
uint32_t builtin_base_instance_id = 0;
|
||||
uint32_t builtin_invocation_id_id = 0;
|
||||
uint32_t builtin_primitive_id_id = 0;
|
||||
uint32_t builtin_subgroup_invocation_id_id = 0;
|
||||
uint32_t builtin_subgroup_size_id = 0;
|
||||
uint32_t aux_buffer_id = 0;
|
||||
|
||||
void bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
|
||||
@ -518,6 +527,7 @@ protected:
|
||||
bool needs_aux_buffer_def = false;
|
||||
bool used_aux_buffer = false;
|
||||
bool added_builtin_tess_level = false;
|
||||
bool needs_subgroup_invocation_id = false;
|
||||
std::string qual_pos_var_name;
|
||||
std::string stage_in_var_name = "in";
|
||||
std::string stage_out_var_name = "out";
|
||||
@ -561,6 +571,7 @@ protected:
|
||||
bool suppress_missing_prototypes = false;
|
||||
bool uses_atomics = false;
|
||||
bool uses_resource_write = false;
|
||||
bool needs_subgroup_invocation_id = false;
|
||||
};
|
||||
|
||||
// OpcodeHandler that scans for uses of sampled images
|
||||
|
Loading…
Reference in New Issue
Block a user