diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index d7aef672..0098d6ac 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -4050,10 +4050,48 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op, break; } + case OpGroupNonUniformIMul: + { + type_infos.emplace_back(TypeInfo{ "uint", "1u" }); + type_infos.emplace_back(TypeInfo{ "uvec2", "uvec2(1u)" }); + type_infos.emplace_back(TypeInfo{ "uvec3", "uvec3(1u)" }); + type_infos.emplace_back(TypeInfo{ "uvec4", "uvec4(1u)" }); + type_infos.emplace_back(TypeInfo{ "int", "1" }); + type_infos.emplace_back(TypeInfo{ "ivec2", "ivec2(1)" }); + type_infos.emplace_back(TypeInfo{ "ivec3", "ivec3(1)" }); + type_infos.emplace_back(TypeInfo{ "ivec4", "ivec4(1)" }); + break; + } + + case OpGroupNonUniformFMul: + { + type_infos.emplace_back(TypeInfo{ "float", "1.0f" }); + type_infos.emplace_back(TypeInfo{ "vec2", "vec2(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec3", "vec3(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "vec4", "vec4(1.0f)" }); + type_infos.emplace_back(TypeInfo{ "double", "0.0LF" }); + type_infos.emplace_back(TypeInfo{ "dvec2", "dvec2(1.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec3", "dvec3(1.0LF)" }); + type_infos.emplace_back(TypeInfo{ "dvec4", "dvec4(1.0LF)" }); + break; + } + default: SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation"); } + const bool op_is_addition = op == OpGroupNonUniformIAdd || op == OpGroupNonUniformFAdd; + const bool op_is_multiplication = op == OpGroupNonUniformIMul || op == OpGroupNonUniformFMul; + std::string op_symbol; + if (op_is_addition) + { + op_symbol = "+="; + } + else if (op_is_multiplication) + { + op_symbol = "*="; + } + for (const TypeInfo& t : type_infos) { statement(t.type, " ", func, "(", t.type, " v)"); begin_scope(); @@ -4065,15 +4103,18 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op, statement(result, " = v;"); statement("for (uint i = 1u; i <= total; i <<= 1u)"); begin_scope(); + statement("bool valid;"); if (group_op == GroupOperationReduce) { - statement(result, " += shuffleXorNV(", result, ", i, gl_SubgroupSize);"); + statement(t.type, " s = shuffleXorNV(", result, ", i, gl_SubgroupSize, valid);"); } else if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan) { - statement("bool valid;"); statement(t.type, " s = shuffleUpNV(", result, ", i, gl_SubgroupSize, valid);"); - statement(result, " += valid ? s : ", t.identity, ";"); + } + if (op_is_addition || op_is_multiplication) + { + statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";"); } end_scope(); if (group_op == GroupOperationExclusiveScan) @@ -4103,7 +4144,10 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op, { statement("valid = valid && (i < total);"); } - statement(result, " += valid ? s : ", t.identity, ";"); + if (op_is_addition || op_is_multiplication) + { + statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";"); + } end_scope(); end_scope(); statement("return ", result, ";"); @@ -4552,6 +4596,19 @@ void CompilerGLSL::emit_extension_workarounds(spv::ExecutionModel model) OpGroupNonUniformFAdd, GroupOperationExclusiveScan); arithmetic_feature_helper(Supp::SubgroupArithmeticFAddInclusiveScan, "subgroupInclusiveAdd", OpGroupNonUniformFAdd, GroupOperationInclusiveScan); + + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulReduce, "subgroupMul", OpGroupNonUniformIMul, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulExclusiveScan, "subgroupExclusiveMul", + OpGroupNonUniformIMul, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticIMulInclusiveScan, "subgroupInclusiveMul", + OpGroupNonUniformIMul, GroupOperationInclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulReduce, "subgroupMul", OpGroupNonUniformFMul, + GroupOperationReduce); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulExclusiveScan, "subgroupExclusiveMul", + OpGroupNonUniformFMul, GroupOperationExclusiveScan); + arithmetic_feature_helper(Supp::SubgroupArithmeticFMulInclusiveScan, "subgroupInclusiveMul", + OpGroupNonUniformFMul, GroupOperationInclusiveScan); } if (!workaround_ubo_load_overload_types.empty()) @@ -7273,6 +7330,8 @@ bool CompilerGLSL::is_supported_subgroup_op_in_opengl(spv::Op op, const uint32_t return true; case OpGroupNonUniformIAdd: case OpGroupNonUniformFAdd: + case OpGroupNonUniformIMul: + case OpGroupNonUniformFMul: { const GroupOperation operation = static_cast(ops[3]); if (operation == GroupOperationReduce || operation == GroupOperationInclusiveScan || @@ -8946,58 +9005,34 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i) } break; - case OpGroupNonUniformIAdd: - { - auto operation = static_cast(ops[3]); - if (operation == GroupOperationClusteredReduce) - { - require_extension_internal("GL_KHR_shader_subgroup_clustered"); - } - else if (operation == GroupOperationReduce) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddReduce); - } - else if (operation == GroupOperationExclusiveScan) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddExclusiveScan); - } - else if (operation == GroupOperationInclusiveScan) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddInclusiveScan); - } - else - SPIRV_CROSS_THROW("Invalid group operation."); - break; + // clang-format off +#define GLSL_GROUP_OP(OP)\ + case OpGroupNonUniform##OP:\ + {\ + auto operation = static_cast(ops[3]);\ + if (operation == GroupOperationClusteredReduce)\ + require_extension_internal("GL_KHR_shader_subgroup_clustered");\ + else if (operation == GroupOperationReduce)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##Reduce);\ + else if (operation == GroupOperationExclusiveScan)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##ExclusiveScan);\ + else if (operation == GroupOperationInclusiveScan)\ + request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##InclusiveScan);\ + else\ + SPIRV_CROSS_THROW("Invalid group operation.");\ + break;\ } - case OpGroupNonUniformFAdd: - { - auto operation = static_cast(ops[3]); - if (operation == GroupOperationClusteredReduce) - { - require_extension_internal("GL_KHR_shader_subgroup_clustered"); - } - else if (operation == GroupOperationReduce) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddReduce); - } - else if (operation == GroupOperationExclusiveScan) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddExclusiveScan); - } - else if (operation == GroupOperationInclusiveScan) - { - request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddInclusiveScan); - } - else - SPIRV_CROSS_THROW("Invalid group operation."); - break; - } + GLSL_GROUP_OP(IAdd) + GLSL_GROUP_OP(FAdd) + GLSL_GROUP_OP(IMul) + GLSL_GROUP_OP(FMul) + +#undef GLSL_GROUP_OP + // clang-format on - case OpGroupNonUniformFMul: case OpGroupNonUniformFMin: case OpGroupNonUniformFMax: - case OpGroupNonUniformIMul: case OpGroupNonUniformSMin: case OpGroupNonUniformSMax: case OpGroupNonUniformUMin: @@ -17916,9 +17951,15 @@ CompilerGLSL::ShaderSubgroupSupportHelper::FeatureVector CompilerGLSL::ShaderSub case SubgroupArithmeticIAddInclusiveScan: case SubgroupArithmeticFAddReduce: case SubgroupArithmeticFAddInclusiveScan: + case SubgroupArithmeticIMulReduce: + case SubgroupArithmeticIMulInclusiveScan: + case SubgroupArithmeticFMulReduce: + case SubgroupArithmeticFMulInclusiveScan: return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupBallotBitExtract }; case SubgroupArithmeticIAddExclusiveScan: case SubgroupArithmeticFAddExclusiveScan: + case SubgroupArithmeticIMulExclusiveScan: + case SubgroupArithmeticFMulExclusiveScan: return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupElect, SubgroupBallotBitExtract }; default: @@ -17939,7 +17980,9 @@ bool CompilerGLSL::ShaderSubgroupSupportHelper::can_feature_be_implemented_witho true, // SubgroupBalloFindLSB_MSB false, false, false, false, true, // SubgroupMemBarrier - replaced with workgroup memory barriers - false, false, true, false, false, false, false, false, false, false + false, false, true, false, + false, false, false, false, false, false, // iadd, fadd + false, false, false, false, false, false, // imul , fmul }; return retval[feature]; @@ -17955,6 +17998,8 @@ CompilerGLSL::ShaderSubgroupSupportHelper::Candidate CompilerGLSL::ShaderSubgrou KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, + KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, }; return extensions[feature]; @@ -18056,6 +18101,12 @@ CompilerGLSL::ShaderSubgroupSupportHelper::CandidateVector CompilerGLSL::ShaderS case SubgroupArithmeticFAddReduce: case SubgroupArithmeticFAddExclusiveScan: case SubgroupArithmeticFAddInclusiveScan: + case SubgroupArithmeticIMulReduce: + case SubgroupArithmeticIMulExclusiveScan: + case SubgroupArithmeticIMulInclusiveScan: + case SubgroupArithmeticFMulReduce: + case SubgroupArithmeticFMulExclusiveScan: + case SubgroupArithmeticFMulInclusiveScan: return { KHR_shader_subgroup_arithmetic, NV_shader_thread_shuffle }; default: return {}; diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 8ac55026..8c8f470c 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -331,6 +331,12 @@ protected: SubgroupArithmeticFAddReduce = 19, SubgroupArithmeticFAddExclusiveScan = 20, SubgroupArithmeticFAddInclusiveScan = 21, + SubgroupArithmeticIMulReduce = 22, + SubgroupArithmeticIMulExclusiveScan = 23, + SubgroupArithmeticIMulInclusiveScan = 24, + SubgroupArithmeticFMulReduce = 25, + SubgroupArithmeticFMulExclusiveScan = 26, + SubgroupArithmeticFMulInclusiveScan = 27, FeatureCount };