Implement all of subgroup.
This commit is contained in:
parent
f6c0e53f58
commit
e1ccfd5dbb
@ -60,17 +60,29 @@ void main()
|
|||||||
|
|
||||||
// arithmetic
|
// arithmetic
|
||||||
vec4 added = subgroupAdd(vec4(20.0));
|
vec4 added = subgroupAdd(vec4(20.0));
|
||||||
|
ivec4 iadded = subgroupAdd(ivec4(20));
|
||||||
vec4 multiplied = subgroupMul(vec4(20.0));
|
vec4 multiplied = subgroupMul(vec4(20.0));
|
||||||
|
ivec4 imultiplied = subgroupMul(ivec4(20));
|
||||||
vec4 lo = subgroupMin(vec4(20.0));
|
vec4 lo = subgroupMin(vec4(20.0));
|
||||||
vec4 hi = subgroupMax(vec4(20.0));
|
vec4 hi = subgroupMax(vec4(20.0));
|
||||||
|
ivec4 slo = subgroupMin(ivec4(20));
|
||||||
|
ivec4 shi = subgroupMax(ivec4(20));
|
||||||
|
uvec4 ulo = subgroupMin(uvec4(20));
|
||||||
|
uvec4 uhi = subgroupMax(uvec4(20));
|
||||||
uvec4 anded = subgroupAnd(ballot_value);
|
uvec4 anded = subgroupAnd(ballot_value);
|
||||||
uvec4 ored = subgroupOr(ballot_value);
|
uvec4 ored = subgroupOr(ballot_value);
|
||||||
uvec4 xored = subgroupXor(ballot_value);
|
uvec4 xored = subgroupXor(ballot_value);
|
||||||
|
|
||||||
added = subgroupInclusiveAdd(added);
|
added = subgroupInclusiveAdd(added);
|
||||||
|
iadded = subgroupInclusiveAdd(iadded);
|
||||||
multiplied = subgroupInclusiveMul(multiplied);
|
multiplied = subgroupInclusiveMul(multiplied);
|
||||||
|
imultiplied = subgroupInclusiveMul(imultiplied);
|
||||||
lo = subgroupInclusiveMin(lo);
|
lo = subgroupInclusiveMin(lo);
|
||||||
hi = subgroupInclusiveMax(hi);
|
hi = subgroupInclusiveMax(hi);
|
||||||
|
slo = subgroupInclusiveMin(slo);
|
||||||
|
shi = subgroupInclusiveMax(shi);
|
||||||
|
ulo = subgroupInclusiveMin(ulo);
|
||||||
|
uhi = subgroupInclusiveMax(uhi);
|
||||||
anded = subgroupInclusiveAnd(anded);
|
anded = subgroupInclusiveAnd(anded);
|
||||||
ored = subgroupInclusiveOr(ored);
|
ored = subgroupInclusiveOr(ored);
|
||||||
xored = subgroupInclusiveXor(ored);
|
xored = subgroupInclusiveXor(ored);
|
||||||
@ -78,8 +90,14 @@ void main()
|
|||||||
|
|
||||||
added = subgroupExclusiveAdd(multiplied);
|
added = subgroupExclusiveAdd(multiplied);
|
||||||
multiplied = subgroupExclusiveMul(multiplied);
|
multiplied = subgroupExclusiveMul(multiplied);
|
||||||
|
iadded = subgroupExclusiveAdd(imultiplied);
|
||||||
|
imultiplied = subgroupExclusiveMul(imultiplied);
|
||||||
lo = subgroupExclusiveMin(lo);
|
lo = subgroupExclusiveMin(lo);
|
||||||
hi = subgroupExclusiveMax(hi);
|
hi = subgroupExclusiveMax(hi);
|
||||||
|
ulo = subgroupExclusiveMin(ulo);
|
||||||
|
uhi = subgroupExclusiveMax(uhi);
|
||||||
|
slo = subgroupExclusiveMin(slo);
|
||||||
|
shi = subgroupExclusiveMax(shi);
|
||||||
anded = subgroupExclusiveAnd(anded);
|
anded = subgroupExclusiveAnd(anded);
|
||||||
ored = subgroupExclusiveOr(ored);
|
ored = subgroupExclusiveOr(ored);
|
||||||
xored = subgroupExclusiveXor(ored);
|
xored = subgroupExclusiveXor(ored);
|
||||||
@ -87,8 +105,14 @@ void main()
|
|||||||
// clustered
|
// clustered
|
||||||
added = subgroupClusteredAdd(added, 4u);
|
added = subgroupClusteredAdd(added, 4u);
|
||||||
multiplied = subgroupClusteredMul(multiplied, 4u);
|
multiplied = subgroupClusteredMul(multiplied, 4u);
|
||||||
|
iadded = subgroupClusteredAdd(iadded, 4u);
|
||||||
|
imultiplied = subgroupClusteredMul(imultiplied, 4u);
|
||||||
lo = subgroupClusteredMin(lo, 4u);
|
lo = subgroupClusteredMin(lo, 4u);
|
||||||
hi = subgroupClusteredMax(hi, 4u);
|
hi = subgroupClusteredMax(hi, 4u);
|
||||||
|
ulo = subgroupClusteredMin(ulo, 4u);
|
||||||
|
uhi = subgroupClusteredMax(uhi, 4u);
|
||||||
|
slo = subgroupClusteredMin(slo, 4u);
|
||||||
|
shi = subgroupClusteredMax(shi, 4u);
|
||||||
anded = subgroupClusteredAnd(anded, 4u);
|
anded = subgroupClusteredAnd(anded, 4u);
|
||||||
ored = subgroupClusteredOr(ored, 4u);
|
ored = subgroupClusteredOr(ored, 4u);
|
||||||
xored = subgroupClusteredXor(xored, 4u);
|
xored = subgroupClusteredXor(xored, 4u);
|
||||||
|
111
spirv_glsl.cpp
111
spirv_glsl.cpp
@ -4511,6 +4511,12 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
|
|||||||
case OpGroupNonUniformFMul:
|
case OpGroupNonUniformFMul:
|
||||||
case OpGroupNonUniformFMin:
|
case OpGroupNonUniformFMin:
|
||||||
case OpGroupNonUniformFMax:
|
case OpGroupNonUniformFMax:
|
||||||
|
case OpGroupNonUniformIAdd:
|
||||||
|
case OpGroupNonUniformIMul:
|
||||||
|
case OpGroupNonUniformSMin:
|
||||||
|
case OpGroupNonUniformSMax:
|
||||||
|
case OpGroupNonUniformUMin:
|
||||||
|
case OpGroupNonUniformUMax:
|
||||||
case OpGroupNonUniformBitwiseAnd:
|
case OpGroupNonUniformBitwiseAnd:
|
||||||
case OpGroupNonUniformBitwiseOr:
|
case OpGroupNonUniformBitwiseOr:
|
||||||
case OpGroupNonUniformBitwiseXor:
|
case OpGroupNonUniformBitwiseXor:
|
||||||
@ -4573,27 +4579,106 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
|
|||||||
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBallotBitExtract");
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBallotBitExtract");
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformBallotBitCount:
|
|
||||||
case OpGroupNonUniformBallotFindLSB:
|
case OpGroupNonUniformBallotFindLSB:
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupBallotFindLSB");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformBallotFindMSB:
|
case OpGroupNonUniformBallotFindMSB:
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupBallotFindMSB");
|
||||||
|
break;
|
||||||
|
|
||||||
|
case OpGroupNonUniformBallotBitCount:
|
||||||
|
{
|
||||||
|
auto operation = static_cast<GroupOperation>(ops[3]);
|
||||||
|
if (operation == GroupOperationReduce)
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroupBallotBitCount");
|
||||||
|
else if (operation == GroupOperationInclusiveScan)
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroupBallotInclusiveBitCount");
|
||||||
|
else if (operation == GroupOperationExclusiveScan)
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroupBallotExclusiveBitCount");
|
||||||
|
else
|
||||||
|
SPIRV_CROSS_THROW("Invalid BitCount operation.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case OpGroupNonUniformShuffle:
|
case OpGroupNonUniformShuffle:
|
||||||
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffle");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformShuffleXor:
|
case OpGroupNonUniformShuffleXor:
|
||||||
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleXor");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformShuffleUp:
|
case OpGroupNonUniformShuffleUp:
|
||||||
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleUp");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformShuffleDown:
|
case OpGroupNonUniformShuffleDown:
|
||||||
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupShuffleDown");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformAll:
|
case OpGroupNonUniformAll:
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupAll");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformAny:
|
case OpGroupNonUniformAny:
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupAny");
|
||||||
|
break;
|
||||||
|
|
||||||
case OpGroupNonUniformAllEqual:
|
case OpGroupNonUniformAllEqual:
|
||||||
case OpGroupNonUniformFAdd:
|
emit_unary_func_op(result_type, id, ops[3], "subgroupAllEqual");
|
||||||
case OpGroupNonUniformFMul:
|
break;
|
||||||
case OpGroupNonUniformFMin:
|
|
||||||
case OpGroupNonUniformFMax:
|
#define GROUP_OP(op, glsl_op) \
|
||||||
case OpGroupNonUniformBitwiseAnd:
|
case OpGroupNonUniform##op: \
|
||||||
case OpGroupNonUniformBitwiseOr:
|
{ \
|
||||||
case OpGroupNonUniformBitwiseXor:
|
auto operation = static_cast<GroupOperation>(ops[3]); \
|
||||||
|
if (operation == GroupOperationReduce) \
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroup" #glsl_op); \
|
||||||
|
else if (operation == GroupOperationInclusiveScan) \
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroupInclusive" #glsl_op); \
|
||||||
|
else if (operation == GroupOperationExclusiveScan) \
|
||||||
|
emit_unary_func_op(result_type, id, ops[4], "subgroupExclusive" #glsl_op); \
|
||||||
|
else if (operation == GroupOperationClusteredReduce) \
|
||||||
|
emit_binary_func_op(result_type, id, ops[4], ops[5], "subgroupClustered" #glsl_op); \
|
||||||
|
else \
|
||||||
|
SPIRV_CROSS_THROW("Invalid group operation."); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
GROUP_OP(FAdd, Add)
|
||||||
|
GROUP_OP(FMul, Mul)
|
||||||
|
GROUP_OP(FMin, Min)
|
||||||
|
GROUP_OP(FMax, Max)
|
||||||
|
GROUP_OP(IAdd, Add)
|
||||||
|
GROUP_OP(IMul, Mul)
|
||||||
|
GROUP_OP(SMin, Min)
|
||||||
|
GROUP_OP(SMax, Max)
|
||||||
|
GROUP_OP(UMin, Min)
|
||||||
|
GROUP_OP(UMax, Max)
|
||||||
|
GROUP_OP(BitwiseAnd, And)
|
||||||
|
GROUP_OP(BitwiseOr, Or)
|
||||||
|
GROUP_OP(BitwiseXor, Xor)
|
||||||
|
#undef GROUP_OP
|
||||||
|
|
||||||
case OpGroupNonUniformQuadSwap:
|
case OpGroupNonUniformQuadSwap:
|
||||||
|
{
|
||||||
|
uint32_t direction = get<SPIRConstant>(ops[4]).scalar();
|
||||||
|
if (direction == 0)
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapHorizontal");
|
||||||
|
else if (direction == 1)
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapVertical");
|
||||||
|
else if (direction == 2)
|
||||||
|
emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapDiagonal");
|
||||||
|
else
|
||||||
|
SPIRV_CROSS_THROW("Invalid quad swap direction.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case OpGroupNonUniformQuadBroadcast:
|
case OpGroupNonUniformQuadBroadcast:
|
||||||
emit_op(result_type, id, "subgroupRandom()", false);
|
{
|
||||||
return;
|
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupQuadBroadcast");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
|
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
|
||||||
@ -7691,9 +7776,15 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
|
|||||||
case OpGroupNonUniformAny:
|
case OpGroupNonUniformAny:
|
||||||
case OpGroupNonUniformAllEqual:
|
case OpGroupNonUniformAllEqual:
|
||||||
case OpGroupNonUniformFAdd:
|
case OpGroupNonUniformFAdd:
|
||||||
|
case OpGroupNonUniformIAdd:
|
||||||
case OpGroupNonUniformFMul:
|
case OpGroupNonUniformFMul:
|
||||||
|
case OpGroupNonUniformIMul:
|
||||||
case OpGroupNonUniformFMin:
|
case OpGroupNonUniformFMin:
|
||||||
case OpGroupNonUniformFMax:
|
case OpGroupNonUniformFMax:
|
||||||
|
case OpGroupNonUniformSMin:
|
||||||
|
case OpGroupNonUniformSMax:
|
||||||
|
case OpGroupNonUniformUMin:
|
||||||
|
case OpGroupNonUniformUMax:
|
||||||
case OpGroupNonUniformBitwiseAnd:
|
case OpGroupNonUniformBitwiseAnd:
|
||||||
case OpGroupNonUniformBitwiseOr:
|
case OpGroupNonUniformBitwiseOr:
|
||||||
case OpGroupNonUniformBitwiseXor:
|
case OpGroupNonUniformBitwiseXor:
|
||||||
|
Loading…
Reference in New Issue
Block a user