HLSL: Support logical subgroup ops.

This commit is contained in:
Hans-Kristian Arntzen 2021-03-08 12:52:03 +01:00
parent 5570043af3
commit d6c2c1b39a
6 changed files with 88 additions and 8 deletions

View File

@ -0,0 +1,28 @@
static const uint3 gl_WorkGroupSize = uint3(30u, 1u, 1u);
RWByteAddressBuffer _46 : register(u0, space0);
static uint3 gl_GlobalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
};
void comp_main()
{
bool v = gl_GlobalInvocationID.x != 3u;
bool4 v4;
v4.x = bool(WaveActiveBitOr(uint(v)));
v4.y = bool(WaveActiveBitAnd(uint(v)));
v4.z = bool(WaveActiveBitXor(uint(v)));
v4.w = WaveActiveAllEqual(v);
uint4 w = uint4(v4.x ? uint4(1u, 1u, 1u, 1u).x : uint4(0u, 0u, 0u, 0u).x, v4.y ? uint4(1u, 1u, 1u, 1u).y : uint4(0u, 0u, 0u, 0u).y, v4.z ? uint4(1u, 1u, 1u, 1u).z : uint4(0u, 0u, 0u, 0u).z, v4.w ? uint4(1u, 1u, 1u, 1u).w : uint4(0u, 0u, 0u, 0u).w);
_46.Store(gl_GlobalInvocationID.x * 4 + 0, ((w.x + w.y) + w.z) + w.w);
}
[numthreads(30, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
comp_main();
}

View File

@ -23,7 +23,7 @@ void comp_main()
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
bool has_all = WaveActiveAllTrue(true);
bool has_any = WaveActiveAnyTrue(true);
bool has_equal = WaveActiveAllEqualBool(true);
bool has_equal = WaveActiveAllEqual(true);
float4 added = WaveActiveSum(20.0f.xxxx);
int4 iadded = WaveActiveSum(int4(20, 20, 20, 20));
float4 multiplied = WaveActiveProduct(20.0f.xxxx);
@ -37,6 +37,9 @@ void comp_main()
uint4 anded = WaveActiveBitAnd(ballot_value);
uint4 ored = WaveActiveBitOr(ballot_value);
uint4 xored = WaveActiveBitXor(ballot_value);
bool4 anded_b = bool4(WaveActiveBitAnd(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
bool4 ored_b = bool4(WaveActiveBitOr(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
bool4 xored_b = bool4(WaveActiveBitXor(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
added = WavePrefixSum(added) + added;
iadded = WavePrefixSum(iadded) + iadded;
multiplied = WavePrefixProduct(multiplied) * multiplied;

View File

@ -0,0 +1,30 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_clustered : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 30) in;
layout(std430, binding = 0) buffer SSBO
{
uint FragColor[];
};
void main()
{
bool v = gl_GlobalInvocationID.x != 3;
bvec4 v4;
v4.x = subgroupOr(v);
v4.y = subgroupAnd(v);
v4.z = subgroupXor(v);
v4.w = subgroupAllEqual(v);
uvec4 w = uvec4(v4);
FragColor[gl_GlobalInvocationID.x] = w.x + w.y + w.z + w.w;
}

View File

@ -72,6 +72,9 @@ void main()
uvec4 anded = subgroupAnd(ballot_value);
uvec4 ored = subgroupOr(ballot_value);
uvec4 xored = subgroupXor(ballot_value);
bvec4 anded_b = subgroupAnd(equal(ballot_value, uvec4(42)));
bvec4 ored_b = subgroupOr(equal(ballot_value, uvec4(42)));
bvec4 xored_b = subgroupXor(equal(ballot_value, uvec4(42)));
added = subgroupInclusiveAdd(added);
iadded = subgroupInclusiveAdd(iadded);
@ -121,6 +124,10 @@ void main()
anded = subgroupClusteredAnd(anded, 4u);
ored = subgroupClusteredOr(ored, 4u);
xored = subgroupClusteredXor(xored, 4u);
anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 4u);
ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 4u);
xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 4u);
#endif
// quad

View File

@ -5700,14 +5700,27 @@ void CompilerGLSL::emit_unary_func_op_cast(uint32_t result_type, uint32_t result
// Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends.
expected_type.basetype = input_type;
expected_type.width = expr_type.width;
string cast_op = expr_type.basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0);
string cast_op;
if (expr_type.basetype != input_type)
{
if (expr_type.basetype == SPIRType::Boolean)
cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")");
else
cast_op = bitcast_glsl(expected_type, op0);
}
else
cast_op = to_unpacked_expression(op0);
string expr;
if (out_type.basetype != expected_result_type)
{
expected_type.basetype = expected_result_type;
expected_type.width = out_type.width;
expr = bitcast_glsl_op(out_type, expected_type);
if (out_type.basetype == SPIRType::Boolean)
expr = type_to_glsl(out_type);
else
expr = bitcast_glsl_op(out_type, expected_type);
expr += '(';
expr += join(op, "(", cast_op, ")");
expr += ')';

View File

@ -4635,12 +4635,8 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
break;
case OpGroupNonUniformAllEqual:
{
auto &type = get<SPIRType>(result_type);
emit_unary_func_op(result_type, id, ops[3],
type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual");
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
break;
}
// clang-format off
#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
@ -4688,6 +4684,9 @@ case OpGroupNonUniform##op: \
HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
HLSL_GROUP_OP(BitwiseOr, BitOr, false)
HLSL_GROUP_OP(BitwiseXor, BitXor, false)
HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
#undef HLSL_GROUP_OP
#undef HLSL_GROUP_OP_CAST