MSL: Consider pointer arithmetic for OpPtrAccessChain.

If the stride is weird for non-struct types you gotta do what you gotta
do.
This commit is contained in:
Hans-Kristian Arntzen 2024-06-19 13:43:08 +02:00
parent 4b27b458c5
commit bc105b6ad0
4 changed files with 56 additions and 2 deletions

View File

@ -10286,6 +10286,26 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
}
else
{
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
{
SPIRType tmp_type(OpTypeInt);
tmp_type.basetype = SPIRType::UInt64;
tmp_type.width = 64;
tmp_type.vecsize = 1;
tmp_type.columns = 1;
TypeID ptr_type_id = expression_type_id(base);
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
// This only runs in native pointer backends.
// Can replace reinterpret_cast with a backend string if ever needed.
// We expect this to count as a de-reference.
auto intptr_expr = join("*reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
get_decoration(ptr_type_id, DecorationArrayStride));
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
}
else
append_index(index, is_literal, true);
}
@ -10706,6 +10726,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
return ret;
}
uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
{
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
}
string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
AccessChainMeta *meta, bool ptr_chain)
{
@ -10755,7 +10780,22 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
{
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
if (ptr_chain)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
// PtrAccessChain could get complicated.
TypeID type_id = expression_type_id(base);
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
{
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
// random values. It works for structs though.
uint32_t physical_stride = get_physical_type_stride(get_pointee_type(get<SPIRType>(type_id)));
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
if (physical_stride != requested_stride)
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
}
}
return access_chain_internal(base, indices, count, flags, meta);
}
}

View File

@ -66,7 +66,8 @@ enum AccessChainFlagBits
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7
};
typedef uint32_t AccessChainFlags;
@ -753,6 +754,10 @@ protected:
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
AccessChainMeta *meta);
// Only meaningful on backends with physical pointer support (i.e. MSL).
// Relevant for PtrAccessChain.
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);

View File

@ -17050,6 +17050,13 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
return msl_size;
}
uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
{
// This should only be relevant for plain types such as scalars and vectors?
// If we're pointing to a struct, it will recursively pick up packed/row-major state.
return get_declared_type_size_msl(type, false, false);
}
// Returns the byte size of a struct member.
uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
{

View File

@ -1028,6 +1028,8 @@ protected:
uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;
uint32_t get_physical_type_stride(const SPIRType &type) const override;
// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
// These values can change depending on various extended decorations which control packing rules.
// We need to make these rules match up with SPIR-V declared rules.