Merge pull request #2288 from KhronosGroup/fix-2272

MSL: Improve handling of BDA + atomics.
This commit is contained in:
Hans-Kristian Arntzen 2024-03-05 15:03:56 +01:00 committed by GitHub
commit 4db95b762f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 10 deletions

View File

@ -0,0 +1,41 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
struct Ptr;
struct Registers
{
device Ptr* ptr;
};
struct Ptr
{
uint i;
uint2 i2;
};
struct UBO
{
device Ptr* ptr_ubo;
};
struct SSBO
{
device Ptr* ptr_ssbo;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
kernel void main0(constant Registers& _12 [[buffer(0)]], constant UBO& _26 [[buffer(1)]], const device SSBO& _35 [[buffer(2)]])
{
uint _23 = atomic_fetch_add_explicit((device atomic_uint*)&_12.ptr->i, 10u, memory_order_relaxed);
uint _32 = atomic_fetch_add_explicit((device atomic_uint*)&_26.ptr_ubo->i, 11u, memory_order_relaxed);
uint _41 = atomic_fetch_add_explicit((device atomic_uint*)&_35.ptr_ssbo->i, 12u, memory_order_relaxed);
uint _51 = atomic_fetch_add_explicit((device atomic_uint*)&(reinterpret_cast<device Ptr*>(as_type<ulong>(_12.ptr->i2)))->i, 13u, memory_order_relaxed);
}

View File

@ -0,0 +1,34 @@
#version 450
#extension GL_EXT_buffer_reference : require
#extension GL_EXT_buffer_reference_uvec2 : require
layout(local_size_x = 1) in;
layout(buffer_reference) buffer Ptr
{
uint i;
uvec2 i2;
};
layout(push_constant, std430) uniform Registers
{
Ptr ptr;
};
layout(set = 0, binding = 0) uniform UBO
{
Ptr ptr_ubo;
};
layout(set = 0, binding = 1) readonly buffer SSBO
{
Ptr ptr_ssbo;
};
void main()
{
atomicAdd(ptr.i, 10u);
atomicAdd(ptr_ubo.i, 11u);
atomicAdd(ptr_ssbo.i, 12u);
atomicAdd(Ptr(ptr.i2).i, 13u);
}

View File

@ -10127,7 +10127,8 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
{
string exp;
auto &type = get_pointee_type(expression_type(obj));
auto &ptr_type = expression_type(obj);
auto &type = get_pointee_type(ptr_type);
auto expected_type = type.basetype;
if (opcode == OpAtomicUMax || opcode == OpAtomicUMin)
expected_type = to_unsigned_basetype(type.width);
@ -10147,15 +10148,13 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
remapped_type.basetype = expected_type;
auto *var = maybe_get_backing_variable(obj);
if (!var)
SPIRV_CROSS_THROW("No backing variable for atomic operation.");
const auto &res_type = get<SPIRType>(var->basetype);
const auto *res_type = var ? &get<SPIRType>(var->basetype) : nullptr;
assert(type.storage != StorageClassImage || res_type);
bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
bool check_discard = opcode != OpAtomicLoad && needs_frag_discard_checks() &&
((res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image) ||
var->storage == StorageClassStorageBuffer || var->storage == StorageClassUniform);
ptr_type.storage != StorageClassWorkgroup;
// Even compare exchange atomics are vec4 on metal for ... reasons :v
uint32_t vec4_temporary_id = 0;
@ -10199,9 +10198,9 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
{
auto coord = obj_expression.substr(split_index + 1);
exp += join(obj_expression.substr(0, split_index), ".", op, "(");
if (res_type.basetype == SPIRType::Image && res_type.image.arrayed)
if (ptr_type.storage == StorageClassImage && res_type->image.arrayed)
{
switch (res_type.image.dim)
switch (res_type->image.dim)
{
case Dim1D:
exp += join(coord, ".x, ", coord, ".y");
@ -10228,17 +10227,22 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
exp += string(op) + "_explicit(";
exp += "(";
// Emulate texture2D atomic operations
if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
if (ptr_type.storage == StorageClassImage)
{
auto &flags = ir.get_decoration_bitset(var->self);
if (decoration_flags_signal_volatile(flags))
exp += "volatile ";
exp += "device";
}
else
else if (var && ptr_type.storage != StorageClassPhysicalStorageBuffer)
{
exp += get_argument_address_space(*var);
}
else
{
// Fallback scenario, could happen for raw pointers.
exp += ptr_type.storage == StorageClassWorkgroup ? "threadgroup" : "device";
}
exp += " atomic_";
// For signed and unsigned min/max, we can signal this through the pointer type.