GLSL: Support KHR_subgroup_arithmetic IMul/FMul

Support NV workarounds for IMul/FMul Reduce/InclusiveScan/ExclusiveScan
This commit is contained in:
georgeouzou 2023-04-03 19:13:30 +03:00
parent ab3a6212b8
commit 168e9f2cc9
2 changed files with 109 additions and 52 deletions

View File

@ -4050,10 +4050,48 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
break;
}
case OpGroupNonUniformIMul:
{
type_infos.emplace_back(TypeInfo{ "uint", "1u" });
type_infos.emplace_back(TypeInfo{ "uvec2", "uvec2(1u)" });
type_infos.emplace_back(TypeInfo{ "uvec3", "uvec3(1u)" });
type_infos.emplace_back(TypeInfo{ "uvec4", "uvec4(1u)" });
type_infos.emplace_back(TypeInfo{ "int", "1" });
type_infos.emplace_back(TypeInfo{ "ivec2", "ivec2(1)" });
type_infos.emplace_back(TypeInfo{ "ivec3", "ivec3(1)" });
type_infos.emplace_back(TypeInfo{ "ivec4", "ivec4(1)" });
break;
}
case OpGroupNonUniformFMul:
{
type_infos.emplace_back(TypeInfo{ "float", "1.0f" });
type_infos.emplace_back(TypeInfo{ "vec2", "vec2(1.0f)" });
type_infos.emplace_back(TypeInfo{ "vec3", "vec3(1.0f)" });
type_infos.emplace_back(TypeInfo{ "vec4", "vec4(1.0f)" });
type_infos.emplace_back(TypeInfo{ "double", "0.0LF" });
type_infos.emplace_back(TypeInfo{ "dvec2", "dvec2(1.0LF)" });
type_infos.emplace_back(TypeInfo{ "dvec3", "dvec3(1.0LF)" });
type_infos.emplace_back(TypeInfo{ "dvec4", "dvec4(1.0LF)" });
break;
}
default:
SPIRV_CROSS_THROW("Unsupported workaround for arithmetic group operation");
}
const bool op_is_addition = op == OpGroupNonUniformIAdd || op == OpGroupNonUniformFAdd;
const bool op_is_multiplication = op == OpGroupNonUniformIMul || op == OpGroupNonUniformFMul;
std::string op_symbol;
if (op_is_addition)
{
op_symbol = "+=";
}
else if (op_is_multiplication)
{
op_symbol = "*=";
}
for (const TypeInfo& t : type_infos) {
statement(t.type, " ", func, "(", t.type, " v)");
begin_scope();
@ -4065,15 +4103,18 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
statement(result, " = v;");
statement("for (uint i = 1u; i <= total; i <<= 1u)");
begin_scope();
statement("bool valid;");
if (group_op == GroupOperationReduce)
{
statement(result, " += shuffleXorNV(", result, ", i, gl_SubgroupSize);");
statement(t.type, " s = shuffleXorNV(", result, ", i, gl_SubgroupSize, valid);");
}
else if (group_op == GroupOperationExclusiveScan || group_op == GroupOperationInclusiveScan)
{
statement("bool valid;");
statement(t.type, " s = shuffleUpNV(", result, ", i, gl_SubgroupSize, valid);");
statement(result, " += valid ? s : ", t.identity, ";");
}
if (op_is_addition || op_is_multiplication)
{
statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";");
}
end_scope();
if (group_op == GroupOperationExclusiveScan)
@ -4103,7 +4144,10 @@ void CompilerGLSL::emit_subgroup_arithmetic_workaround(std::string func, Op op,
{
statement("valid = valid && (i < total);");
}
statement(result, " += valid ? s : ", t.identity, ";");
if (op_is_addition || op_is_multiplication)
{
statement(result, " ", op_symbol, " valid ? s : ", t.identity, ";");
}
end_scope();
end_scope();
statement("return ", result, ";");
@ -4552,6 +4596,19 @@ void CompilerGLSL::emit_extension_workarounds(spv::ExecutionModel model)
OpGroupNonUniformFAdd, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFAddInclusiveScan, "subgroupInclusiveAdd",
OpGroupNonUniformFAdd, GroupOperationInclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulReduce, "subgroupMul", OpGroupNonUniformIMul,
GroupOperationReduce);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulExclusiveScan, "subgroupExclusiveMul",
OpGroupNonUniformIMul, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticIMulInclusiveScan, "subgroupInclusiveMul",
OpGroupNonUniformIMul, GroupOperationInclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulReduce, "subgroupMul", OpGroupNonUniformFMul,
GroupOperationReduce);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulExclusiveScan, "subgroupExclusiveMul",
OpGroupNonUniformFMul, GroupOperationExclusiveScan);
arithmetic_feature_helper(Supp::SubgroupArithmeticFMulInclusiveScan, "subgroupInclusiveMul",
OpGroupNonUniformFMul, GroupOperationInclusiveScan);
}
if (!workaround_ubo_load_overload_types.empty())
@ -7273,6 +7330,8 @@ bool CompilerGLSL::is_supported_subgroup_op_in_opengl(spv::Op op, const uint32_t
return true;
case OpGroupNonUniformIAdd:
case OpGroupNonUniformFAdd:
case OpGroupNonUniformIMul:
case OpGroupNonUniformFMul:
{
const GroupOperation operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationReduce || operation == GroupOperationInclusiveScan ||
@ -8946,58 +9005,34 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
}
break;
case OpGroupNonUniformIAdd:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationClusteredReduce)
{
require_extension_internal("GL_KHR_shader_subgroup_clustered");
}
else if (operation == GroupOperationReduce)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddReduce);
}
else if (operation == GroupOperationExclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddExclusiveScan);
}
else if (operation == GroupOperationInclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticIAddInclusiveScan);
}
else
SPIRV_CROSS_THROW("Invalid group operation.");
break;
// clang-format off
#define GLSL_GROUP_OP(OP)\
case OpGroupNonUniform##OP:\
{\
auto operation = static_cast<GroupOperation>(ops[3]);\
if (operation == GroupOperationClusteredReduce)\
require_extension_internal("GL_KHR_shader_subgroup_clustered");\
else if (operation == GroupOperationReduce)\
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##Reduce);\
else if (operation == GroupOperationExclusiveScan)\
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##ExclusiveScan);\
else if (operation == GroupOperationInclusiveScan)\
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmetic##OP##InclusiveScan);\
else\
SPIRV_CROSS_THROW("Invalid group operation.");\
break;\
}
case OpGroupNonUniformFAdd:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationClusteredReduce)
{
require_extension_internal("GL_KHR_shader_subgroup_clustered");
}
else if (operation == GroupOperationReduce)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddReduce);
}
else if (operation == GroupOperationExclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddExclusiveScan);
}
else if (operation == GroupOperationInclusiveScan)
{
request_subgroup_feature(ShaderSubgroupSupportHelper::SubgroupArithmeticFAddInclusiveScan);
}
else
SPIRV_CROSS_THROW("Invalid group operation.");
break;
}
GLSL_GROUP_OP(IAdd)
GLSL_GROUP_OP(FAdd)
GLSL_GROUP_OP(IMul)
GLSL_GROUP_OP(FMul)
#undef GLSL_GROUP_OP
// clang-format on
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformIMul:
case OpGroupNonUniformSMin:
case OpGroupNonUniformSMax:
case OpGroupNonUniformUMin:
@ -17916,9 +17951,15 @@ CompilerGLSL::ShaderSubgroupSupportHelper::FeatureVector CompilerGLSL::ShaderSub
case SubgroupArithmeticIAddInclusiveScan:
case SubgroupArithmeticFAddReduce:
case SubgroupArithmeticFAddInclusiveScan:
case SubgroupArithmeticIMulReduce:
case SubgroupArithmeticIMulInclusiveScan:
case SubgroupArithmeticFMulReduce:
case SubgroupArithmeticFMulInclusiveScan:
return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount, SubgroupMask, SubgroupBallotBitExtract };
case SubgroupArithmeticIAddExclusiveScan:
case SubgroupArithmeticFAddExclusiveScan:
case SubgroupArithmeticIMulExclusiveScan:
case SubgroupArithmeticFMulExclusiveScan:
return { SubgroupSize, SubgroupBallot, SubgroupBallotBitCount,
SubgroupMask, SubgroupElect, SubgroupBallotBitExtract };
default:
@ -17939,7 +17980,9 @@ bool CompilerGLSL::ShaderSubgroupSupportHelper::can_feature_be_implemented_witho
true, // SubgroupBalloFindLSB_MSB
false, false, false, false,
true, // SubgroupMemBarrier - replaced with workgroup memory barriers
false, false, true, false, false, false, false, false, false, false
false, false, true, false,
false, false, false, false, false, false, // iadd, fadd
false, false, false, false, false, false, // imul , fmul
};
return retval[feature];
@ -17955,6 +17998,8 @@ CompilerGLSL::ShaderSubgroupSupportHelper::Candidate CompilerGLSL::ShaderSubgrou
KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot, KHR_shader_subgroup_ballot,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic, KHR_shader_subgroup_arithmetic,
};
return extensions[feature];
@ -18056,6 +18101,12 @@ CompilerGLSL::ShaderSubgroupSupportHelper::CandidateVector CompilerGLSL::ShaderS
case SubgroupArithmeticFAddReduce:
case SubgroupArithmeticFAddExclusiveScan:
case SubgroupArithmeticFAddInclusiveScan:
case SubgroupArithmeticIMulReduce:
case SubgroupArithmeticIMulExclusiveScan:
case SubgroupArithmeticIMulInclusiveScan:
case SubgroupArithmeticFMulReduce:
case SubgroupArithmeticFMulExclusiveScan:
case SubgroupArithmeticFMulInclusiveScan:
return { KHR_shader_subgroup_arithmetic, NV_shader_thread_shuffle };
default:
return {};

View File

@ -331,6 +331,12 @@ protected:
SubgroupArithmeticFAddReduce = 19,
SubgroupArithmeticFAddExclusiveScan = 20,
SubgroupArithmeticFAddInclusiveScan = 21,
SubgroupArithmeticIMulReduce = 22,
SubgroupArithmeticIMulExclusiveScan = 23,
SubgroupArithmeticIMulInclusiveScan = 24,
SubgroupArithmeticFMulReduce = 25,
SubgroupArithmeticFMulExclusiveScan = 26,
SubgroupArithmeticFMulInclusiveScan = 27,
FeatureCount
};