MSL: Consider pointer arithmetic for OpPtrAccessChain.
If the stride is weird for non-struct types you gotta do what you gotta do.
This commit is contained in:
parent
4b27b458c5
commit
bc105b6ad0
@ -10286,7 +10286,27 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
|
||||
}
|
||||
else
|
||||
{
|
||||
append_index(index, is_literal, true);
|
||||
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
|
||||
{
|
||||
SPIRType tmp_type(OpTypeInt);
|
||||
tmp_type.basetype = SPIRType::UInt64;
|
||||
tmp_type.width = 64;
|
||||
tmp_type.vecsize = 1;
|
||||
tmp_type.columns = 1;
|
||||
|
||||
TypeID ptr_type_id = expression_type_id(base);
|
||||
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
|
||||
|
||||
// This only runs in native pointer backends.
|
||||
// Can replace reinterpret_cast with a backend string if ever needed.
|
||||
// We expect this to count as a de-reference.
|
||||
auto intptr_expr = join("*reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
|
||||
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
|
||||
get_decoration(ptr_type_id, DecorationArrayStride));
|
||||
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
|
||||
}
|
||||
else
|
||||
append_index(index, is_literal, true);
|
||||
}
|
||||
|
||||
if (type->basetype == SPIRType::ControlPointArray)
|
||||
@ -10706,6 +10726,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
|
||||
{
|
||||
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
|
||||
}
|
||||
|
||||
string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
|
||||
AccessChainMeta *meta, bool ptr_chain)
|
||||
{
|
||||
@ -10755,7 +10780,22 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
|
||||
{
|
||||
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
|
||||
if (ptr_chain)
|
||||
{
|
||||
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
|
||||
// PtrAccessChain could get complicated.
|
||||
TypeID type_id = expression_type_id(base);
|
||||
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
|
||||
{
|
||||
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
|
||||
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
|
||||
// random values. It works for structs though.
|
||||
uint32_t physical_stride = get_physical_type_stride(get_pointee_type(get<SPIRType>(type_id)));
|
||||
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
|
||||
if (physical_stride != requested_stride)
|
||||
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
|
||||
}
|
||||
}
|
||||
|
||||
return access_chain_internal(base, indices, count, flags, meta);
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,8 @@ enum AccessChainFlagBits
|
||||
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
|
||||
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
|
||||
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
|
||||
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
|
||||
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
|
||||
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7
|
||||
};
|
||||
typedef uint32_t AccessChainFlags;
|
||||
|
||||
@ -753,6 +754,10 @@ protected:
|
||||
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
|
||||
AccessChainMeta *meta);
|
||||
|
||||
// Only meaningful on backends with physical pointer support (i.e. MSL).
|
||||
// Relevant for PtrAccessChain.
|
||||
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
|
||||
|
||||
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
|
||||
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
|
||||
|
||||
|
@ -17050,6 +17050,13 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
|
||||
return msl_size;
|
||||
}
|
||||
|
||||
uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
|
||||
{
|
||||
// This should only be relevant for plain types such as scalars and vectors?
|
||||
// If we're pointing to a struct, it will recursively pick up packed/row-major state.
|
||||
return get_declared_type_size_msl(type, false, false);
|
||||
}
|
||||
|
||||
// Returns the byte size of a struct member.
|
||||
uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
|
||||
{
|
||||
|
@ -1028,6 +1028,8 @@ protected:
|
||||
|
||||
uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;
|
||||
|
||||
uint32_t get_physical_type_stride(const SPIRType &type) const override;
|
||||
|
||||
// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
|
||||
// These values can change depending on various extended decorations which control packing rules.
|
||||
// We need to make these rules match up with SPIR-V declared rules.
|
||||
|
Loading…
Reference in New Issue
Block a user