From 5253da9e6352d8cc34b42312a029cd3c849c50b3 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Thu, 9 Jan 2020 12:01:54 +0100 Subject: [PATCH] GLSL: Deal with sign in subgroup Min/Max operations. --- ...up-arithmetic-cast.nocompat.vk.asm.frag.vk | 24 +++++++ ...group-arithmetic-cast.nocompat.vk.asm.frag | 65 +++++++++++++++++++ spirv_glsl.cpp | 60 +++++++++++++++-- spirv_glsl.hpp | 2 + 4 files changed, 147 insertions(+), 4 deletions(-) create mode 100644 reference/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag.vk create mode 100644 shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag diff --git a/reference/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag.vk b/reference/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag.vk new file mode 100644 index 00000000..130cab7d --- /dev/null +++ b/reference/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag.vk @@ -0,0 +1,24 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_clustered : require + +layout(location = 0) flat in int index; +layout(location = 0) out uint FragColor; + +void main() +{ + uint _17 = uint(index); + FragColor = uint(subgroupMin(index)); + FragColor = uint(subgroupMax(int(_17))); + FragColor = subgroupMin(uint(index)); + FragColor = subgroupMax(_17); + FragColor = uint(subgroupInclusiveMax(index)); + FragColor = uint(subgroupInclusiveMin(int(_17))); + FragColor = subgroupExclusiveMax(uint(index)); + FragColor = subgroupExclusiveMin(_17); + FragColor = uint(subgroupClusteredMin(index, 4u)); + FragColor = uint(subgroupClusteredMax(int(_17), 4u)); + FragColor = subgroupClusteredMin(uint(index), 4u); + FragColor = subgroupClusteredMax(_17, 4u); +} + diff --git a/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag b/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag new file mode 100644 index 00000000..a47c6b78 --- /dev/null +++ b/shaders-no-opt/asm/frag/subgroup-arithmetic-cast.nocompat.vk.asm.frag @@ -0,0 +1,65 @@ +; SPIR-V +; Version: 1.3 +; Generator: Khronos Glslang Reference Front End; 8 +; Bound: 78 +; Schema: 0 + OpCapability Shader + OpCapability GroupNonUniform + OpCapability GroupNonUniformArithmetic + OpCapability GroupNonUniformClustered + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %index %FragColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_KHR_shader_subgroup_arithmetic" + OpSourceExtension "GL_KHR_shader_subgroup_basic" + OpSourceExtension "GL_KHR_shader_subgroup_clustered" + OpName %main "main" + OpName %index "index" + OpName %FragColor "FragColor" + OpDecorate %index Flat + OpDecorate %index Location 0 + OpDecorate %FragColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %index = OpVariable %_ptr_Input_int Input + %uint_3 = OpConstant %uint 3 + %uint_4 = OpConstant %uint 4 +%_ptr_Output_uint = OpTypePointer Output %uint + %FragColor = OpVariable %_ptr_Output_uint Output + %main = OpFunction %void None %3 + %5 = OpLabel + %i = OpLoad %int %index + %u = OpBitcast %uint %i + %res0 = OpGroupNonUniformSMin %uint %uint_3 Reduce %i + %res1 = OpGroupNonUniformSMax %uint %uint_3 Reduce %u + %res2 = OpGroupNonUniformUMin %uint %uint_3 Reduce %i + %res3 = OpGroupNonUniformUMax %uint %uint_3 Reduce %u + %res4 = OpGroupNonUniformSMax %uint %uint_3 InclusiveScan %i + %res5 = OpGroupNonUniformSMin %uint %uint_3 InclusiveScan %u + %res6 = OpGroupNonUniformUMax %uint %uint_3 ExclusiveScan %i + %res7 = OpGroupNonUniformUMin %uint %uint_3 ExclusiveScan %u + %res8 = OpGroupNonUniformSMin %uint %uint_3 ClusteredReduce %i %uint_4 + %res9 = OpGroupNonUniformSMax %uint %uint_3 ClusteredReduce %u %uint_4 + %res10 = OpGroupNonUniformUMin %uint %uint_3 ClusteredReduce %i %uint_4 + %res11 = OpGroupNonUniformUMax %uint %uint_3 ClusteredReduce %u %uint_4 + OpStore %FragColor %res0 + OpStore %FragColor %res1 + OpStore %FragColor %res2 + OpStore %FragColor %res3 + OpStore %FragColor %res4 + OpStore %FragColor %res5 + OpStore %FragColor %res6 + OpStore %FragColor %res7 + OpStore %FragColor %res8 + OpStore %FragColor %res9 + OpStore %FragColor %res10 + OpStore %FragColor %res11 + OpReturn + OpFunctionEnd diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 20cb624a..99cf57c6 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -4464,6 +4464,34 @@ void CompilerGLSL::emit_trinary_func_op_cast(uint32_t result_type, uint32_t resu inherit_expression_dependencies(result_id, op2); } +void CompilerGLSL::emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, + uint32_t op1, const char *op, SPIRType::BaseType input_type) +{ + // Special purpose method for implementing clustered subgroup opcodes. + // Main difference is that op1 does not participate in any casting, it needs to be a literal. + auto &out_type = get(result_type); + auto expected_type = out_type; + expected_type.basetype = input_type; + string cast_op0 = + expression_type(op0).basetype != input_type ? bitcast_glsl(expected_type, op0) : to_unpacked_expression(op0); + + string expr; + if (out_type.basetype != input_type) + { + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")"); + expr += ')'; + } + else + { + expr += join(op, "(", cast_op0, ", ", to_expression(op1), ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0)); + inherit_expression_dependencies(result_id, op0); +} + void CompilerGLSL::emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type) { @@ -6014,6 +6042,11 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i) if (!options.vulkan_semantics) SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics."); + // If we need to do implicit bitcasts, make sure we do it with the correct type. + uint32_t integer_width = get_integer_width_for_instruction(i); + auto int_type = to_signed_basetype(integer_width); + auto uint_type = to_unsigned_basetype(integer_width); + switch (op) { case OpGroupNonUniformElect: @@ -6185,20 +6218,39 @@ case OpGroupNonUniform##op: \ SPIRV_CROSS_THROW("Invalid group operation."); \ break; \ } + +#define GLSL_GROUP_OP_CAST(op, glsl_op, type) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[3]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroup" #glsl_op, type, type); \ + else if (operation == GroupOperationInclusiveScan) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroupInclusive" #glsl_op, type, type); \ + else if (operation == GroupOperationExclusiveScan) \ + emit_unary_func_op_cast(result_type, id, ops[4], "subgroupExclusive" #glsl_op, type, type); \ + else if (operation == GroupOperationClusteredReduce) \ + emit_binary_func_op_cast_clustered(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op, type); \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + GLSL_GROUP_OP(FAdd, Add) GLSL_GROUP_OP(FMul, Mul) GLSL_GROUP_OP(FMin, Min) GLSL_GROUP_OP(FMax, Max) GLSL_GROUP_OP(IAdd, Add) GLSL_GROUP_OP(IMul, Mul) - GLSL_GROUP_OP(SMin, Min) - GLSL_GROUP_OP(SMax, Max) - GLSL_GROUP_OP(UMin, Min) - GLSL_GROUP_OP(UMax, Max) + GLSL_GROUP_OP_CAST(SMin, Min, int_type) + GLSL_GROUP_OP_CAST(SMax, Max, int_type) + GLSL_GROUP_OP_CAST(UMin, Min, uint_type) + GLSL_GROUP_OP_CAST(UMax, Max, uint_type) GLSL_GROUP_OP(BitwiseAnd, And) GLSL_GROUP_OP(BitwiseOr, Or) GLSL_GROUP_OP(BitwiseXor, Xor) #undef GLSL_GROUP_OP +#undef GLSL_GROUP_OP_CAST // clang-format on case OpGroupNonUniformQuadSwap: diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 32d8f42c..907f2fc7 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -467,6 +467,8 @@ protected: SPIRType::BaseType input_type, SPIRType::BaseType expected_result_type); void emit_binary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type); + void emit_binary_func_op_cast_clustered(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op, SPIRType::BaseType input_type); void emit_trinary_func_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, const char *op, SPIRType::BaseType input_type); void emit_trinary_func_op_bitextract(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,