Merge pull request #2346 from KhronosGroup/fix-2336
MSL: Handle OpPtrAccessChain with ArrayStride
This commit is contained in:
commit
d79ba7d714
@ -0,0 +1,22 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct Registers
|
||||
{
|
||||
device float3* a;
|
||||
device float3* b;
|
||||
uint2 c;
|
||||
uint2 d;
|
||||
};
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(64u, 1u, 1u);
|
||||
|
||||
kernel void main0(constant Registers& _7 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
|
||||
{
|
||||
device float3* _41 = reinterpret_cast<device float3*>(as_type<ulong>(_7.c));
|
||||
*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_7.a) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_7.a) + gl_GlobalInvocationID.x * 12)) + _7.b[gl_GlobalInvocationID.x];
|
||||
*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_41) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_41) + gl_GlobalInvocationID.x * 12)) + (reinterpret_cast<device float3*>(as_type<ulong>(_7.d)))[gl_GlobalInvocationID.x];
|
||||
}
|
||||
|
@ -0,0 +1,98 @@
|
||||
; SPIR-V
|
||||
; Version: 1.0
|
||||
; Generator: Khronos Glslang Reference Front End; 11
|
||||
; Bound: 66
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpCapability PhysicalStorageBufferAddresses
|
||||
OpExtension "SPV_KHR_physical_storage_buffer"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel PhysicalStorageBuffer64 GLSL450
|
||||
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID
|
||||
OpExecutionMode %main LocalSize 64 1 1
|
||||
OpSource GLSL 450
|
||||
OpSourceExtension "GL_EXT_buffer_reference"
|
||||
OpSourceExtension "GL_EXT_buffer_reference_uvec2"
|
||||
OpSourceExtension "GL_EXT_scalar_block_layout"
|
||||
OpName %main "main"
|
||||
OpName %Registers "Registers"
|
||||
OpMemberName %Registers 0 "a"
|
||||
OpMemberName %Registers 1 "b"
|
||||
OpMemberName %Registers 2 "c"
|
||||
OpMemberName %Registers 3 "d"
|
||||
OpName %_ ""
|
||||
OpName %gl_GlobalInvocationID "gl_GlobalInvocationID"
|
||||
OpMemberDecorate %Registers 0 Offset 0
|
||||
OpMemberDecorate %Registers 1 Offset 8
|
||||
OpMemberDecorate %Registers 2 Offset 16
|
||||
OpMemberDecorate %Registers 3 Offset 24
|
||||
OpDecorate %Registers Block
|
||||
OpDecorate %v3float_stride12_ptr ArrayStride 12
|
||||
OpDecorate %v3float_stride16_ptr ArrayStride 16
|
||||
OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
|
||||
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
|
||||
%void = OpTypeVoid
|
||||
%3 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%float = OpTypeFloat 32
|
||||
%v3float = OpTypeVector %float 3
|
||||
%_ptr_PhysicalStorageBuffer_v3float = OpTypePointer PhysicalStorageBuffer %v3float
|
||||
%v3float_stride12_ptr = OpTypePointer PhysicalStorageBuffer %v3float
|
||||
%v3float_stride16_ptr = OpTypePointer PhysicalStorageBuffer %v3float
|
||||
%v3float_stride12_ptr_push = OpTypePointer PushConstant %v3float_stride12_ptr
|
||||
%v3float_stride16_ptr_push = OpTypePointer PushConstant %v3float_stride16_ptr
|
||||
%v2uint_ptr = OpTypePointer PushConstant %v2uint
|
||||
%Registers = OpTypeStruct %v3float_stride12_ptr %v3float_stride16_ptr %v2uint %v2uint
|
||||
%_ptr_PushConstant_Registers = OpTypePointer PushConstant %Registers
|
||||
%_ = OpVariable %_ptr_PushConstant_Registers PushConstant
|
||||
%int = OpTypeInt 32 1
|
||||
%int_0 = OpConstant %int 0
|
||||
%v3uint = OpTypeVector %uint 3
|
||||
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
|
||||
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%_ptr_Input_uint = OpTypePointer Input %uint
|
||||
%int_1 = OpConstant %int 1
|
||||
%int_2 = OpConstant %int 2
|
||||
%_ptr_PushConstant_v2uint = OpTypePointer PushConstant %v2uint
|
||||
%int_3 = OpConstant %int 3
|
||||
%uint_64 = OpConstant %uint 64
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_64 %uint_1 %uint_1
|
||||
%main = OpFunction %void None %3
|
||||
%5 = OpLabel
|
||||
%29 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
|
||||
%index = OpLoad %uint %29
|
||||
|
||||
%ptr_member_0 = OpAccessChain %v3float_stride12_ptr_push %_ %int_0
|
||||
%ptr0 = OpLoad %v3float_stride12_ptr %ptr_member_0
|
||||
|
||||
%ptr_member_1 = OpAccessChain %v3float_stride16_ptr_push %_ %int_1
|
||||
%ptr1 = OpLoad %v3float_stride16_ptr %ptr_member_1
|
||||
|
||||
%ptr_member_2 = OpAccessChain %v2uint_ptr %_ %int_2
|
||||
%ptr2v = OpLoad %v2uint %ptr_member_2
|
||||
%ptr2 = OpBitcast %v3float_stride12_ptr %ptr2v
|
||||
|
||||
%ptr_member_3 = OpAccessChain %v2uint_ptr %_ %int_3
|
||||
%ptr3v = OpLoad %v2uint %ptr_member_3
|
||||
%ptr3 = OpBitcast %v3float_stride16_ptr %ptr3v
|
||||
|
||||
%ptr0_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr0 %index
|
||||
%ptr1_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr1 %index
|
||||
%ptr2_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr2 %index
|
||||
%ptr3_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr3 %index
|
||||
|
||||
%loaded0 = OpLoad %v3float %ptr0_chain Aligned 4
|
||||
%loaded1 = OpLoad %v3float %ptr1_chain Aligned 16
|
||||
%loaded2 = OpLoad %v3float %ptr2_chain Aligned 4
|
||||
%loaded3 = OpLoad %v3float %ptr3_chain Aligned 16
|
||||
|
||||
%added0 = OpFAdd %v3float %loaded0 %loaded1
|
||||
%added1 = OpFAdd %v3float %loaded2 %loaded3
|
||||
OpStore %ptr0_chain %added0 Aligned 4
|
||||
OpStore %ptr2_chain %added1 Aligned 4
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd
|
@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_
|
||||
string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
|
||||
{
|
||||
auto &type = expression_type(id);
|
||||
if (type.pointer && should_dereference(id))
|
||||
|
||||
if (is_pointer(type) && should_dereference(id))
|
||||
return dereference_expression(type, to_enclosed_expression(id, register_expression_read));
|
||||
else
|
||||
return to_expression(id, register_expression_read);
|
||||
@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre
|
||||
string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read)
|
||||
{
|
||||
auto &type = expression_type(id);
|
||||
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
|
||||
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
|
||||
return address_of_expression(to_enclosed_expression(id, register_expression_read));
|
||||
else
|
||||
return to_unpacked_expression(id, register_expression_read);
|
||||
@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression
|
||||
string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read)
|
||||
{
|
||||
auto &type = expression_type(id);
|
||||
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
|
||||
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
|
||||
return address_of_expression(to_enclosed_expression(id, register_expression_read));
|
||||
else
|
||||
return to_enclosed_unpacked_expression(id, register_expression_read);
|
||||
@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
|
||||
}
|
||||
else
|
||||
{
|
||||
append_index(index, is_literal, true);
|
||||
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
|
||||
{
|
||||
SPIRType tmp_type(OpTypeInt);
|
||||
tmp_type.basetype = SPIRType::UInt64;
|
||||
tmp_type.width = 64;
|
||||
tmp_type.vecsize = 1;
|
||||
tmp_type.columns = 1;
|
||||
|
||||
TypeID ptr_type_id = expression_type_id(base);
|
||||
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
|
||||
const SPIRType &pointee_type = get_pointee_type(ptr_type);
|
||||
|
||||
// This only runs in native pointer backends.
|
||||
// Can replace reinterpret_cast with a backend string if ever needed.
|
||||
// We expect this to count as a de-reference.
|
||||
// This leaks some MSL details, but feels slightly overkill to
|
||||
// add yet another virtual interface just for this.
|
||||
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
|
||||
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
|
||||
get_decoration(ptr_type_id, DecorationArrayStride));
|
||||
|
||||
if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
|
||||
{
|
||||
is_packed = true;
|
||||
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
|
||||
" *>(", intptr_expr, ")");
|
||||
}
|
||||
else
|
||||
{
|
||||
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
|
||||
}
|
||||
}
|
||||
else
|
||||
append_index(index, is_literal, true);
|
||||
}
|
||||
|
||||
if (type->basetype == SPIRType::ControlPointArray)
|
||||
@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
|
||||
{
|
||||
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
|
||||
}
|
||||
|
||||
string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
|
||||
AccessChainMeta *meta, bool ptr_chain)
|
||||
{
|
||||
@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
|
||||
{
|
||||
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
|
||||
if (ptr_chain)
|
||||
{
|
||||
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
|
||||
// PtrAccessChain could get complicated.
|
||||
TypeID type_id = expression_type_id(base);
|
||||
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
|
||||
{
|
||||
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
|
||||
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
|
||||
// random values. It works for structs though.
|
||||
auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
|
||||
uint32_t physical_stride = get_physical_type_stride(pointee_type);
|
||||
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
|
||||
if (physical_stride != requested_stride)
|
||||
{
|
||||
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
|
||||
if (is_vector(pointee_type))
|
||||
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return access_chain_internal(base, indices, count, flags, meta);
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,9 @@ enum AccessChainFlagBits
|
||||
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
|
||||
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
|
||||
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
|
||||
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
|
||||
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
|
||||
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
|
||||
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
|
||||
};
|
||||
typedef uint32_t AccessChainFlags;
|
||||
|
||||
@ -753,6 +755,10 @@ protected:
|
||||
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
|
||||
AccessChainMeta *meta);
|
||||
|
||||
// Only meaningful on backends with physical pointer support ala MSL.
|
||||
// Relevant for PtrAccessChain / BDA.
|
||||
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
|
||||
|
||||
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
|
||||
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
|
||||
|
||||
|
@ -4803,7 +4803,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!mbr_type.array.empty())
|
||||
if (is_array(mbr_type))
|
||||
{
|
||||
// If we have an array type, array stride must match exactly with SPIR-V.
|
||||
|
||||
@ -17050,13 +17050,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
|
||||
return msl_size;
|
||||
}
|
||||
|
||||
uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
|
||||
{
|
||||
// This should only be relevant for plain types such as scalars and vectors?
|
||||
// If we're pointing to a struct, it will recursively pick up packed/row-major state.
|
||||
return get_declared_type_size_msl(type, false, false);
|
||||
}
|
||||
|
||||
// Returns the byte size of a struct member.
|
||||
uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
|
||||
{
|
||||
// Pointers take 8 bytes each
|
||||
// Match both pointer and array-of-pointer here.
|
||||
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
|
||||
{
|
||||
uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
|
||||
uint32_t type_size = 8;
|
||||
|
||||
// Work our way through potentially layered arrays,
|
||||
// stopping when we hit a pointer that is not also an array.
|
||||
@ -17131,9 +17139,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t
|
||||
// Returns the byte alignment of a type.
|
||||
uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
|
||||
{
|
||||
// Pointers aligns on multiples of 8 bytes
|
||||
// Pointers align on multiples of 8 bytes.
|
||||
// Deliberately ignore array-ness here. It's not relevant for alignment.
|
||||
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
|
||||
return 8 * (type.vecsize == 3 ? 4 : type.vecsize);
|
||||
return 8;
|
||||
|
||||
switch (type.basetype)
|
||||
{
|
||||
|
@ -1028,6 +1028,8 @@ protected:
|
||||
|
||||
uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;
|
||||
|
||||
uint32_t get_physical_type_stride(const SPIRType &type) const override;
|
||||
|
||||
// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
|
||||
// These values can change depending on various extended decorations which control packing rules.
|
||||
// We need to make these rules match up with SPIR-V declared rules.
|
||||
|
Loading…
Reference in New Issue
Block a user