HLSL: Implement GroupOperation(Inclusive/Exclusive)Scan.

This commit is contained in:
Pedro J. Estébanez 2022-07-20 22:00:35 +02:00
parent d8d051381f
commit 1fe470b199
3 changed files with 29 additions and 5 deletions

View File

@ -21,6 +21,8 @@ void comp_main()
float3 first = WaveReadLaneFirst(20.0f.xxx);
uint4 ballot_value = WaveActiveBallot(true);
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
uint inclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLeMask.x) + countbits(ballot_value.y & gl_SubgroupLeMask.y) + countbits(ballot_value.z & gl_SubgroupLeMask.z) + countbits(ballot_value.w & gl_SubgroupLeMask.w);
uint exclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLtMask.x) + countbits(ballot_value.y & gl_SubgroupLtMask.y) + countbits(ballot_value.z & gl_SubgroupLtMask.z) + countbits(ballot_value.w & gl_SubgroupLtMask.w);
uint shuffled = WaveReadLaneAt(10u, 8u);
uint shuffled_xor = WaveReadLaneAt(30u, WaveGetLaneIndex() ^ 8u);
uint shuffled_up = WaveReadLaneAt(20u, WaveGetLaneIndex() - 4u);

View File

@ -40,8 +40,8 @@ void main()
//bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
//bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
uint bit_count = subgroupBallotBitCount(ballot_value);
//uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
//uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
//uint lsb = subgroupBallotFindLSB(ballot_value);
//uint msb = subgroupBallotFindMSB(ballot_value);

View File

@ -4728,9 +4728,9 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
bool forward = should_forward(ops[4]);
if (operation == GroupOperationReduce)
{
bool forward = should_forward(ops[4]);
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
to_enclosed_expression(ops[4]), ".y)");
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
@ -4739,9 +4739,31 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
inherit_expression_dependencies(id, ops[4]);
}
else if (operation == GroupOperationInclusiveScan)
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
{
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLeMask.x) + countbits(",
to_enclosed_expression(ops[4]), ".y & gl_SubgroupLeMask.y)");
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLeMask.z) + countbits(",
to_enclosed_expression(ops[4]), ".w & gl_SubgroupLeMask.w)");
emit_op(result_type, id, join(left, " + ", right), forward);
if (!active_input_builtins.get(BuiltInSubgroupLeMask))
{
active_input_builtins.set(BuiltInSubgroupLeMask);
force_recompile_guarantee_forward_progress();
}
}
else if (operation == GroupOperationExclusiveScan)
SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
{
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLtMask.x) + countbits(",
to_enclosed_expression(ops[4]), ".y & gl_SubgroupLtMask.y)");
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLtMask.z) + countbits(",
to_enclosed_expression(ops[4]), ".w & gl_SubgroupLtMask.w)");
emit_op(result_type, id, join(left, " + ", right), forward);
if (!active_input_builtins.get(BuiltInSubgroupLtMask))
{
active_input_builtins.set(BuiltInSubgroupLtMask);
force_recompile_guarantee_forward_progress();
}
}
else
SPIRV_CROSS_THROW("Invalid BitCount operation.");
break;