GLSL: Support KHR_subgroup_arithmetic IMul/FMul

Support NV workarounds for IMul/FMul Reduce/InclusiveScan/ExclusiveScan
This commit is contained in:
georgeouzou 2023-04-03 19:13:30 +03:00
parent ab3a6212b8
commit 168e9f2cc9
2 changed files with 109 additions and 52 deletions

View File

@ -4050,10 +4050,48 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
break; break;
} }
case OpGroupNonUniformIMul:
{
type_infos.emplace_back(TypeInfo{ "uint", "1u" });
type_infos.emplace_back(TypeInfo{ "uvec2", "uvec2(1u)" });
type_infos.emplace_back(TypeInfo{ "uvec3", "uvec3(1u)" });
type_infos.emplace_back(TypeInfo{ "uvec4", "uvec4(1u)" });
type_infos.emplace_back(TypeInfo{ "int", "1" });
type_infos.emplace_back(TypeInfo{ "ivec2", "ivec2(1)" });
type_infos.emplace_back(TypeInfo{ "ivec3", "ivec3(1)" });
type_infos.emplace_back(TypeInfo{ "ivec4", "ivec4(1)" });
break;
}
case OpGroupNonUniformFMul:
{
type_infos.emplace_back(TypeInfo{ "float", "1.0f" });
type_infos.emplace_back(TypeInfo{ "vec2", "vec2(1.0f)" });
type_infos.emplace_back(TypeInfo{ "vec3", "vec3(1.0f)" });
type_infos.emplace_back(TypeInfo{ "vec4", "vec4(1.0f)" });
type_infos.emplace_back(TypeInfo{ "double", "0.0LF" });
type_infos.emplace_back(TypeInfo{ "dvec2", "dvec2(1.0LF)" });
type_infos.emplace_back(TypeInfo{ "dvec3", "dvec3(1.0LF)" });
type_infos.emplace_back(TypeInfo{ "dvec4", "dvec4(1.0LF)" });
break;
}
default: default:
SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation"); SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation");
} }
const bool op_is_addition = op == OpGroupNonUniformIAdd || op == OpGroupNonUniformFAdd;
const bool op_is_multiplication = op == OpGroupNonUniformIMul || op == OpGroupNonUniformFMul;
std::string op_symbol;
if (op_is_addition)
{
op_symbol = "+=";
}
else if (op_is_multiplication)
{
op_symbol = "*=";
}
for (const TypeInfo& t : type_infos) { for (const TypeInfo& t : type_infos) {
statement(t.type, " ", func, "(", t.type, " v)"); statement(t.type, " ", func, "(", t.type, " v)");
begin_scope(); begin_scope();
@ -4065,15 +4103,18 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
statement(result, " = v;"); statement(result, " = v;");
statement("for (uint i = 1u; i <= total; i <<= 1u)"); statement("for (uint i = 1u; i <= total; i <<= 1u)");
begin_scope(); begin_scope();
statement("bool valid;");
if (group_op == GroupOperationReduce) if (group_op == GroupOperationReduce)
{ {
statement(result, " += shuffleXorNV(", result, ", i, gl_SubgroupSize);"); statement(t.type, " s = shuffleXorNV(", result, ", i, gl_SubgroupSize, valid);");
} }
else if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan) else if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan)
{ {
statement("bool valid;");
statement(t.type, " s = shuffleUpNV(", result, ", i, gl_SubgroupSize, valid);"); statement(t.type, " s = shuffleUpNV(", result, ", i, gl_SubgroupSize, valid);");
statement(result, " += valid ? s : ", t.identity, ";"); }
if (op_is_addition || op_is_multiplication)
{
statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";");
} }
end_scope(); end_scope();
if (group_op == GroupOperationExclusiveScan) if (group_op == GroupOperationExclusiveScan)
@ -4103,7 +4144,10 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
{ {
statement("valid = valid && (i < total);"); statement("valid = valid && (i < total);");
} }
statement(result, " += valid ? s : ", t.identity, ";"); if (op_is_addition || op_is_multiplication)
{
statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";");
}
end_scope(); end_scope();
end_scope(); end_scope();
statement("return ", result, ";"); statement("return ", result, ";");
@ -4552,6 +4596,19 @@ void CompilerGLSL::emit_extension_workarounds(spv::ExecutionModel model)
OpGroupNonUniformFAdd, GroupOperationExclusiveScan); OpGroupNonUniformFAdd, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFAddInclusiveScan, "subgroupInclusiveAdd", arithmetic_feature_helper(Supp::SubgroupArithmeticFAddInclusiveScan, "subgroupInclusiveAdd",
OpGroupNonUniformFAdd, GroupOperationInclusiveScan); OpGroupNonUniformFAdd, GroupOperationInclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulReduce, "subgroupMul", OpGroupNonUniformIMul,
GroupOperationReduce);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulExclusiveScan, "subgroupExclusiveMul",
OpGroupNonUniformIMul, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulInclusiveScan, "subgroupInclusiveMul",
OpGroupNonUniformIMul, GroupOperationInclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulReduce, "subgroupMul", OpGroupNonUniformFMul,
GroupOperationReduce);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulExclusiveScan, "subgroupExclusiveMul",
OpGroupNonUniformFMul, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulInclusiveScan, "subgroupInclusiveMul",
OpGroupNonUniformFMul, GroupOperationInclusiveScan);
} }
if (!workaround_ubo_load_overload_types.empty()) if (!workaround_ubo_load_overload_types.empty())
@ -7273,6 +7330,8 @@ bool CompilerGLSL::is_supported_subgroup_op_in_opengl(spv::Op op, const uint32_t
return true; return true;
case OpGroupNonUniformIAdd: case OpGroupNonUniformIAdd:
case OpGroupNonUniformFAdd: case OpGroupNonUniformFAdd:
case OpGroupNonUniformIMul:
case OpGroupNonUniformFMul:
{ {
const GroupOperation operation = static_cast<GroupOperation>(ops[3]); const GroupOperation operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationReduce || operation == GroupOperationInclusiveScan || if (operation == GroupOperationReduce || operation == GroupOperationInclusiveScan ||
@ -8946,58 +9005,34 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
} }
break; break;
case OpGroupNonUniformIAdd: // clang-format off
{ #define GLSL_GROUP_OP(OP)\
auto operation = static_cast<GroupOperation>(ops[3]); case OpGroupNonUniform##OP:\
if (operation == GroupOperationClusteredReduce) {\
{ auto operation = static_cast<GroupOperation>(ops[3]);\
require_extension_internal("GL_KHR_shader_subgroup_clustered"); if (operation == GroupOperationClusteredReduce)\
} require_extension_internal("GL_KHR_shader_subgroup_clustered");\
else if (operation == GroupOperationReduce) else if (operation == GroupOperationReduce)\
{ request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##Reduce);\
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddReduce); else if (operation == GroupOperationExclusiveScan)\
} request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##ExclusiveScan);\
else if (operation == GroupOperationExclusiveScan) else if (operation == GroupOperationInclusiveScan)\
{ request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##InclusiveScan);\
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddExclusiveScan); else\
} SPIRV_CROSS_THROW("Invalid group operation.");\
else if (operation == GroupOperationInclusiveScan) break;\
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddInclusiveScan);
}
else
SPIRV_CROSS_THROW("Invalid group operation.");
break;
} }
case OpGroupNonUniformFAdd: GLSL_GROUP_OP(IAdd)
{ GLSL_GROUP_OP(FAdd)
auto operation = static_cast<GroupOperation>(ops[3]); GLSL_GROUP_OP(IMul)
if (operation == GroupOperationClusteredReduce) GLSL_GROUP_OP(FMul)
{
require_extension_internal("GL_KHR_shader_subgroup_clustered"); #undef GLSL_GROUP_OP
} // clang-format on
else if (operation == GroupOperationReduce)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddReduce);
}
else if (operation == GroupOperationExclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddExclusiveScan);
}
else if (operation == GroupOperationInclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddInclusiveScan);
}
else
SPIRV_CROSS_THROW("Invalid group operation.");
break;
}
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin: case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax: case OpGroupNonUniformFMax:
case OpGroupNonUniformIMul:
case OpGroupNonUniformSMin: case OpGroupNonUniformSMin:
case OpGroupNonUniformSMax: case OpGroupNonUniformSMax:
case OpGroupNonUniformUMin: case OpGroupNonUniformUMin:
@ -17916,9 +17951,15 @@ CompilerGLSL::ShaderSubgroupSupportHelper::FeatureVector CompilerGLSL::ShaderSub
case SubgroupArithmeticIAddInclusiveScan: case SubgroupArithmeticIAddInclusiveScan:
case SubgroupArithmeticFAddReduce: case SubgroupArithmeticFAddReduce:
case SubgroupArithmeticFAddInclusiveScan: case SubgroupArithmeticFAddInclusiveScan:
case SubgroupArithmeticIMulReduce:
case SubgroupArithmeticIMulInclusiveScan:
case SubgroupArithmeticFMulReduce:
case SubgroupArithmeticFMulInclusiveScan:
return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupBallotBitExtract }; return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupBallotBitExtract };
case SubgroupArithmeticIAddExclusiveScan: case SubgroupArithmeticIAddExclusiveScan:
case SubgroupArithmeticFAddExclusiveScan: case SubgroupArithmeticFAddExclusiveScan:
case SubgroupArithmeticIMulExclusiveScan:
case SubgroupArithmeticFMulExclusiveScan:
return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount,
SubgroupMask, SubgroupElect, SubgroupBallotBitExtract }; SubgroupMask, SubgroupElect, SubgroupBallotBitExtract };
default: default:
@ -17939,7 +17980,9 @@ bool CompilerGLSL::ShaderSubgroupSupportHelper::can_feature_be_implemented_witho
true, // SubgroupBalloFindLSB_MSB true, // SubgroupBalloFindLSB_MSB
false, false, false, false, false, false, false, false,
true, // SubgroupMemBarrier - replaced with workgroup memory barriers true, // SubgroupMemBarrier - replaced with workgroup memory barriers
false, false, true, false, false, false, false, false, false, false false, false, true, false,
false, false, false, false, false, false, // iadd, fadd
false, false, false, false, false, false, // imul , fmul
}; };
return retval[feature]; return retval[feature];
@ -17955,6 +17998,8 @@ CompilerGLSL::ShaderSubgroupSupportHelper::Candidate CompilerGLSL::ShaderSubgrou
KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
}; };
return extensions[feature]; return extensions[feature];
@ -18056,6 +18101,12 @@ CompilerGLSL::ShaderSubgroupSupportHelper::CandidateVector CompilerGLSL::ShaderS
case SubgroupArithmeticFAddReduce: case SubgroupArithmeticFAddReduce:
case SubgroupArithmeticFAddExclusiveScan: case SubgroupArithmeticFAddExclusiveScan:
case SubgroupArithmeticFAddInclusiveScan: case SubgroupArithmeticFAddInclusiveScan:
case SubgroupArithmeticIMulReduce:
case SubgroupArithmeticIMulExclusiveScan:
case SubgroupArithmeticIMulInclusiveScan:
case SubgroupArithmeticFMulReduce:
case SubgroupArithmeticFMulExclusiveScan:
case SubgroupArithmeticFMulInclusiveScan:
return { KHR_shader_subgroup_arithmetic, NV_shader_thread_shuffle }; return { KHR_shader_subgroup_arithmetic, NV_shader_thread_shuffle };
default: default:
return {}; return {};

View File

@ -331,6 +331,12 @@ protected:
SubgroupArithmeticFAddReduce = 19, SubgroupArithmeticFAddReduce = 19,
SubgroupArithmeticFAddExclusiveScan = 20, SubgroupArithmeticFAddExclusiveScan = 20,
SubgroupArithmeticFAddInclusiveScan = 21, SubgroupArithmeticFAddInclusiveScan = 21,
SubgroupArithmeticIMulReduce = 22,
SubgroupArithmeticIMulExclusiveScan = 23,
SubgroupArithmeticIMulInclusiveScan = 24,
SubgroupArithmeticFMulReduce = 25,
SubgroupArithmeticFMulExclusiveScan = 26,
SubgroupArithmeticFMulInclusiveScan = 27,
FeatureCount FeatureCount
}; };