GLSL: Deal with sign in subgroup Min/Max operations.
This commit is contained in:
parent
34ba8ea4f2
commit
5253da9e63
@ -0,0 +1,24 @@
|
||||
#version 450
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
layout(location = 0) flat in int index;
|
||||
layout(location = 0) out uint FragColor;
|
||||
|
||||
void main()
|
||||
{
|
||||
uint _17 = uint(index);
|
||||
FragColor = uint(subgroupMin(index));
|
||||
FragColor = uint(subgroupMax(int(_17)));
|
||||
FragColor = subgroupMin(uint(index));
|
||||
FragColor = subgroupMax(_17);
|
||||
FragColor = uint(subgroupInclusiveMax(index));
|
||||
FragColor = uint(subgroupInclusiveMin(int(_17)));
|
||||
FragColor = subgroupExclusiveMax(uint(index));
|
||||
FragColor = subgroupExclusiveMin(_17);
|
||||
FragColor = uint(subgroupClusteredMin(index, 4u));
|
||||
FragColor = uint(subgroupClusteredMax(int(_17), 4u));
|
||||
FragColor = subgroupClusteredMin(uint(index), 4u);
|
||||
FragColor = subgroupClusteredMax(_17, 4u);
|
||||
}
|
||||
|
@ -0,0 +1,65 @@
|
||||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Khronos Glslang Reference Front End; 8
|
||||
; Bound: 78
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpCapability GroupNonUniform
|
||||
OpCapability GroupNonUniformArithmetic
|
||||
OpCapability GroupNonUniformClustered
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main" %index %FragColor
|
||||
OpExecutionMode %main OriginUpperLeft
|
||||
OpSource GLSL 450
|
||||
OpSourceExtension "GL_KHR_shader_subgroup_arithmetic"
|
||||
OpSourceExtension "GL_KHR_shader_subgroup_basic"
|
||||
OpSourceExtension "GL_KHR_shader_subgroup_clustered"
|
||||
OpName %main "main"
|
||||
OpName %index "index"
|
||||
OpName %FragColor "FragColor"
|
||||
OpDecorate %index Flat
|
||||
OpDecorate %index Location 0
|
||||
OpDecorate %FragColor Location 0
|
||||
%void = OpTypeVoid
|
||||
%3 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%int = OpTypeInt 32 1
|
||||
%_ptr_Input_int = OpTypePointer Input %int
|
||||
%index = OpVariable %_ptr_Input_int Input
|
||||
%uint_3 = OpConstant %uint 3
|
||||
%uint_4 = OpConstant %uint 4
|
||||
%_ptr_Output_uint = OpTypePointer Output %uint
|
||||
%FragColor = OpVariable %_ptr_Output_uint Output
|
||||
%main = OpFunction %void None %3
|
||||
%5 = OpLabel
|
||||
%i = OpLoad %int %index
|
||||
%u = OpBitcast %uint %i
|
||||
%res0 = OpGroupNonUniformSMin %uint %uint_3 Reduce %i
|
||||
%res1 = OpGroupNonUniformSMax %uint %uint_3 Reduce %u
|
||||
%res2 = OpGroupNonUniformUMin %uint %uint_3 Reduce %i
|
||||
%res3 = OpGroupNonUniformUMax %uint %uint_3 Reduce %u
|
||||
%res4 = OpGroupNonUniformSMax %uint %uint_3 InclusiveScan %i
|
||||
%res5 = OpGroupNonUniformSMin %uint %uint_3 InclusiveScan %u
|
||||
%res6 = OpGroupNonUniformUMax %uint %uint_3 ExclusiveScan %i
|
||||
%res7 = OpGroupNonUniformUMin %uint %uint_3 ExclusiveScan %u
|
||||
%res8 = OpGroupNonUniformSMin %uint %uint_3 ClusteredReduce %i %uint_4
|
||||
%res9 = OpGroupNonUniformSMax %uint %uint_3 ClusteredReduce %u %uint_4
|
||||
%res10 = OpGroupNonUniformUMin %uint %uint_3 ClusteredReduce %i %uint_4
|
||||
%res11 = OpGroupNonUniformUMax %uint %uint_3 ClusteredReduce %u %uint_4
|
||||
OpStore %FragColor %res0
|
||||
OpStore %FragColor %res1
|
||||
OpStore %FragColor %res2
|
||||
OpStore %FragColor %res3
|
||||
OpStore %FragColor %res4
|
||||
OpStore %FragColor %res5
|
||||
OpStore %FragColor %res6
|
||||
OpStore %FragColor %res7
|
||||
OpStore %FragColor %res8
|
||||
OpStore %FragColor %res9
|
||||
OpStore %FragColor %res10
|
||||
OpStore %FragColor %res11
|
||||
OpReturn
|
||||
OpFunctionEnd
|
@ -4464,6 +4464,34 @@ void CompilerGLSL::emit_trinary_func_op_cast(uint32_t result_type, uint32_t resu
|
||||
inherit_expression_dependencies(result_id, op2);
|
||||
}
|
||||
|
||||
void CompilerGLSL::emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0,
|
||||
uint32_t op1, const char *op, SPIRType::BaseType input_type)
|
||||
{
|
||||
// Special purpose method for implementing clustered subgroup opcodes.
|
||||
// Main difference is that op1 does not participate in any casting, it needs to be a literal.
|
||||
auto &out_type = get<SPIRType>(result_type);
|
||||
auto expected_type = out_type;
|
||||
expected_type.basetype = input_type;
|
||||
string cast_op0 =
|
||||
expression_type(op0).basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0);
|
||||
|
||||
string expr;
|
||||
if (out_type.basetype != input_type)
|
||||
{
|
||||
expr = bitcast_glsl_op(out_type, expected_type);
|
||||
expr += '(';
|
||||
expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")");
|
||||
expr += ')';
|
||||
}
|
||||
else
|
||||
{
|
||||
expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")");
|
||||
}
|
||||
|
||||
emit_op(result_type, result_id, expr, should_forward(op0));
|
||||
inherit_expression_dependencies(result_id, op0);
|
||||
}
|
||||
|
||||
void CompilerGLSL::emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
|
||||
const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type)
|
||||
{
|
||||
@ -6014,6 +6042,11 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
|
||||
if (!options.vulkan_semantics)
|
||||
SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics.");
|
||||
|
||||
// If we need to do implicit bitcasts, make sure we do it with the correct type.
|
||||
uint32_t integer_width = get_integer_width_for_instruction(i);
|
||||
auto int_type = to_signed_basetype(integer_width);
|
||||
auto uint_type = to_unsigned_basetype(integer_width);
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case OpGroupNonUniformElect:
|
||||
@ -6185,20 +6218,39 @@ case OpGroupNonUniform##op: \
|
||||
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define GLSL_GROUP_OP_CAST(op, glsl_op, type) \
|
||||
case OpGroupNonUniform##op: \
|
||||
{ \
|
||||
auto operation = static_cast<GroupOperation>(ops[3]); \
|
||||
if (operation == GroupOperationReduce) \
|
||||
emit_unary_func_op_cast(result_type, id, ops[4], "subgroup" #glsl_op, type, type); \
|
||||
else if (operation == GroupOperationInclusiveScan) \
|
||||
emit_unary_func_op_cast(result_type, id, ops[4], "subgroupInclusive" #glsl_op, type, type); \
|
||||
else if (operation == GroupOperationExclusiveScan) \
|
||||
emit_unary_func_op_cast(result_type, id, ops[4], "subgroupExclusive" #glsl_op, type, type); \
|
||||
else if (operation == GroupOperationClusteredReduce) \
|
||||
emit_binary_func_op_cast_clustered(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op, type); \
|
||||
else \
|
||||
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||
break; \
|
||||
}
|
||||
|
||||
GLSL_GROUP_OP(FAdd, Add)
|
||||
GLSL_GROUP_OP(FMul, Mul)
|
||||
GLSL_GROUP_OP(FMin, Min)
|
||||
GLSL_GROUP_OP(FMax, Max)
|
||||
GLSL_GROUP_OP(IAdd, Add)
|
||||
GLSL_GROUP_OP(IMul, Mul)
|
||||
GLSL_GROUP_OP(SMin, Min)
|
||||
GLSL_GROUP_OP(SMax, Max)
|
||||
GLSL_GROUP_OP(UMin, Min)
|
||||
GLSL_GROUP_OP(UMax, Max)
|
||||
GLSL_GROUP_OP_CAST(SMin, Min, int_type)
|
||||
GLSL_GROUP_OP_CAST(SMax, Max, int_type)
|
||||
GLSL_GROUP_OP_CAST(UMin, Min, uint_type)
|
||||
GLSL_GROUP_OP_CAST(UMax, Max, uint_type)
|
||||
GLSL_GROUP_OP(BitwiseAnd, And)
|
||||
GLSL_GROUP_OP(BitwiseOr, Or)
|
||||
GLSL_GROUP_OP(BitwiseXor, Xor)
|
||||
#undef GLSL_GROUP_OP
|
||||
#undef GLSL_GROUP_OP_CAST
|
||||
// clang-format on
|
||||
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
|
@ -467,6 +467,8 @@ protected:
|
||||
SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type);
|
||||
void emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
|
||||
SPIRType::BaseType input_type, bool skip_cast_if_equal_type);
|
||||
void emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
|
||||
const char *op, SPIRType::BaseType input_type);
|
||||
void emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2,
|
||||
const char *op, SPIRType::BaseType input_type);
|
||||
void emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
|
||||
|
Loading…
Reference in New Issue
Block a user