MSL: Cast to packed format when using unexpected stride.

We're still technically missing handling for Aligned mask on Load/Store,
but that needs separate analysis and gets horribly annoying ...
This should cover most use cases.
This commit is contained in:
Hans-Kristian Arntzen 2024-06-19 14:24:50 +02:00
parent eeb35a97e9
commit 098427a9ce
2 changed files with 25 additions and 6 deletions

View File

@ -10297,14 +10297,27 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
TypeID ptr_type_id = expression_type_id(base); TypeID ptr_type_id = expression_type_id(base);
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id); const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
const SPIRType &pointee_type = get_pointee_type(ptr_type);
// This only runs in native pointer backends. // This only runs in native pointer backends.
// Can replace reinterpret_cast with a backend string if ever needed. // Can replace reinterpret_cast with a backend string if ever needed.
// We expect this to count as a de-reference. // We expect this to count as a de-reference.
auto intptr_expr = join("*reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); // This leaks some MSL details, but feels slightly overkill to
// add yet another virtual interface just for this.
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ", intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
get_decoration(ptr_type_id, DecorationArrayStride)); get_decoration(ptr_type_id, DecorationArrayStride));
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
{
is_packed = true;
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
" *>(", intptr_expr, ")");
}
else
{
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
}
} }
else else
append_index(index, is_literal, true); append_index(index, is_literal, true);
@ -10790,10 +10803,15 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'( // If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
// Using packed hacks only gets us so far, and is not designed to deal with pointer to // Using packed hacks only gets us so far, and is not designed to deal with pointer to
// random values. It works for structs though. // random values. It works for structs though.
uint32_t physical_stride = get_physical_type_stride(get_pointee_type(get<SPIRType>(type_id))); auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
uint32_t physical_stride = get_physical_type_stride(pointee_type);
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride); uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
if (physical_stride != requested_stride) if (physical_stride != requested_stride)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT; flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
if (is_vector(pointee_type))
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
}
} }
} }

View File

@ -67,7 +67,8 @@ enum AccessChainFlagBits
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4, ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5, ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6, ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7 ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
}; };
typedef uint32_t AccessChainFlags; typedef uint32_t AccessChainFlags;
@ -754,8 +755,8 @@ protected:
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags, std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
AccessChainMeta *meta); AccessChainMeta *meta);
// Only meaningful on backends with physical pointer support (i.e. MSL). // Only meaningful on backends with physical pointer support ala MSL.
// Relevant for PtrAccessChain. // Relevant for PtrAccessChain / BDA.
virtual uint32_t get_physical_type_stride(const SPIRType &type) const; virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr); spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);