Partially implement subgroup ops for HLSL SM 6.0.

Lots of stuff that needs tons of emulation, which I'm not going to
bother with.
This commit is contained in:
Hans-Kristian Arntzen 2018-04-11 15:02:02 +02:00
parent 146ea76f52
commit c266429be9
7 changed files with 644 additions and 2 deletions

View File

@ -0,0 +1,67 @@
RWByteAddressBuffer _9 : register(u0, space0);
static uint4 gl_SubgroupEqMask;
static uint4 gl_SubgroupGeMask;
static uint4 gl_SubgroupGtMask;
static uint4 gl_SubgroupLeMask;
static uint4 gl_SubgroupLtMask;
void comp_main()
{
_9.Store(0, asuint(float(WaveGetLaneCount())));
_9.Store(0, asuint(float(WaveGetLaneIndex())));
_9.Store(0, asuint(float4(gl_SubgroupEqMask).x));
_9.Store(0, asuint(float4(gl_SubgroupGeMask).x));
_9.Store(0, asuint(float4(gl_SubgroupGtMask).x));
_9.Store(0, asuint(float4(gl_SubgroupLeMask).x));
_9.Store(0, asuint(float4(gl_SubgroupLtMask).x));
uint4 _75 = WaveActiveBallot(true);
float4 _88 = WaveActiveSum(20.0f.xxxx);
int4 _94 = WaveActiveSum(int4(20, 20, 20, 20));
float4 _96 = WaveActiveProduct(20.0f.xxxx);
int4 _98 = WaveActiveProduct(int4(20, 20, 20, 20));
float4 _127 = WavePrefixProduct(_96) * _96;
int4 _129 = WavePrefixProduct(_98) * _98;
}
[numthreads(1, 1, 1)]
void main()
{
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
uint gt_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
uint le_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
comp_main();
}

View File

@ -0,0 +1,93 @@
RWByteAddressBuffer _9 : register(u0, space0);
static uint4 gl_SubgroupEqMask;
static uint4 gl_SubgroupGeMask;
static uint4 gl_SubgroupGtMask;
static uint4 gl_SubgroupLeMask;
static uint4 gl_SubgroupLtMask;
void comp_main()
{
_9.Store(0, asuint(float(WaveGetLaneCount())));
_9.Store(0, asuint(float(WaveGetLaneIndex())));
bool elected = WaveIsFirstLane();
_9.Store(0, asuint(float4(gl_SubgroupEqMask).x));
_9.Store(0, asuint(float4(gl_SubgroupGeMask).x));
_9.Store(0, asuint(float4(gl_SubgroupGtMask).x));
_9.Store(0, asuint(float4(gl_SubgroupLeMask).x));
_9.Store(0, asuint(float4(gl_SubgroupLtMask).x));
float4 broadcasted = WaveReadLaneAt(10.0f.xxxx, 8u);
float3 first = WaveReadLaneFirst(20.0f.xxx);
uint4 ballot_value = WaveActiveBallot(true);
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
bool has_all = WaveActiveAllTrue(true);
bool has_any = WaveActiveAnyTrue(true);
bool has_equal = WaveActiveAllEqualBool(true);
float4 added = WaveActiveSum(20.0f.xxxx);
int4 iadded = WaveActiveSum(int4(20, 20, 20, 20));
float4 multiplied = WaveActiveProduct(20.0f.xxxx);
int4 imultiplied = WaveActiveProduct(int4(20, 20, 20, 20));
float4 lo = WaveActiveMin(20.0f.xxxx);
float4 hi = WaveActiveMax(20.0f.xxxx);
int4 slo = WaveActiveMin(int4(20, 20, 20, 20));
int4 shi = WaveActiveMax(int4(20, 20, 20, 20));
uint4 ulo = WaveActiveMin(uint4(20u, 20u, 20u, 20u));
uint4 uhi = WaveActiveMax(uint4(20u, 20u, 20u, 20u));
uint4 anded = WaveActiveBitAnd(ballot_value);
uint4 ored = WaveActiveBitOr(ballot_value);
uint4 xored = WaveActiveBitXor(ballot_value);
added = WavePrefixSum(added) + added;
iadded = WavePrefixSum(iadded) + iadded;
multiplied = WavePrefixProduct(multiplied) * multiplied;
imultiplied = WavePrefixProduct(imultiplied) * imultiplied;
added = WavePrefixSum(multiplied);
multiplied = WavePrefixProduct(multiplied);
iadded = WavePrefixSum(imultiplied);
imultiplied = WavePrefixProduct(imultiplied);
float4 swap_horiz = QuadReadAcrossX(20.0f.xxxx);
float4 swap_vertical = QuadReadAcrossY(20.0f.xxxx);
float4 swap_diagonal = QuadReadAcrossDiagonal(20.0f.xxxx);
float4 quad_broadcast = QuadReadLaneAt(20.0f.xxxx, 3u);
}
[numthreads(1, 1, 1)]
void main()
{
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
uint gt_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
uint le_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
comp_main();
}

View File

@ -0,0 +1,131 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_clustered : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer SSBO
{
float FragColor;
};
void main()
{
// basic
//FragColor = float(gl_NumSubgroups);
//FragColor = float(gl_SubgroupID);
FragColor = float(gl_SubgroupSize);
FragColor = float(gl_SubgroupInvocationID);
subgroupBarrier();
subgroupMemoryBarrier();
subgroupMemoryBarrierBuffer();
subgroupMemoryBarrierShared();
subgroupMemoryBarrierImage();
bool elected = subgroupElect();
// ballot
FragColor = float(gl_SubgroupEqMask);
FragColor = float(gl_SubgroupGeMask);
FragColor = float(gl_SubgroupGtMask);
FragColor = float(gl_SubgroupLeMask);
FragColor = float(gl_SubgroupLtMask);
vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u);
vec3 first = subgroupBroadcastFirst(vec3(20.0));
uvec4 ballot_value = subgroupBallot(true);
//bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
//bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
uint bit_count = subgroupBallotBitCount(ballot_value);
//uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
//uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
//uint lsb = subgroupBallotFindLSB(ballot_value);
//uint msb = subgroupBallotFindMSB(ballot_value);
// shuffle
//uint shuffled = subgroupShuffle(10u, 8u);
//uint shuffled_xor = subgroupShuffleXor(30u, 8u);
// shuffle relative
//uint shuffled_up = subgroupShuffleUp(20u, 4u);
//uint shuffled_down = subgroupShuffleDown(20u, 4u);
// vote
bool has_all = subgroupAll(true);
bool has_any = subgroupAny(true);
bool has_equal = subgroupAllEqual(true);
// arithmetic
vec4 added = subgroupAdd(vec4(20.0));
ivec4 iadded = subgroupAdd(ivec4(20));
vec4 multiplied = subgroupMul(vec4(20.0));
ivec4 imultiplied = subgroupMul(ivec4(20));
vec4 lo = subgroupMin(vec4(20.0));
vec4 hi = subgroupMax(vec4(20.0));
ivec4 slo = subgroupMin(ivec4(20));
ivec4 shi = subgroupMax(ivec4(20));
uvec4 ulo = subgroupMin(uvec4(20));
uvec4 uhi = subgroupMax(uvec4(20));
uvec4 anded = subgroupAnd(ballot_value);
uvec4 ored = subgroupOr(ballot_value);
uvec4 xored = subgroupXor(ballot_value);
added = subgroupInclusiveAdd(added);
iadded = subgroupInclusiveAdd(iadded);
multiplied = subgroupInclusiveMul(multiplied);
imultiplied = subgroupInclusiveMul(imultiplied);
#if 0
lo = subgroupInclusiveMin(lo);
hi = subgroupInclusiveMax(hi);
slo = subgroupInclusiveMin(slo);
shi = subgroupInclusiveMax(shi);
ulo = subgroupInclusiveMin(ulo);
uhi = subgroupInclusiveMax(uhi);
anded = subgroupInclusiveAnd(anded);
ored = subgroupInclusiveOr(ored);
xored = subgroupInclusiveXor(ored);
added = subgroupExclusiveAdd(lo);
#endif
added = subgroupExclusiveAdd(multiplied);
multiplied = subgroupExclusiveMul(multiplied);
iadded = subgroupExclusiveAdd(imultiplied);
imultiplied = subgroupExclusiveMul(imultiplied);
#if 0
lo = subgroupExclusiveMin(lo);
hi = subgroupExclusiveMax(hi);
ulo = subgroupExclusiveMin(ulo);
uhi = subgroupExclusiveMax(uhi);
slo = subgroupExclusiveMin(slo);
shi = subgroupExclusiveMax(shi);
anded = subgroupExclusiveAnd(anded);
ored = subgroupExclusiveOr(ored);
xored = subgroupExclusiveXor(ored);
#endif
#if 0
// clustered
added = subgroupClusteredAdd(added, 4u);
multiplied = subgroupClusteredMul(multiplied, 4u);
iadded = subgroupClusteredAdd(iadded, 4u);
imultiplied = subgroupClusteredMul(imultiplied, 4u);
lo = subgroupClusteredMin(lo, 4u);
hi = subgroupClusteredMax(hi, 4u);
ulo = subgroupClusteredMin(ulo, 4u);
uhi = subgroupClusteredMax(uhi, 4u);
slo = subgroupClusteredMin(slo, 4u);
shi = subgroupClusteredMax(shi, 4u);
anded = subgroupClusteredAnd(anded, 4u);
ored = subgroupClusteredOr(ored, 4u);
xored = subgroupClusteredXor(xored, 4u);
#endif
// quad
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
}

View File

@ -625,6 +625,13 @@ void CompilerHLSL::emit_builtin_inputs_in_struct()
break;
case BuiltInNumWorkgroups:
case BuiltInSubgroupSize:
case BuiltInSubgroupLocalInvocationId:
case BuiltInSubgroupEqMask:
case BuiltInSubgroupLtMask:
case BuiltInSubgroupLeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupGeMask:
// Handled specially.
break;
@ -864,6 +871,11 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas
case BuiltInPointCoord:
// Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
return "float2(0.5f, 0.5f)";
case BuiltInSubgroupLocalInvocationId:
return "WaveGetLaneIndex()";
case BuiltInSubgroupSize:
return "WaveGetLaneCount()";
default:
return CompilerGLSL::builtin_to_glsl(builtin, storage);
}
@ -928,6 +940,22 @@ void CompilerHLSL::emit_builtin_variables()
// Handled specially.
break;
case BuiltInSubgroupLocalInvocationId:
case BuiltInSubgroupSize:
if (hlsl_options.shader_model < 60)
SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
break;
case BuiltInSubgroupEqMask:
case BuiltInSubgroupLtMask:
case BuiltInSubgroupLeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupGeMask:
if (hlsl_options.shader_model < 60)
SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
type = "uint4";
break;
case BuiltInClipDistance:
array_size = clip_distance_count;
type = "float";
@ -940,7 +968,6 @@ void CompilerHLSL::emit_builtin_variables()
default:
SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
break;
}
StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
@ -1225,6 +1252,14 @@ void CompilerHLSL::emit_resources()
auto input_builtins = active_input_builtins;
input_builtins.clear(BuiltInNumWorkgroups);
input_builtins.clear(BuiltInPointCoord);
input_builtins.clear(BuiltInSubgroupSize);
input_builtins.clear(BuiltInSubgroupLocalInvocationId);
input_builtins.clear(BuiltInSubgroupEqMask);
input_builtins.clear(BuiltInSubgroupLtMask);
input_builtins.clear(BuiltInSubgroupLeMask);
input_builtins.clear(BuiltInSubgroupGtMask);
input_builtins.clear(BuiltInSubgroupGeMask);
if (!input_variables.empty() || !input_builtins.empty())
{
require_input = true;
@ -2106,6 +2141,70 @@ void CompilerHLSL::emit_hlsl_entry_point()
case BuiltInNumWorkgroups:
case BuiltInPointCoord:
case BuiltInSubgroupSize:
case BuiltInSubgroupLocalInvocationId:
break;
case BuiltInSubgroupEqMask:
// Emulate these ...
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
break;
case BuiltInSubgroupGeMask:
// Emulate these ...
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
break;
case BuiltInSubgroupGtMask:
// Emulate these ...
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
break;
case BuiltInSubgroupLeMask:
// Emulate these ...
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
break;
case BuiltInSubgroupLtMask:
// Emulate these ...
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
break;
case BuiltInClipDistance:
@ -3528,6 +3627,176 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
register_read(ops[1], ops[2], should_forward(ops[2]));
}
void CompilerHLSL::emit_subgroup_op(const Instruction &i)
{
if (hlsl_options.shader_model < 60)
SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
const uint32_t *ops = stream(i);
auto op = static_cast<Op>(i.op);
uint32_t result_type = ops[0];
uint32_t id = ops[1];
auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
if (scope != ScopeSubgroup)
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
const auto make_inclusive_Sum = [&](const string &expr) -> string {
return join(expr, " + ", to_expression(ops[4]));
};
const auto make_inclusive_Product = [&](const string &expr) -> string {
return join(expr, " * ", to_expression(ops[4]));
};
#define make_inclusive_BitAnd(expr) ""
#define make_inclusive_BitOr(expr) ""
#define make_inclusive_BitXor(expr) ""
#define make_inclusive_Min(expr) ""
#define make_inclusive_Max(expr) ""
switch (op)
{
case OpGroupNonUniformElect:
emit_op(result_type, id, "WaveIsFirstLane()", true);
break;
case OpGroupNonUniformBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
break;
case OpGroupNonUniformBroadcastFirst:
emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
break;
case OpGroupNonUniformBallot:
emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
break;
case OpGroupNonUniformInverseBallot:
SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
break;
case OpGroupNonUniformBallotBitExtract:
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
break;
case OpGroupNonUniformBallotFindLSB:
SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
break;
case OpGroupNonUniformBallotFindMSB:
SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
break;
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationReduce)
{
bool forward = should_forward(ops[4]);
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(", to_enclosed_expression(ops[4]), ".y)");
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(", to_enclosed_expression(ops[4]), ".w)");
emit_op(result_type, id, join(left, " + ", right), forward);
inherit_expression_dependencies(id, ops[4]);
}
else if (operation == GroupOperationInclusiveScan)
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
else if (operation == GroupOperationExclusiveScan)
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
else
SPIRV_CROSS_THROW("Invalid BitCount operation.");
break;
}
case OpGroupNonUniformShuffle:
SPIRV_CROSS_THROW("Cannot trivially implement Shuffle in HLSL.");
case OpGroupNonUniformShuffleXor:
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleXor in HLSL.");
case OpGroupNonUniformShuffleUp:
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleUp in HLSL.");
case OpGroupNonUniformShuffleDown:
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleDown in HLSL.");
case OpGroupNonUniformAll:
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
break;
case OpGroupNonUniformAny:
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
break;
case OpGroupNonUniformAllEqual:
{
auto &type = get<SPIRType>(result_type);
emit_unary_func_op(result_type, id, ops[3],
type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual");
break;
}
#define GROUP_OP(op, hlsl_op, supports_scan) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
else if (operation == GroupOperationInclusiveScan && supports_scan) \
{ \
bool forward = should_forward(ops[4]); \
emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
inherit_expression_dependencies(id, ops[4]); \
} \
else if (operation == GroupOperationExclusiveScan && supports_scan) \
emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
else if (operation == GroupOperationClusteredReduce) \
SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
break; \
}
GROUP_OP(FAdd, Sum, true)
GROUP_OP(FMul, Product, true)
GROUP_OP(FMin, Min, false)
GROUP_OP(FMax, Max, false)
GROUP_OP(IAdd, Sum, true)
GROUP_OP(IMul, Product, true)
GROUP_OP(SMin, Min, false)
GROUP_OP(SMax, Max, false)
GROUP_OP(UMin, Min, false)
GROUP_OP(UMax, Max, false)
GROUP_OP(BitwiseAnd, BitAnd, false)
GROUP_OP(BitwiseOr, BitOr, false)
GROUP_OP(BitwiseXor, BitXor, false)
#undef GROUP_OP
case OpGroupNonUniformQuadSwap:
{
uint32_t direction = get<SPIRConstant>(ops[4]).scalar();
if (direction == 0)
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
else if (direction == 1)
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
else if (direction == 2)
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
else
SPIRV_CROSS_THROW("Invalid quad swap direction.");
break;
}
case OpGroupNonUniformQuadBroadcast:
{
emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
break;
}
default:
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
}
register_control_dependent_expression(id);
}
void CompilerHLSL::emit_instruction(const Instruction &instruction)
{
auto ops = stream(instruction);
@ -4004,6 +4273,12 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
semantics = get<SPIRConstant>(ops[2]).scalar();
}
if (memory == ScopeSubgroup)
{
// No Wave-barriers in HLSL.
break;
}
// We only care about these flags, acquire/release and friends are not relevant to GLSL.
semantics = mask_relevant_memory_semantics(semantics);

View File

@ -157,6 +157,7 @@ private:
void write_access_chain(const SPIRAccessChain &chain, uint32_t value);
void emit_store(const Instruction &instruction);
void emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op);
void emit_subgroup_op(const Instruction &i) override;
void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, const std::string &qualifier,
uint32_t base_offset = 0) override;

View File

@ -155,7 +155,9 @@ def validate_shader_hlsl(shader):
sys.exit(1)
def shader_to_sm(shader):
if '.sm51.' in shader:
if '.sm60.' in shader:
return '60'
elif '.sm51.' in shader:
return '51'
elif '.sm20.' in shader:
return '20'

View File

@ -0,0 +1,73 @@
// Ad-hoc test that the wave op masks work as expected.
#include <glm/glm.hpp>
#include <assert.h>
using namespace glm;
static uvec4 gl_SubgroupEqMask;
static uvec4 gl_SubgroupGeMask;
static uvec4 gl_SubgroupGtMask;
static uvec4 gl_SubgroupLeMask;
static uvec4 gl_SubgroupLtMask;
using uint4 = uvec4;
static void test_main(unsigned wave_index)
{
const auto WaveGetLaneIndex = [&]() { return wave_index; };
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
uint gt_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
uint le_lane_index = WaveGetLaneIndex() + 1;
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
}
int main()
{
for (unsigned subgroup_id = 0; subgroup_id < 128; subgroup_id++)
{
test_main(subgroup_id);
for (unsigned bit = 0; bit < 128; bit++)
{
assert(bool(gl_SubgroupEqMask[bit / 32] & (1u << (bit & 31))) == (bit == subgroup_id));
assert(bool(gl_SubgroupGtMask[bit / 32] & (1u << (bit & 31))) == (bit > subgroup_id));
assert(bool(gl_SubgroupGeMask[bit / 32] & (1u << (bit & 31))) == (bit >= subgroup_id));
assert(bool(gl_SubgroupLtMask[bit / 32] & (1u << (bit & 31))) == (bit < subgroup_id));
assert(bool(gl_SubgroupLeMask[bit / 32] & (1u << (bit & 31))) == (bit <= subgroup_id));
}
}
}