mirror of
https://github.com/KhronosGroup/SPIRV-Cross.git
synced 2024-11-09 13:50:05 +00:00
MSL: Cast to packed format when using unexpected stride.
We're still technically missing handling for Aligned mask on Load/Store, but that needs separate analysis and gets horribly annoying ... This should cover most use cases.
This commit is contained in:
parent
eeb35a97e9
commit
098427a9ce
@ -10297,14 +10297,27 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
|
|||||||
|
|
||||||
TypeID ptr_type_id = expression_type_id(base);
|
TypeID ptr_type_id = expression_type_id(base);
|
||||||
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
|
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
|
||||||
|
const SPIRType &pointee_type = get_pointee_type(ptr_type);
|
||||||
|
|
||||||
// This only runs in native pointer backends.
|
// This only runs in native pointer backends.
|
||||||
// Can replace reinterpret_cast with a backend string if ever needed.
|
// Can replace reinterpret_cast with a backend string if ever needed.
|
||||||
// We expect this to count as a de-reference.
|
// We expect this to count as a de-reference.
|
||||||
auto intptr_expr = join("*reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
|
// This leaks some MSL details, but feels slightly overkill to
|
||||||
|
// add yet another virtual interface just for this.
|
||||||
|
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
|
||||||
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
|
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
|
||||||
get_decoration(ptr_type_id, DecorationArrayStride));
|
get_decoration(ptr_type_id, DecorationArrayStride));
|
||||||
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
|
|
||||||
|
if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
|
||||||
|
{
|
||||||
|
is_packed = true;
|
||||||
|
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
|
||||||
|
" *>(", intptr_expr, ")");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
append_index(index, is_literal, true);
|
append_index(index, is_literal, true);
|
||||||
@ -10790,10 +10803,15 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
|
|||||||
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
|
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
|
||||||
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
|
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
|
||||||
// random values. It works for structs though.
|
// random values. It works for structs though.
|
||||||
uint32_t physical_stride = get_physical_type_stride(get_pointee_type(get<SPIRType>(type_id)));
|
auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
|
||||||
|
uint32_t physical_stride = get_physical_type_stride(pointee_type);
|
||||||
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
|
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
|
||||||
if (physical_stride != requested_stride)
|
if (physical_stride != requested_stride)
|
||||||
|
{
|
||||||
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
|
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
|
||||||
|
if (is_vector(pointee_type))
|
||||||
|
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,8 @@ enum AccessChainFlagBits
|
|||||||
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
|
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
|
||||||
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
|
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
|
||||||
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
|
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
|
||||||
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7
|
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
|
||||||
|
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
|
||||||
};
|
};
|
||||||
typedef uint32_t AccessChainFlags;
|
typedef uint32_t AccessChainFlags;
|
||||||
|
|
||||||
@ -754,8 +755,8 @@ protected:
|
|||||||
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
|
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
|
||||||
AccessChainMeta *meta);
|
AccessChainMeta *meta);
|
||||||
|
|
||||||
// Only meaningful on backends with physical pointer support (i.e. MSL).
|
// Only meaningful on backends with physical pointer support ala MSL.
|
||||||
// Relevant for PtrAccessChain.
|
// Relevant for PtrAccessChain / BDA.
|
||||||
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
|
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
|
||||||
|
|
||||||
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
|
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
|
||||||
|
Loading…
Reference in New Issue
Block a user