GLSL: Deal with sign in subgroup Min/Max operations.

This commit is contained in:
Hans-Kristian Arntzen 2020-01-09 12:01:54 +01:00
parent 34ba8ea4f2
commit 5253da9e63
4 changed files with 147 additions and 4 deletions

View File

@ -0,0 +1,24 @@
#version 450
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_clustered : require
layout(location = 0) flat in int index;
layout(location = 0) out uint FragColor;
void main()
{
uint _17 = uint(index);
FragColor = uint(subgroupMin(index));
FragColor = uint(subgroupMax(int(_17)));
FragColor = subgroupMin(uint(index));
FragColor = subgroupMax(_17);
FragColor = uint(subgroupInclusiveMax(index));
FragColor = uint(subgroupInclusiveMin(int(_17)));
FragColor = subgroupExclusiveMax(uint(index));
FragColor = subgroupExclusiveMin(_17);
FragColor = uint(subgroupClusteredMin(index, 4u));
FragColor = uint(subgroupClusteredMax(int(_17), 4u));
FragColor = subgroupClusteredMin(uint(index), 4u);
FragColor = subgroupClusteredMax(_17, 4u);
}

View File

@ -0,0 +1,65 @@
; SPIR-V
; Version: 1.3
; Generator: Khronos Glslang Reference Front End; 8
; Bound: 78
; Schema: 0
OpCapability Shader
OpCapability GroupNonUniform
OpCapability GroupNonUniformArithmetic
OpCapability GroupNonUniformClustered
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %index %FragColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpSourceExtension "GL_KHR_shader_subgroup_arithmetic"
OpSourceExtension "GL_KHR_shader_subgroup_basic"
OpSourceExtension "GL_KHR_shader_subgroup_clustered"
OpName %main "main"
OpName %index "index"
OpName %FragColor "FragColor"
OpDecorate %index Flat
OpDecorate %index Location 0
OpDecorate %FragColor Location 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%uint_0 = OpConstant %uint 0
%int = OpTypeInt 32 1
%_ptr_Input_int = OpTypePointer Input %int
%index = OpVariable %_ptr_Input_int Input
%uint_3 = OpConstant %uint 3
%uint_4 = OpConstant %uint 4
%_ptr_Output_uint = OpTypePointer Output %uint
%FragColor = OpVariable %_ptr_Output_uint Output
%main = OpFunction %void None %3
%5 = OpLabel
%i = OpLoad %int %index
%u = OpBitcast %uint %i
%res0 = OpGroupNonUniformSMin %uint %uint_3 Reduce %i
%res1 = OpGroupNonUniformSMax %uint %uint_3 Reduce %u
%res2 = OpGroupNonUniformUMin %uint %uint_3 Reduce %i
%res3 = OpGroupNonUniformUMax %uint %uint_3 Reduce %u
%res4 = OpGroupNonUniformSMax %uint %uint_3 InclusiveScan %i
%res5 = OpGroupNonUniformSMin %uint %uint_3 InclusiveScan %u
%res6 = OpGroupNonUniformUMax %uint %uint_3 ExclusiveScan %i
%res7 = OpGroupNonUniformUMin %uint %uint_3 ExclusiveScan %u
%res8 = OpGroupNonUniformSMin %uint %uint_3 ClusteredReduce %i %uint_4
%res9 = OpGroupNonUniformSMax %uint %uint_3 ClusteredReduce %u %uint_4
%res10 = OpGroupNonUniformUMin %uint %uint_3 ClusteredReduce %i %uint_4
%res11 = OpGroupNonUniformUMax %uint %uint_3 ClusteredReduce %u %uint_4
OpStore %FragColor %res0
OpStore %FragColor %res1
OpStore %FragColor %res2
OpStore %FragColor %res3
OpStore %FragColor %res4
OpStore %FragColor %res5
OpStore %FragColor %res6
OpStore %FragColor %res7
OpStore %FragColor %res8
OpStore %FragColor %res9
OpStore %FragColor %res10
OpStore %FragColor %res11
OpReturn
OpFunctionEnd

View File

@ -4464,6 +4464,34 @@ void CompilerGLSL::emit_trinary_func_op_cast(uint32_t result_type, uint32_t resu
inherit_expression_dependencies(result_id, op2); inherit_expression_dependencies(result_id, op2);
} }
void CompilerGLSL::emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0,
uint32_t op1, const char *op, SPIRType::BaseType input_type)
{
// Special purpose method for implementing clustered subgroup opcodes.
// Main difference is that op1 does not participate in any casting, it needs to be a literal.
auto &out_type = get<SPIRType>(result_type);
auto expected_type = out_type;
expected_type.basetype = input_type;
string cast_op0 =
expression_type(op0).basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0);
string expr;
if (out_type.basetype != input_type)
{
expr = bitcast_glsl_op(out_type, expected_type);
expr += '(';
expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")");
expr += ')';
}
else
{
expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")");
}
emit_op(result_type, result_id, expr, should_forward(op0));
inherit_expression_dependencies(result_id, op0);
}
void CompilerGLSL::emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, void CompilerGLSL::emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type) const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type)
{ {
@ -6014,6 +6042,11 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
if (!options.vulkan_semantics) if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics."); SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics.");
// If we need to do implicit bitcasts, make sure we do it with the correct type.
uint32_t integer_width = get_integer_width_for_instruction(i);
auto int_type = to_signed_basetype(integer_width);
auto uint_type = to_unsigned_basetype(integer_width);
switch (op) switch (op)
{ {
case OpGroupNonUniformElect: case OpGroupNonUniformElect:
@ -6185,20 +6218,39 @@ case OpGroupNonUniform##op: \
SPIRV_CROSS_THROW("Invalid group operation."); \ SPIRV_CROSS_THROW("Invalid group operation."); \
break; \ break; \
} }
#define GLSL_GROUP_OP_CAST(op, glsl_op, type) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op_cast(result_type, id, ops[4], "subgroup" #glsl_op, type, type); \
else if (operation == GroupOperationInclusiveScan) \
emit_unary_func_op_cast(result_type, id, ops[4], "subgroupInclusive" #glsl_op, type, type); \
else if (operation == GroupOperationExclusiveScan) \
emit_unary_func_op_cast(result_type, id, ops[4], "subgroupExclusive" #glsl_op, type, type); \
else if (operation == GroupOperationClusteredReduce) \
emit_binary_func_op_cast_clustered(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op, type); \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
break; \
}
GLSL_GROUP_OP(FAdd, Add) GLSL_GROUP_OP(FAdd, Add)
GLSL_GROUP_OP(FMul, Mul) GLSL_GROUP_OP(FMul, Mul)
GLSL_GROUP_OP(FMin, Min) GLSL_GROUP_OP(FMin, Min)
GLSL_GROUP_OP(FMax, Max) GLSL_GROUP_OP(FMax, Max)
GLSL_GROUP_OP(IAdd, Add) GLSL_GROUP_OP(IAdd, Add)
GLSL_GROUP_OP(IMul, Mul) GLSL_GROUP_OP(IMul, Mul)
GLSL_GROUP_OP(SMin, Min) GLSL_GROUP_OP_CAST(SMin, Min, int_type)
GLSL_GROUP_OP(SMax, Max) GLSL_GROUP_OP_CAST(SMax, Max, int_type)
GLSL_GROUP_OP(UMin, Min) GLSL_GROUP_OP_CAST(UMin, Min, uint_type)
GLSL_GROUP_OP(UMax, Max) GLSL_GROUP_OP_CAST(UMax, Max, uint_type)
GLSL_GROUP_OP(BitwiseAnd, And) GLSL_GROUP_OP(BitwiseAnd, And)
GLSL_GROUP_OP(BitwiseOr, Or) GLSL_GROUP_OP(BitwiseOr, Or)
GLSL_GROUP_OP(BitwiseXor, Xor) GLSL_GROUP_OP(BitwiseXor, Xor)
#undef GLSL_GROUP_OP #undef GLSL_GROUP_OP
#undef GLSL_GROUP_OP_CAST
// clang-format on // clang-format on
case OpGroupNonUniformQuadSwap: case OpGroupNonUniformQuadSwap:

View File

@ -467,6 +467,8 @@ protected:
SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type); SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type);
void emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, void emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
SPIRType::BaseType input_type, bool skip_cast_if_equal_type); SPIRType::BaseType input_type, bool skip_cast_if_equal_type);
void emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
const char *op, SPIRType::BaseType input_type);
void emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, void emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2,
const char *op, SPIRType::BaseType input_type); const char *op, SPIRType::BaseType input_type);
void emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, void emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,