SPIRV-Cross/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl22.ios.comp
Chip Davis 68908355a9 MSL: Expand subgroup support.
Add support for declaring a fixed subgroup size. Metal, like Vulkan with
`VK_EXT_subgroup_size_control`, allows the thread execution width to
vary depending on factors such as register usage. Unfortunately, this
breaks several tests that depend on the subgroup size being what the
device says it is. So we'll fix the subgroup size at the size the device
declares. The extra invocations in the subgroup will appear to be
inactive. Because of this, the ballot mask builtins are now ANDed with
the active subgroup mask.

Add support for emulating a subgroup of size 1. This is intended to be
used by Vulkan Portability implementations (e.g. MoltenVK) when the
hardware/software combo provides insufficient support for subgroups.
Luckily for us, Vulkan 1.1 only requires that the subgroup size be at
least 1.

Add support for quadgroup and SIMD-group functions which were added to
iOS in Metal 2.2 and 2.3. This will allow clients to take advantage of
expanded quadgroup and SIMD-group support in recent Metal versions and
on recent Apple GPUs (families 6 and 7).

Gut emulation of subgroup builtins in fragment shaders. It turns out
codegen for the SIMD-group functions in fragment wasn't implemented for
AMD on Mojave; it's a safe bet that it wasn't implemented for the other
drivers either. Subgroup support in fragment shaders now requires Metal
2.2.
2020-11-20 15:55:49 -06:00

283 lines
9.2 KiB
Plaintext

#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float FragColor;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
template<typename T>
inline T spvSubgroupBroadcast(T value, ushort lane)
{
return quad_broadcast(value, lane);
}
template<>
inline bool spvSubgroupBroadcast(bool value, ushort lane)
{
return !!quad_broadcast((ushort)value, lane);
}
template<uint N>
inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)
{
return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);
}
template<typename T>
inline T spvSubgroupBroadcastFirst(T value)
{
return quad_broadcast_first(value);
}
template<>
inline bool spvSubgroupBroadcastFirst(bool value)
{
return !!quad_broadcast_first((ushort)value);
}
template<uint N>
inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)
{
return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);
}
inline uint4 spvSubgroupBallot(bool value)
{
return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);
}
inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
{
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
}
inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));
ballot &= mask;
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
}
inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));
ballot &= mask;
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
}
inline uint spvPopCount4(uint4 ballot)
{
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));
return spvPopCount4(ballot & mask);
}
template<typename T>
inline bool spvSubgroupAllEqual(T value)
{
return quad_all(all(value == quad_broadcast_first(value)));
}
template<>
inline bool spvSubgroupAllEqual(bool value)
{
return quad_all(value) || !quad_any(value);
}
template<uint N>
inline bool spvSubgroupAllEqual(vec<bool, N> value)
{
return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));
}
template<typename T>
inline T spvSubgroupShuffle(T value, ushort lane)
{
return quad_shuffle(value, lane);
}
template<>
inline bool spvSubgroupShuffle(bool value, ushort lane)
{
return !!quad_shuffle((ushort)value, lane);
}
template<uint N>
inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
{
return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);
}
template<typename T>
inline T spvSubgroupShuffleXor(T value, ushort mask)
{
return quad_shuffle_xor(value, mask);
}
template<>
inline bool spvSubgroupShuffleXor(bool value, ushort mask)
{
return !!quad_shuffle_xor((ushort)value, mask);
}
template<uint N>
inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)
{
return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);
}
template<typename T>
inline T spvSubgroupShuffleUp(T value, ushort delta)
{
return quad_shuffle_up(value, delta);
}
template<>
inline bool spvSubgroupShuffleUp(bool value, ushort delta)
{
return !!quad_shuffle_up((ushort)value, delta);
}
template<uint N>
inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)
{
return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);
}
template<typename T>
inline T spvSubgroupShuffleDown(T value, ushort delta)
{
return quad_shuffle_down(value, delta);
}
template<>
inline bool spvSubgroupShuffleDown(bool value, ushort delta)
{
return !!quad_shuffle_down((ushort)value, delta);
}
template<uint N>
inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)
{
return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);
}
template<typename T>
inline T spvQuadBroadcast(T value, uint lane)
{
return quad_broadcast(value, lane);
}
template<>
inline bool spvQuadBroadcast(bool value, uint lane)
{
return !!quad_broadcast((ushort)value, lane);
}
template<uint N>
inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)
{
return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);
}
template<typename T>
inline T spvQuadSwap(T value, uint dir)
{
return quad_shuffle_xor(value, dir + 1);
}
template<>
inline bool spvQuadSwap(bool value, uint dir)
{
return !!quad_shuffle_xor((ushort)value, dir + 1);
}
template<uint N>
inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)
{
return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);
}
kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
{
uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0));
uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID + 1, gl_SubgroupSize - gl_SubgroupInvocationID - 1), uint3(0));
uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));
uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint3(0));
_9.FragColor = float(gl_NumSubgroups);
_9.FragColor = float(gl_SubgroupID);
_9.FragColor = float(gl_SubgroupSize);
_9.FragColor = float(gl_SubgroupInvocationID);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture);
simdgroup_barrier(mem_flags::mem_device);
simdgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_barrier(mem_flags::mem_texture);
bool _39 = quad_is_first();
bool elected = _39;
_9.FragColor = float4(gl_SubgroupEqMask).x;
_9.FragColor = float4(gl_SubgroupGeMask).x;
_9.FragColor = float4(gl_SubgroupGtMask).x;
_9.FragColor = float4(gl_SubgroupLeMask).x;
_9.FragColor = float4(gl_SubgroupLtMask).x;
float4 broadcasted = spvSubgroupBroadcast(float4(10.0), 8u);
bool2 broadcasted_bool = spvSubgroupBroadcast(bool2(true), 8u);
float3 first = spvSubgroupBroadcastFirst(float3(20.0));
bool4 first_bool = spvSubgroupBroadcastFirst(bool4(false));
uint4 ballot_value = spvSubgroupBallot(true);
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
uint shuffled = spvSubgroupShuffle(10u, 8u);
bool shuffled_bool = spvSubgroupShuffle(true, 9u);
uint shuffled_xor = spvSubgroupShuffleXor(30u, 8u);
bool shuffled_xor_bool = spvSubgroupShuffleXor(false, 9u);
uint shuffled_up = spvSubgroupShuffleUp(20u, 4u);
bool shuffled_up_bool = spvSubgroupShuffleUp(true, 4u);
uint shuffled_down = spvSubgroupShuffleDown(20u, 4u);
bool shuffled_down_bool = spvSubgroupShuffleDown(false, 4u);
bool has_all = quad_all(true);
bool has_any = quad_any(true);
bool has_equal = spvSubgroupAllEqual(0);
has_equal = spvSubgroupAllEqual(true);
has_equal = spvSubgroupAllEqual(float3(0.0, 1.0, 2.0));
has_equal = spvSubgroupAllEqual(bool4(true, true, false, true));
float4 swap_horiz = spvQuadSwap(float4(20.0), 0u);
bool4 swap_horiz_bool = spvQuadSwap(bool4(true), 0u);
float4 swap_vertical = spvQuadSwap(float4(20.0), 1u);
bool4 swap_vertical_bool = spvQuadSwap(bool4(true), 1u);
float4 swap_diagonal = spvQuadSwap(float4(20.0), 2u);
bool4 swap_diagonal_bool = spvQuadSwap(bool4(true), 2u);
float4 quad_broadcast0 = spvQuadBroadcast(float4(20.0), 3u);
bool4 quad_broadcast_bool = spvQuadBroadcast(bool4(true), 3u);
}