MSL: Mask ballots passed to Ballot bit ops.
Only the least *n* bits are significant, where *n* is the subgroup size. The Vulkan CTS actually checks this. The `FindLSB` tests weren't actually failing, but I masked that anyway, in case there's some corner case the CTS is missing.
This commit is contained in:
parent
781367d083
commit
065b5bda3c
@ -27,31 +27,41 @@ inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
|
||||
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
ballot &= mask;
|
||||
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
ballot &= mask;
|
||||
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot)
|
||||
inline uint spvPopCount4(uint4 ballot)
|
||||
{
|
||||
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -99,11 +109,11 @@ kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgrou
|
||||
uint4 ballot_value = spvSubgroupBallot(true);
|
||||
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
|
||||
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
|
||||
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
|
||||
uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
|
||||
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
|
||||
uint msb = spvSubgroupBallotFindMSB(ballot_value);
|
||||
uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
|
||||
uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
|
||||
uint shuffled = simd_shuffle(10u, 8u);
|
||||
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
|
||||
uint shuffled_up = simd_shuffle_up(20u, 4u);
|
||||
|
@ -25,31 +25,41 @@ inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)
|
||||
return !!extract_bits(ballot[bit / 32], bit % 32, 1);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot)
|
||||
inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
ballot &= mask;
|
||||
return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot)
|
||||
inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
ballot &= mask;
|
||||
return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot)
|
||||
inline uint spvPopCount4(uint4 ballot)
|
||||
{
|
||||
return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)
|
||||
{
|
||||
uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));
|
||||
return spvSubgroupBallotBitCount(ballot & mask);
|
||||
return spvPopCount4(ballot & mask);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -93,11 +103,11 @@ fragment main0_out main0()
|
||||
uint4 ballot_value = spvSubgroupBallot(true);
|
||||
bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID);
|
||||
bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u);
|
||||
uint bit_count = spvSubgroupBallotBitCount(ballot_value);
|
||||
uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize);
|
||||
uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID);
|
||||
uint lsb = spvSubgroupBallotFindLSB(ballot_value);
|
||||
uint msb = spvSubgroupBallotFindMSB(ballot_value);
|
||||
uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize);
|
||||
uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize);
|
||||
uint shuffled = simd_shuffle(10u, 8u);
|
||||
uint shuffled_xor = simd_shuffle_xor(30u, 8u);
|
||||
uint shuffled_up = simd_shuffle_up(20u, 4u);
|
||||
|
@ -160,7 +160,7 @@ void CompilerMSL::build_implicit_builtins()
|
||||
bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
|
||||
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
|
||||
need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
|
||||
needs_subgroup_invocation_id || need_sample_mask)
|
||||
needs_subgroup_invocation_id || needs_subgroup_size || need_sample_mask)
|
||||
{
|
||||
bool has_frag_coord = false;
|
||||
bool has_sample_id = false;
|
||||
@ -287,7 +287,7 @@ void CompilerMSL::build_implicit_builtins()
|
||||
has_subgroup_invocation_id = true;
|
||||
}
|
||||
|
||||
if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
|
||||
if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
|
||||
{
|
||||
builtin_subgroup_size_id = var.self;
|
||||
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
|
||||
@ -593,7 +593,7 @@ void CompilerMSL::build_implicit_builtins()
|
||||
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
|
||||
}
|
||||
|
||||
if (!has_subgroup_size && need_subgroup_ge_mask)
|
||||
if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
|
||||
{
|
||||
uint32_t offset = ir.increase_bound_by(2);
|
||||
uint32_t type_ptr_id = offset;
|
||||
@ -1280,6 +1280,8 @@ void CompilerMSL::preprocess_op_codes()
|
||||
|
||||
if (preproc.needs_subgroup_invocation_id)
|
||||
needs_subgroup_invocation_id = true;
|
||||
if (preproc.needs_subgroup_size)
|
||||
needs_subgroup_size = true;
|
||||
}
|
||||
|
||||
// Move the Private and Workgroup global variables to the entry function.
|
||||
@ -4631,8 +4633,11 @@ void CompilerMSL::emit_custom_functions()
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotFindLSB:
|
||||
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
|
||||
statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
|
||||
statement("ballot &= mask;");
|
||||
statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
|
||||
"ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
|
||||
end_scope();
|
||||
@ -4640,8 +4645,11 @@ void CompilerMSL::emit_custom_functions()
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotFindMSB:
|
||||
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
|
||||
statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
|
||||
statement("ballot &= mask;");
|
||||
statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
|
||||
"(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
|
||||
"ballot.z == 0), ballot.w == 0);");
|
||||
@ -4650,24 +4658,31 @@ void CompilerMSL::emit_custom_functions()
|
||||
break;
|
||||
|
||||
case SPVFuncImplSubgroupBallotBitCount:
|
||||
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
|
||||
statement("inline uint spvPopCount4(uint4 ballot)");
|
||||
begin_scope();
|
||||
statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
|
||||
statement("return spvPopCount4(ballot & mask);");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
|
||||
"uint2(0));");
|
||||
statement("return spvSubgroupBallotBitCount(ballot & mask);");
|
||||
statement("return spvPopCount4(ballot & mask);");
|
||||
end_scope();
|
||||
statement("");
|
||||
statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
|
||||
begin_scope();
|
||||
statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
|
||||
"extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
|
||||
statement("return spvSubgroupBallotBitCount(ballot & mask);");
|
||||
statement("return spvPopCount4(ballot & mask);");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
@ -11891,26 +11906,33 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindLSB:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
|
||||
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindMSB:
|
||||
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
|
||||
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
{
|
||||
auto operation = static_cast<GroupOperation>(ops[3]);
|
||||
if (operation == GroupOperationReduce)
|
||||
emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
|
||||
else if (operation == GroupOperationInclusiveScan)
|
||||
switch (operation)
|
||||
{
|
||||
case GroupOperationReduce:
|
||||
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
|
||||
break;
|
||||
case GroupOperationInclusiveScan:
|
||||
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
|
||||
"spvSubgroupBallotInclusiveBitCount");
|
||||
else if (operation == GroupOperationExclusiveScan)
|
||||
break;
|
||||
case GroupOperationExclusiveScan:
|
||||
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
|
||||
"spvSubgroupBallotExclusiveBitCount");
|
||||
else
|
||||
break;
|
||||
default:
|
||||
SPIRV_CROSS_THROW("Invalid BitCount operation.");
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -13010,8 +13032,15 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
|
||||
needs_subgroup_invocation_id = true;
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotFindLSB:
|
||||
case OpGroupNonUniformBallotFindMSB:
|
||||
needs_subgroup_size = true;
|
||||
break;
|
||||
|
||||
case OpGroupNonUniformBallotBitCount:
|
||||
if (args[3] != GroupOperationReduce)
|
||||
if (args[3] == GroupOperationReduce)
|
||||
needs_subgroup_size = true;
|
||||
else
|
||||
needs_subgroup_invocation_id = true;
|
||||
break;
|
||||
|
||||
|
@ -913,6 +913,7 @@ protected:
|
||||
bool used_swizzle_buffer = false;
|
||||
bool added_builtin_tess_level = false;
|
||||
bool needs_subgroup_invocation_id = false;
|
||||
bool needs_subgroup_size = false;
|
||||
std::string qual_pos_var_name;
|
||||
std::string stage_in_var_name = "in";
|
||||
std::string stage_out_var_name = "out";
|
||||
@ -984,6 +985,7 @@ protected:
|
||||
bool uses_atomics = false;
|
||||
bool uses_resource_write = false;
|
||||
bool needs_subgroup_invocation_id = false;
|
||||
bool needs_subgroup_size = false;
|
||||
};
|
||||
|
||||
// OpcodeHandler that scans for uses of sampled images
|
||||
|
Loading…
Reference in New Issue
Block a user