MSL: Mask ballots passed to Ballot bit ops.

Only the least *n* bits are significant, where *n* is the subgroup size.
The Vulkan CTS actually checks this.

The `FindLSB` tests weren't actually failing, but I masked that anyway,
in case there's some corner case the CTS is missing.
This commit is contained in:
Chip Davis 2020-10-20 23:59:30 -05:00
parent 781367d083
commit 065b5bda3c
4 changed files with 83 additions and 32 deletions

View File

@ -27,31 +27,41 @@ inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
}
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
ballot &= mask;
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
}
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
ballot &= mask;
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot)
inline uint spvPopCount4(uint4 ballot)
{
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
return spvPopCount4(ballot & mask);
}
template<typename T>
@ -99,11 +109,11 @@ kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgrou
uint4 ballot_value = spvSubgroupBallot(true);
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
uint msb = spvSubgroupBallotFindMSB(ballot_value);
uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
uint shuffled = simd_shuffle(10u, 8u);
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
uint shuffled_up = simd_shuffle_up(20u, 4u);

View File

@ -25,31 +25,41 @@ inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
}
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
ballot &= mask;
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
}
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
ballot &= mask;
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot)
inline uint spvPopCount4(uint4 ballot)
{
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
}
inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
return spvPopCount4(ballot & mask);
}
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
{
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
return spvSubgroupBallotBitCount(ballot & mask);
return spvPopCount4(ballot & mask);
}
template<typename T>
@ -93,11 +103,11 @@ fragment main0_out main0()
uint4 ballot_value = spvSubgroupBallot(true);
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
uint msb = spvSubgroupBallotFindMSB(ballot_value);
uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
uint shuffled = simd_shuffle(10u, 8u);
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
uint shuffled_up = simd_shuffle_up(20u, 4u);

View File

@ -160,7 +160,7 @@ void CompilerMSL::build_implicit_builtins()
bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
needs_subgroup_invocation_id || need_sample_mask)
needs_subgroup_invocation_id || needs_subgroup_size || need_sample_mask)
{
bool has_frag_coord = false;
bool has_sample_id = false;
@ -287,7 +287,7 @@ void CompilerMSL::build_implicit_builtins()
has_subgroup_invocation_id = true;
}
if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
{
builtin_subgroup_size_id = var.self;
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
@ -593,7 +593,7 @@ void CompilerMSL::build_implicit_builtins()
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
}
if (!has_subgroup_size && need_subgroup_ge_mask)
if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
{
uint32_t offset = ir.increase_bound_by(2);
uint32_t type_ptr_id = offset;
@ -1280,6 +1280,8 @@ void CompilerMSL::preprocess_op_codes()
if (preproc.needs_subgroup_invocation_id)
needs_subgroup_invocation_id = true;
if (preproc.needs_subgroup_size)
needs_subgroup_size = true;
}
// Move the Private and Workgroup global variables to the entry function.
@ -4631,8 +4633,11 @@ void CompilerMSL::emit_custom_functions()
break;
case SPVFuncImplSubgroupBallotFindLSB:
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
statement("ballot &= mask;");
statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
"ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
end_scope();
@ -4640,8 +4645,11 @@ void CompilerMSL::emit_custom_functions()
break;
case SPVFuncImplSubgroupBallotFindMSB:
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
statement("ballot &= mask;");
statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
"(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
"ballot.z == 0), ballot.w == 0);");
@ -4650,24 +4658,31 @@ void CompilerMSL::emit_custom_functions()
break;
case SPVFuncImplSubgroupBallotBitCount:
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
statement("inline uint spvPopCount4(uint4 ballot)");
begin_scope();
statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
end_scope();
statement("");
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
statement("return spvPopCount4(ballot & mask);");
end_scope();
statement("");
statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
"uint2(0));");
statement("return spvSubgroupBallotBitCount(ballot & mask);");
statement("return spvPopCount4(ballot & mask);");
end_scope();
statement("");
statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
begin_scope();
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
statement("return spvSubgroupBallotBitCount(ballot & mask);");
statement("return spvPopCount4(ballot & mask);");
end_scope();
statement("");
break;
@ -11891,26 +11906,33 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
break;
case OpGroupNonUniformBallotFindLSB:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
break;
case OpGroupNonUniformBallotFindMSB:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
break;
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationReduce)
emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
else if (operation == GroupOperationInclusiveScan)
switch (operation)
{
case GroupOperationReduce:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
break;
case GroupOperationInclusiveScan:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotInclusiveBitCount");
else if (operation == GroupOperationExclusiveScan)
break;
case GroupOperationExclusiveScan:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotExclusiveBitCount");
else
break;
default:
SPIRV_CROSS_THROW("Invalid BitCount operation.");
break;
}
break;
}
@ -13010,8 +13032,15 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
needs_subgroup_invocation_id = true;
break;
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
needs_subgroup_size = true;
break;
case OpGroupNonUniformBallotBitCount:
if (args[3] != GroupOperationReduce)
if (args[3] == GroupOperationReduce)
needs_subgroup_size = true;
else
needs_subgroup_invocation_id = true;
break;

View File

@ -913,6 +913,7 @@ protected:
bool used_swizzle_buffer = false;
bool added_builtin_tess_level = false;
bool needs_subgroup_invocation_id = false;
bool needs_subgroup_size = false;
std::string qual_pos_var_name;
std::string stage_in_var_name = "in";
std::string stage_out_var_name = "out";
@ -984,6 +985,7 @@ protected:
bool uses_atomics = false;
bool uses_resource_write = false;
bool needs_subgroup_invocation_id = false;
bool needs_subgroup_size = false;
};
// OpcodeHandler that scans for uses of sampled images