HLSL: Support logical subgroup ops.
This commit is contained in:
parent
5570043af3
commit
d6c2c1b39a
@ -0,0 +1,28 @@
|
||||
static const uint3 gl_WorkGroupSize = uint3(30u, 1u, 1u);
|
||||
|
||||
RWByteAddressBuffer _46 : register(u0, space0);
|
||||
|
||||
static uint3 gl_GlobalInvocationID;
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||
};
|
||||
|
||||
void comp_main()
|
||||
{
|
||||
bool v = gl_GlobalInvocationID.x != 3u;
|
||||
bool4 v4;
|
||||
v4.x = bool(WaveActiveBitOr(uint(v)));
|
||||
v4.y = bool(WaveActiveBitAnd(uint(v)));
|
||||
v4.z = bool(WaveActiveBitXor(uint(v)));
|
||||
v4.w = WaveActiveAllEqual(v);
|
||||
uint4 w = uint4(v4.x ? uint4(1u, 1u, 1u, 1u).x : uint4(0u, 0u, 0u, 0u).x, v4.y ? uint4(1u, 1u, 1u, 1u).y : uint4(0u, 0u, 0u, 0u).y, v4.z ? uint4(1u, 1u, 1u, 1u).z : uint4(0u, 0u, 0u, 0u).z, v4.w ? uint4(1u, 1u, 1u, 1u).w : uint4(0u, 0u, 0u, 0u).w);
|
||||
_46.Store(gl_GlobalInvocationID.x * 4 + 0, ((w.x + w.y) + w.z) + w.w);
|
||||
}
|
||||
|
||||
[numthreads(30, 1, 1)]
|
||||
void main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||
comp_main();
|
||||
}
|
@ -23,7 +23,7 @@ void comp_main()
|
||||
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
|
||||
bool has_all = WaveActiveAllTrue(true);
|
||||
bool has_any = WaveActiveAnyTrue(true);
|
||||
bool has_equal = WaveActiveAllEqualBool(true);
|
||||
bool has_equal = WaveActiveAllEqual(true);
|
||||
float4 added = WaveActiveSum(20.0f.xxxx);
|
||||
int4 iadded = WaveActiveSum(int4(20, 20, 20, 20));
|
||||
float4 multiplied = WaveActiveProduct(20.0f.xxxx);
|
||||
@ -37,6 +37,9 @@ void comp_main()
|
||||
uint4 anded = WaveActiveBitAnd(ballot_value);
|
||||
uint4 ored = WaveActiveBitOr(ballot_value);
|
||||
uint4 xored = WaveActiveBitXor(ballot_value);
|
||||
bool4 anded_b = bool4(WaveActiveBitAnd(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
|
||||
bool4 ored_b = bool4(WaveActiveBitOr(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
|
||||
bool4 xored_b = bool4(WaveActiveBitXor(uint4(bool4(ballot_value.x == uint4(42u, 42u, 42u, 42u).x, ballot_value.y == uint4(42u, 42u, 42u, 42u).y, ballot_value.z == uint4(42u, 42u, 42u, 42u).z, ballot_value.w == uint4(42u, 42u, 42u, 42u).w))));
|
||||
added = WavePrefixSum(added) + added;
|
||||
iadded = WavePrefixSum(iadded) + iadded;
|
||||
multiplied = WavePrefixProduct(multiplied) * multiplied;
|
||||
|
@ -0,0 +1,30 @@
|
||||
#version 450
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_ballot : require
|
||||
#extension GL_KHR_shader_subgroup_vote : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
#extension GL_KHR_shader_subgroup_quad : require
|
||||
layout(local_size_x = 30) in;
|
||||
|
||||
layout(std430, binding = 0) buffer SSBO
|
||||
{
|
||||
uint FragColor[];
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
bool v = gl_GlobalInvocationID.x != 3;
|
||||
bvec4 v4;
|
||||
v4.x = subgroupOr(v);
|
||||
v4.y = subgroupAnd(v);
|
||||
v4.z = subgroupXor(v);
|
||||
v4.w = subgroupAllEqual(v);
|
||||
|
||||
uvec4 w = uvec4(v4);
|
||||
FragColor[gl_GlobalInvocationID.x] = w.x + w.y + w.z + w.w;
|
||||
}
|
||||
|
||||
|
@ -72,6 +72,9 @@ void main()
|
||||
uvec4 anded = subgroupAnd(ballot_value);
|
||||
uvec4 ored = subgroupOr(ballot_value);
|
||||
uvec4 xored = subgroupXor(ballot_value);
|
||||
bvec4 anded_b = subgroupAnd(equal(ballot_value, uvec4(42)));
|
||||
bvec4 ored_b = subgroupOr(equal(ballot_value, uvec4(42)));
|
||||
bvec4 xored_b = subgroupXor(equal(ballot_value, uvec4(42)));
|
||||
|
||||
added = subgroupInclusiveAdd(added);
|
||||
iadded = subgroupInclusiveAdd(iadded);
|
||||
@ -121,6 +124,10 @@ void main()
|
||||
anded = subgroupClusteredAnd(anded, 4u);
|
||||
ored = subgroupClusteredOr(ored, 4u);
|
||||
xored = subgroupClusteredXor(xored, 4u);
|
||||
|
||||
anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 4u);
|
||||
ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 4u);
|
||||
xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 4u);
|
||||
#endif
|
||||
|
||||
// quad
|
||||
|
@ -5700,13 +5700,26 @@ void CompilerGLSL::emit_unary_func_op_cast(uint32_t result_type, uint32_t result
|
||||
// Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends.
|
||||
expected_type.basetype = input_type;
|
||||
expected_type.width = expr_type.width;
|
||||
string cast_op = expr_type.basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0);
|
||||
|
||||
string cast_op;
|
||||
if (expr_type.basetype != input_type)
|
||||
{
|
||||
if (expr_type.basetype == SPIRType::Boolean)
|
||||
cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")");
|
||||
else
|
||||
cast_op = bitcast_glsl(expected_type, op0);
|
||||
}
|
||||
else
|
||||
cast_op = to_unpacked_expression(op0);
|
||||
|
||||
string expr;
|
||||
if (out_type.basetype != expected_result_type)
|
||||
{
|
||||
expected_type.basetype = expected_result_type;
|
||||
expected_type.width = out_type.width;
|
||||
if (out_type.basetype == SPIRType::Boolean)
|
||||
expr = type_to_glsl(out_type);
|
||||
else
|
||||
expr = bitcast_glsl_op(out_type, expected_type);
|
||||
expr += '(';
|
||||
expr += join(op, "(", cast_op, ")");
|
||||
|
@ -4635,12 +4635,8 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformAllEqual:
|
||||
{
|
||||
auto &type = get<SPIRType>(result_type);
|
||||
emit_unary_func_op(result_type, id, ops[3],
|
||||
type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual");
|
||||
emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
|
||||
break;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
|
||||
@ -4688,6 +4684,9 @@ case OpGroupNonUniform##op: \
|
||||
HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
|
||||
HLSL_GROUP_OP(BitwiseOr, BitOr, false)
|
||||
HLSL_GROUP_OP(BitwiseXor, BitXor, false)
|
||||
HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
|
||||
HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
|
||||
HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
|
||||
|
||||
#undef HLSL_GROUP_OP
|
||||
#undef HLSL_GROUP_OP_CAST
|
||||
|
Loading…
Reference in New Issue
Block a user