mirror of
https://github.com/KhronosGroup/SPIRV-Cross.git
synced 2024-11-08 13:20:06 +00:00
Partially implement subgroup ops for HLSL SM 6.0.
Lots of stuff that needs tons of emulation, which I'm not going to bother with.
This commit is contained in:
parent
146ea76f52
commit
c266429be9
@ -0,0 +1,67 @@
|
||||
RWByteAddressBuffer _9 : register(u0, space0);
|
||||
|
||||
static uint4 gl_SubgroupEqMask;
|
||||
static uint4 gl_SubgroupGeMask;
|
||||
static uint4 gl_SubgroupGtMask;
|
||||
static uint4 gl_SubgroupLeMask;
|
||||
static uint4 gl_SubgroupLtMask;
|
||||
void comp_main()
|
||||
{
|
||||
_9.Store(0, asuint(float(WaveGetLaneCount())));
|
||||
_9.Store(0, asuint(float(WaveGetLaneIndex())));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupEqMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupGeMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupGtMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupLeMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupLtMask).x));
|
||||
uint4 _75 = WaveActiveBallot(true);
|
||||
float4 _88 = WaveActiveSum(20.0f.xxxx);
|
||||
int4 _94 = WaveActiveSum(int4(20, 20, 20, 20));
|
||||
float4 _96 = WaveActiveProduct(20.0f.xxxx);
|
||||
int4 _98 = WaveActiveProduct(int4(20, 20, 20, 20));
|
||||
float4 _127 = WavePrefixProduct(_96) * _96;
|
||||
int4 _129 = WavePrefixProduct(_98) * _98;
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
|
||||
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
|
||||
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
|
||||
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
|
||||
uint gt_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
|
||||
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
|
||||
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
|
||||
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
|
||||
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
|
||||
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
|
||||
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
|
||||
uint le_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
|
||||
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
|
||||
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
|
||||
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
|
||||
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
|
||||
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
|
||||
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
|
||||
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
|
||||
comp_main();
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
RWByteAddressBuffer _9 : register(u0, space0);
|
||||
|
||||
static uint4 gl_SubgroupEqMask;
|
||||
static uint4 gl_SubgroupGeMask;
|
||||
static uint4 gl_SubgroupGtMask;
|
||||
static uint4 gl_SubgroupLeMask;
|
||||
static uint4 gl_SubgroupLtMask;
|
||||
void comp_main()
|
||||
{
|
||||
_9.Store(0, asuint(float(WaveGetLaneCount())));
|
||||
_9.Store(0, asuint(float(WaveGetLaneIndex())));
|
||||
bool elected = WaveIsFirstLane();
|
||||
_9.Store(0, asuint(float4(gl_SubgroupEqMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupGeMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupGtMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupLeMask).x));
|
||||
_9.Store(0, asuint(float4(gl_SubgroupLtMask).x));
|
||||
float4 broadcasted = WaveReadLaneAt(10.0f.xxxx, 8u);
|
||||
float3 first = WaveReadLaneFirst(20.0f.xxx);
|
||||
uint4 ballot_value = WaveActiveBallot(true);
|
||||
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
|
||||
bool has_all = WaveActiveAllTrue(true);
|
||||
bool has_any = WaveActiveAnyTrue(true);
|
||||
bool has_equal = WaveActiveAllEqualBool(true);
|
||||
float4 added = WaveActiveSum(20.0f.xxxx);
|
||||
int4 iadded = WaveActiveSum(int4(20, 20, 20, 20));
|
||||
float4 multiplied = WaveActiveProduct(20.0f.xxxx);
|
||||
int4 imultiplied = WaveActiveProduct(int4(20, 20, 20, 20));
|
||||
float4 lo = WaveActiveMin(20.0f.xxxx);
|
||||
float4 hi = WaveActiveMax(20.0f.xxxx);
|
||||
int4 slo = WaveActiveMin(int4(20, 20, 20, 20));
|
||||
int4 shi = WaveActiveMax(int4(20, 20, 20, 20));
|
||||
uint4 ulo = WaveActiveMin(uint4(20u, 20u, 20u, 20u));
|
||||
uint4 uhi = WaveActiveMax(uint4(20u, 20u, 20u, 20u));
|
||||
uint4 anded = WaveActiveBitAnd(ballot_value);
|
||||
uint4 ored = WaveActiveBitOr(ballot_value);
|
||||
uint4 xored = WaveActiveBitXor(ballot_value);
|
||||
added = WavePrefixSum(added) + added;
|
||||
iadded = WavePrefixSum(iadded) + iadded;
|
||||
multiplied = WavePrefixProduct(multiplied) * multiplied;
|
||||
imultiplied = WavePrefixProduct(imultiplied) * imultiplied;
|
||||
added = WavePrefixSum(multiplied);
|
||||
multiplied = WavePrefixProduct(multiplied);
|
||||
iadded = WavePrefixSum(imultiplied);
|
||||
imultiplied = WavePrefixProduct(imultiplied);
|
||||
float4 swap_horiz = QuadReadAcrossX(20.0f.xxxx);
|
||||
float4 swap_vertical = QuadReadAcrossY(20.0f.xxxx);
|
||||
float4 swap_diagonal = QuadReadAcrossDiagonal(20.0f.xxxx);
|
||||
float4 quad_broadcast = QuadReadLaneAt(20.0f.xxxx, 3u);
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
|
||||
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
|
||||
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
|
||||
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
|
||||
uint gt_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
|
||||
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
|
||||
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
|
||||
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
|
||||
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
|
||||
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
|
||||
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
|
||||
uint le_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
|
||||
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
|
||||
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
|
||||
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
|
||||
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
|
||||
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
|
||||
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
|
||||
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
|
||||
comp_main();
|
||||
}
|
131
shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp
Normal file
131
shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp
Normal file
@ -0,0 +1,131 @@
|
||||
#version 450
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_ballot : require
|
||||
#extension GL_KHR_shader_subgroup_vote : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
#extension GL_KHR_shader_subgroup_quad : require
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) buffer SSBO
|
||||
{
|
||||
float FragColor;
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
// basic
|
||||
//FragColor = float(gl_NumSubgroups);
|
||||
//FragColor = float(gl_SubgroupID);
|
||||
FragColor = float(gl_SubgroupSize);
|
||||
FragColor = float(gl_SubgroupInvocationID);
|
||||
subgroupBarrier();
|
||||
subgroupMemoryBarrier();
|
||||
subgroupMemoryBarrierBuffer();
|
||||
subgroupMemoryBarrierShared();
|
||||
subgroupMemoryBarrierImage();
|
||||
bool elected = subgroupElect();
|
||||
|
||||
// ballot
|
||||
FragColor = float(gl_SubgroupEqMask);
|
||||
FragColor = float(gl_SubgroupGeMask);
|
||||
FragColor = float(gl_SubgroupGtMask);
|
||||
FragColor = float(gl_SubgroupLeMask);
|
||||
FragColor = float(gl_SubgroupLtMask);
|
||||
vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u);
|
||||
vec3 first = subgroupBroadcastFirst(vec3(20.0));
|
||||
uvec4 ballot_value = subgroupBallot(true);
|
||||
//bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
|
||||
//bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
|
||||
uint bit_count = subgroupBallotBitCount(ballot_value);
|
||||
//uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
|
||||
//uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
|
||||
//uint lsb = subgroupBallotFindLSB(ballot_value);
|
||||
//uint msb = subgroupBallotFindMSB(ballot_value);
|
||||
|
||||
// shuffle
|
||||
//uint shuffled = subgroupShuffle(10u, 8u);
|
||||
//uint shuffled_xor = subgroupShuffleXor(30u, 8u);
|
||||
|
||||
// shuffle relative
|
||||
//uint shuffled_up = subgroupShuffleUp(20u, 4u);
|
||||
//uint shuffled_down = subgroupShuffleDown(20u, 4u);
|
||||
|
||||
// vote
|
||||
bool has_all = subgroupAll(true);
|
||||
bool has_any = subgroupAny(true);
|
||||
bool has_equal = subgroupAllEqual(true);
|
||||
|
||||
// arithmetic
|
||||
vec4 added = subgroupAdd(vec4(20.0));
|
||||
ivec4 iadded = subgroupAdd(ivec4(20));
|
||||
vec4 multiplied = subgroupMul(vec4(20.0));
|
||||
ivec4 imultiplied = subgroupMul(ivec4(20));
|
||||
vec4 lo = subgroupMin(vec4(20.0));
|
||||
vec4 hi = subgroupMax(vec4(20.0));
|
||||
ivec4 slo = subgroupMin(ivec4(20));
|
||||
ivec4 shi = subgroupMax(ivec4(20));
|
||||
uvec4 ulo = subgroupMin(uvec4(20));
|
||||
uvec4 uhi = subgroupMax(uvec4(20));
|
||||
uvec4 anded = subgroupAnd(ballot_value);
|
||||
uvec4 ored = subgroupOr(ballot_value);
|
||||
uvec4 xored = subgroupXor(ballot_value);
|
||||
|
||||
added = subgroupInclusiveAdd(added);
|
||||
iadded = subgroupInclusiveAdd(iadded);
|
||||
multiplied = subgroupInclusiveMul(multiplied);
|
||||
imultiplied = subgroupInclusiveMul(imultiplied);
|
||||
#if 0
|
||||
lo = subgroupInclusiveMin(lo);
|
||||
hi = subgroupInclusiveMax(hi);
|
||||
slo = subgroupInclusiveMin(slo);
|
||||
shi = subgroupInclusiveMax(shi);
|
||||
ulo = subgroupInclusiveMin(ulo);
|
||||
uhi = subgroupInclusiveMax(uhi);
|
||||
anded = subgroupInclusiveAnd(anded);
|
||||
ored = subgroupInclusiveOr(ored);
|
||||
xored = subgroupInclusiveXor(ored);
|
||||
added = subgroupExclusiveAdd(lo);
|
||||
#endif
|
||||
|
||||
added = subgroupExclusiveAdd(multiplied);
|
||||
multiplied = subgroupExclusiveMul(multiplied);
|
||||
iadded = subgroupExclusiveAdd(imultiplied);
|
||||
imultiplied = subgroupExclusiveMul(imultiplied);
|
||||
#if 0
|
||||
lo = subgroupExclusiveMin(lo);
|
||||
hi = subgroupExclusiveMax(hi);
|
||||
ulo = subgroupExclusiveMin(ulo);
|
||||
uhi = subgroupExclusiveMax(uhi);
|
||||
slo = subgroupExclusiveMin(slo);
|
||||
shi = subgroupExclusiveMax(shi);
|
||||
anded = subgroupExclusiveAnd(anded);
|
||||
ored = subgroupExclusiveOr(ored);
|
||||
xored = subgroupExclusiveXor(ored);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// clustered
|
||||
added = subgroupClusteredAdd(added, 4u);
|
||||
multiplied = subgroupClusteredMul(multiplied, 4u);
|
||||
iadded = subgroupClusteredAdd(iadded, 4u);
|
||||
imultiplied = subgroupClusteredMul(imultiplied, 4u);
|
||||
lo = subgroupClusteredMin(lo, 4u);
|
||||
hi = subgroupClusteredMax(hi, 4u);
|
||||
ulo = subgroupClusteredMin(ulo, 4u);
|
||||
uhi = subgroupClusteredMax(uhi, 4u);
|
||||
slo = subgroupClusteredMin(slo, 4u);
|
||||
shi = subgroupClusteredMax(shi, 4u);
|
||||
anded = subgroupClusteredAnd(anded, 4u);
|
||||
ored = subgroupClusteredOr(ored, 4u);
|
||||
xored = subgroupClusteredXor(xored, 4u);
|
||||
#endif
|
||||
|
||||
// quad
|
||||
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
|
||||
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
|
||||
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
|
||||
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
|
||||
}
|
277
spirv_hlsl.cpp
277
spirv_hlsl.cpp
@ -625,6 +625,13 @@ void CompilerHLSL::emit_builtin_inputs_in_struct()
|
||||
break;
|
||||
|
||||
case BuiltInNumWorkgroups:
|
||||
case BuiltInSubgroupSize:
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
case BuiltInSubgroupEqMask:
|
||||
case BuiltInSubgroupLtMask:
|
||||
case BuiltInSubgroupLeMask:
|
||||
case BuiltInSubgroupGtMask:
|
||||
case BuiltInSubgroupGeMask:
|
||||
// Handled specially.
|
||||
break;
|
||||
|
||||
@ -864,6 +871,11 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas
|
||||
case BuiltInPointCoord:
|
||||
// Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
|
||||
return "float2(0.5f, 0.5f)";
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
return "WaveGetLaneIndex()";
|
||||
case BuiltInSubgroupSize:
|
||||
return "WaveGetLaneCount()";
|
||||
|
||||
default:
|
||||
return CompilerGLSL::builtin_to_glsl(builtin, storage);
|
||||
}
|
||||
@ -928,6 +940,22 @@ void CompilerHLSL::emit_builtin_variables()
|
||||
// Handled specially.
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
case BuiltInSubgroupSize:
|
||||
if (hlsl_options.shader_model < 60)
|
||||
SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupEqMask:
|
||||
case BuiltInSubgroupLtMask:
|
||||
case BuiltInSubgroupLeMask:
|
||||
case BuiltInSubgroupGtMask:
|
||||
case BuiltInSubgroupGeMask:
|
||||
if (hlsl_options.shader_model < 60)
|
||||
SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
|
||||
type = "uint4";
|
||||
break;
|
||||
|
||||
case BuiltInClipDistance:
|
||||
array_size = clip_distance_count;
|
||||
type = "float";
|
||||
@ -940,7 +968,6 @@ void CompilerHLSL::emit_builtin_variables()
|
||||
|
||||
default:
|
||||
SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
|
||||
break;
|
||||
}
|
||||
|
||||
StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
|
||||
@ -1225,6 +1252,14 @@ void CompilerHLSL::emit_resources()
|
||||
auto input_builtins = active_input_builtins;
|
||||
input_builtins.clear(BuiltInNumWorkgroups);
|
||||
input_builtins.clear(BuiltInPointCoord);
|
||||
input_builtins.clear(BuiltInSubgroupSize);
|
||||
input_builtins.clear(BuiltInSubgroupLocalInvocationId);
|
||||
input_builtins.clear(BuiltInSubgroupEqMask);
|
||||
input_builtins.clear(BuiltInSubgroupLtMask);
|
||||
input_builtins.clear(BuiltInSubgroupLeMask);
|
||||
input_builtins.clear(BuiltInSubgroupGtMask);
|
||||
input_builtins.clear(BuiltInSubgroupGeMask);
|
||||
|
||||
if (!input_variables.empty() || !input_builtins.empty())
|
||||
{
|
||||
require_input = true;
|
||||
@ -2106,6 +2141,70 @@ void CompilerHLSL::emit_hlsl_entry_point()
|
||||
|
||||
case BuiltInNumWorkgroups:
|
||||
case BuiltInPointCoord:
|
||||
case BuiltInSubgroupSize:
|
||||
case BuiltInSubgroupLocalInvocationId:
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupEqMask:
|
||||
// Emulate these ...
|
||||
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
|
||||
statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
|
||||
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
|
||||
statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
|
||||
statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
|
||||
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupGeMask:
|
||||
// Emulate these ...
|
||||
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
|
||||
statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
|
||||
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
|
||||
statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
|
||||
statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
|
||||
statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
|
||||
statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
|
||||
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupGtMask:
|
||||
// Emulate these ...
|
||||
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
|
||||
statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
|
||||
statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
|
||||
statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
|
||||
statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
|
||||
statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
|
||||
statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
|
||||
statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
|
||||
statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
|
||||
statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupLeMask:
|
||||
// Emulate these ...
|
||||
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
|
||||
statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
|
||||
statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
|
||||
statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
|
||||
statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
|
||||
statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
|
||||
statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
|
||||
statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
|
||||
statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
|
||||
statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
|
||||
break;
|
||||
|
||||
case BuiltInSubgroupLtMask:
|
||||
// Emulate these ...
|
||||
// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
|
||||
statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
|
||||
statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
|
||||
statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
|
||||
statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
|
||||
statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
|
||||
statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
|
||||
statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
|
||||
break;
|
||||
|
||||
case BuiltInClipDistance:
|
||||
@ -3528,6 +3627,176 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
|
||||
register_read(ops[1], ops[2], should_forward(ops[2]));
|
||||
}
|
||||
|
||||
void CompilerHLSL::emit_subgroup_op(const Instruction &i)
|
||||
{
|
||||
if (hlsl_options.shader_model < 60)
|
||||
SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
|
||||
|
||||
const uint32_t *ops = stream(i);
|
||||
auto op = static_cast<Op>(i.op);
|
||||
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
|
||||
auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
|
||||
if (scope != ScopeSubgroup)
|
||||
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
|
||||
|
||||
const auto make_inclusive_Sum = [&](const string &expr) -> string {
|
||||
return join(expr, " + ", to_expression(ops[4]));
|
||||
};
|
||||
|
||||
const auto make_inclusive_Product = [&](const string &expr) -> string {
|
||||
return join(expr, " * ", to_expression(ops[4]));
|
||||
};
|
||||
|
||||
#define make_inclusive_BitAnd(expr) ""
|
||||
#define make_inclusive_BitOr(expr) ""
|
||||
#define make_inclusive_BitXor(expr) ""
|
||||
#define make_inclusive_Min(expr) ""
|
||||
#define make_inclusive_Max(expr) ""
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case OpGroupNonUniformElect:
|
||||
emit_op(result_type, id, "WaveIsFirstLane()", true);
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBroadcast:
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBroadcastFirst:
|
||||
emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallot:
|
||||
emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformInverseBallot:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitExtract:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindLSB:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindMSB:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
{
|
||||
auto operation = static_cast<GroupOperation>(ops[3]);
|
||||
if (operation == GroupOperationReduce)
|
||||
{
|
||||
bool forward = should_forward(ops[4]);
|
||||
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(", to_enclosed_expression(ops[4]), ".y)");
|
||||
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(", to_enclosed_expression(ops[4]), ".w)");
|
||||
emit_op(result_type, id, join(left, " + ", right), forward);
|
||||
inherit_expression_dependencies(id, ops[4]);
|
||||
}
|
||||
else if (operation == GroupOperationInclusiveScan)
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
|
||||
else if (operation == GroupOperationExclusiveScan)
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
|
||||
else
|
||||
SPIRV_CROSS_THROW("Invalid BitCount operation.");
|
||||
break;
|
||||
}
|
||||
|
||||
case OpGroupNonUniformShuffle:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement Shuffle in HLSL.");
|
||||
case OpGroupNonUniformShuffleXor:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleXor in HLSL.");
|
||||
case OpGroupNonUniformShuffleUp:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleUp in HLSL.");
|
||||
case OpGroupNonUniformShuffleDown:
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement ShuffleDown in HLSL.");
|
||||
|
||||
case OpGroupNonUniformAll:
|
||||
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAny:
|
||||
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAllEqual:
|
||||
{
|
||||
auto &type = get<SPIRType>(result_type);
|
||||
emit_unary_func_op(result_type, id, ops[3],
|
||||
type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual");
|
||||
break;
|
||||
}
|
||||
|
||||
#define GROUP_OP(op, hlsl_op, supports_scan) \
|
||||
case OpGroupNonUniform##op: \
|
||||
{ \
|
||||
auto operation = static_cast<GroupOperation>(ops[3]); \
|
||||
if (operation == GroupOperationReduce) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
|
||||
else if (operation == GroupOperationInclusiveScan && supports_scan) \
|
||||
{ \
|
||||
bool forward = should_forward(ops[4]); \
|
||||
emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
|
||||
inherit_expression_dependencies(id, ops[4]); \
|
||||
} \
|
||||
else if (operation == GroupOperationExclusiveScan && supports_scan) \
|
||||
emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
|
||||
else if (operation == GroupOperationClusteredReduce) \
|
||||
SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
|
||||
else \
|
||||
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||
break; \
|
||||
}
|
||||
GROUP_OP(FAdd, Sum, true)
|
||||
GROUP_OP(FMul, Product, true)
|
||||
GROUP_OP(FMin, Min, false)
|
||||
GROUP_OP(FMax, Max, false)
|
||||
GROUP_OP(IAdd, Sum, true)
|
||||
GROUP_OP(IMul, Product, true)
|
||||
GROUP_OP(SMin, Min, false)
|
||||
GROUP_OP(SMax, Max, false)
|
||||
GROUP_OP(UMin, Min, false)
|
||||
GROUP_OP(UMax, Max, false)
|
||||
GROUP_OP(BitwiseAnd, BitAnd, false)
|
||||
GROUP_OP(BitwiseOr, BitOr, false)
|
||||
GROUP_OP(BitwiseXor, BitXor, false)
|
||||
#undef GROUP_OP
|
||||
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
{
|
||||
uint32_t direction = get<SPIRConstant>(ops[4]).scalar();
|
||||
if (direction == 0)
|
||||
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
|
||||
else if (direction == 1)
|
||||
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
|
||||
else if (direction == 2)
|
||||
emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
|
||||
else
|
||||
SPIRV_CROSS_THROW("Invalid quad swap direction.");
|
||||
break;
|
||||
}
|
||||
|
||||
case OpGroupNonUniformQuadBroadcast:
|
||||
{
|
||||
emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
|
||||
}
|
||||
|
||||
register_control_dependent_expression(id);
|
||||
}
|
||||
|
||||
void CompilerHLSL::emit_instruction(const Instruction &instruction)
|
||||
{
|
||||
auto ops = stream(instruction);
|
||||
@ -4004,6 +4273,12 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
|
||||
semantics = get<SPIRConstant>(ops[2]).scalar();
|
||||
}
|
||||
|
||||
if (memory == ScopeSubgroup)
|
||||
{
|
||||
// No Wave-barriers in HLSL.
|
||||
break;
|
||||
}
|
||||
|
||||
// We only care about these flags, acquire/release and friends are not relevant to GLSL.
|
||||
semantics = mask_relevant_memory_semantics(semantics);
|
||||
|
||||
|
@ -157,6 +157,7 @@ private:
|
||||
void write_access_chain(const SPIRAccessChain &chain, uint32_t value);
|
||||
void emit_store(const Instruction &instruction);
|
||||
void emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op);
|
||||
void emit_subgroup_op(const Instruction &i) override;
|
||||
|
||||
void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, const std::string &qualifier,
|
||||
uint32_t base_offset = 0) override;
|
||||
|
@ -155,7 +155,9 @@ def validate_shader_hlsl(shader):
|
||||
sys.exit(1)
|
||||
|
||||
def shader_to_sm(shader):
|
||||
if '.sm51.' in shader:
|
||||
if '.sm60.' in shader:
|
||||
return '60'
|
||||
elif '.sm51.' in shader:
|
||||
return '51'
|
||||
elif '.sm20.' in shader:
|
||||
return '20'
|
||||
|
73
tests-other/hlsl_wave_mask.cpp
Normal file
73
tests-other/hlsl_wave_mask.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
// Ad-hoc test that the wave op masks work as expected.
|
||||
#include <glm/glm.hpp>
|
||||
#include <assert.h>
|
||||
|
||||
using namespace glm;
|
||||
|
||||
static uvec4 gl_SubgroupEqMask;
|
||||
static uvec4 gl_SubgroupGeMask;
|
||||
static uvec4 gl_SubgroupGtMask;
|
||||
static uvec4 gl_SubgroupLeMask;
|
||||
static uvec4 gl_SubgroupLtMask;
|
||||
using uint4 = uvec4;
|
||||
|
||||
static void test_main(unsigned wave_index)
|
||||
{
|
||||
const auto WaveGetLaneIndex = [&]() { return wave_index; };
|
||||
|
||||
gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;
|
||||
if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;
|
||||
if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;
|
||||
gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;
|
||||
uint gt_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);
|
||||
if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;
|
||||
if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;
|
||||
if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;
|
||||
if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;
|
||||
if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;
|
||||
if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;
|
||||
if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;
|
||||
uint le_lane_index = WaveGetLaneIndex() + 1;
|
||||
gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;
|
||||
if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;
|
||||
if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;
|
||||
if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;
|
||||
if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;
|
||||
if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;
|
||||
if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;
|
||||
gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;
|
||||
if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;
|
||||
if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;
|
||||
if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;
|
||||
if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;
|
||||
if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;
|
||||
if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
for (unsigned subgroup_id = 0; subgroup_id < 128; subgroup_id++)
|
||||
{
|
||||
test_main(subgroup_id);
|
||||
|
||||
for (unsigned bit = 0; bit < 128; bit++)
|
||||
{
|
||||
assert(bool(gl_SubgroupEqMask[bit / 32] & (1u << (bit & 31))) == (bit == subgroup_id));
|
||||
assert(bool(gl_SubgroupGtMask[bit / 32] & (1u << (bit & 31))) == (bit > subgroup_id));
|
||||
assert(bool(gl_SubgroupGeMask[bit / 32] & (1u << (bit & 31))) == (bit >= subgroup_id));
|
||||
assert(bool(gl_SubgroupLtMask[bit / 32] & (1u << (bit & 31))) == (bit < subgroup_id));
|
||||
assert(bool(gl_SubgroupLeMask[bit / 32] & (1u << (bit & 31))) == (bit <= subgroup_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user