From 098427a9ceaeecd02ded11aa8e781adf01803e60 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Wed, 19 Jun 2024 14:24:50 +0200 Subject: [PATCH] MSL: Cast to packed format when using unexpected stride. We're still technically missing handling for Aligned mask on Load/Store, but that needs separate analysis and gets horribly annoying ... This should cover most use cases. --- spirv_glsl.cpp | 24 +++++++++++++++++++++--- spirv_glsl.hpp | 7 ++++--- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 2b7595cf..afe18fb2 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -10297,14 +10297,27 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice TypeID ptr_type_id = expression_type_id(base); const SPIRType &ptr_type = get(ptr_type_id); + const SPIRType &pointee_type = get_pointee_type(ptr_type); // This only runs in native pointer backends. // Can replace reinterpret_cast with a backend string if ever needed. // We expect this to count as a de-reference. - auto intptr_expr = join("*reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); + // This leaks some MSL details, but feels slightly overkill to + // add yet another virtual interface just for this. + auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ", get_decoration(ptr_type_id, DecorationArrayStride)); - expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")"); + + if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT) + { + is_packed = true; + expr = join("*reinterpret_cast(", intptr_expr, ")"); + } + else + { + expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")"); + } } else append_index(index, is_literal, true); @@ -10790,10 +10803,15 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32 // If there is a mismatch we have to go via 64-bit pointer arithmetic :'( // Using packed hacks only gets us so far, and is not designed to deal with pointer to // random values. It works for structs though. - uint32_t physical_stride = get_physical_type_stride(get_pointee_type(get(type_id))); + auto &pointee_type = get_pointee_type(get(type_id)); + uint32_t physical_stride = get_physical_type_stride(pointee_type); uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride); if (physical_stride != requested_stride) + { flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT; + if (is_vector(pointee_type)) + flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT; + } } } diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 2f3fd96c..8a002632 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -67,7 +67,8 @@ enum AccessChainFlagBits ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4, ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5, ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6, - ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7 + ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7, + ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8 }; typedef uint32_t AccessChainFlags; @@ -754,8 +755,8 @@ protected: std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags, AccessChainMeta *meta); - // Only meaningful on backends with physical pointer support (i.e. MSL). - // Relevant for PtrAccessChain. + // Only meaningful on backends with physical pointer support ala MSL. + // Relevant for PtrAccessChain / BDA. virtual uint32_t get_physical_type_stride(const SPIRType &type) const; spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);