MSL: Fixes from review for SPV_KHR_physical_storage_buffer extension.

- Assign ulongn physical type to buffer pointers in short arrays
  when array stride is larger than pointer size.
- Support GL_EXT_buffer_reference_uvec2 casting
  buffer reference pointers to and from uvec2 values.
- When packing structs, include structs inside physical buffers.
- Update mechanism for traversing pointer arrays when calculating type sizes.
- Added unit test shaders.
This commit is contained in:
Bill Hollings 2022-07-01 16:10:41 -04:00
parent 78eb5043f9
commit 4185acc70d
10 changed files with 181 additions and 42 deletions

View File

@ -0,0 +1,28 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO;
struct UBO
{
uint2 b;
};
struct SSBO
{
packed_float3 a1;
float a2;
};
kernel void main0(constant UBO& _10 [[buffer(0)]])
{
((device SSBO*)as_type<uint64_t>(_10.b))->a1 = float3(1.0, 2.0, 3.0);
uint2 _35 = as_type<uint2>((uint64_t)((device SSBO*)as_type<uint64_t>(_10.b + uint2(32u))));
uint2 v2 = _35;
device SSBO* _39 = ((device SSBO*)as_type<uint64_t>(_35));
float3 v3 = float3(_39->a1);
_39->a1 = float3(_39->a1) + float3(1.0);
}

View File

@ -9,7 +9,7 @@ struct t24
{
int4 m0[2];
int m1;
device t21* m2[2];
ulong2 m2[2];
device t21* m3;
float2x4 m4;
};
@ -18,7 +18,7 @@ struct t21
{
int4 m0[2];
int m1;
device t21* m2[2];
ulong2 m2[2];
device t21* m3;
float2x4 m4;
};
@ -47,33 +47,33 @@ kernel void main0(constant t24& u24 [[buffer(0)]], constant t35& u35 [[buffer(1)
v8 = _75;
int _82 = _75 | int(u24.m4[1u][1] - 6.0);
v8 = _82;
int _92 = _82 | (u24.m2[0]->m0[0].x - 3);
int _92 = _82 | (((device t21*)u24.m2[0].x)->m0[0].x - 3);
v8 = _92;
int _101 = _92 | (u24.m2[0]->m0[u35.m0[1]].x - 4);
int _101 = _92 | (((device t21*)u24.m2[0].x)->m0[u35.m0[1]].x - 4);
v8 = _101;
int _109 = _101 | (u24.m2[0]->m1 - 5);
int _109 = _101 | (((device t21*)u24.m2[0].x)->m1 - 5);
v8 = _109;
int _118 = _109 | int(u24.m2[0]->m4[0u][0] - 6.0);
int _118 = _109 | int(((device t21*)u24.m2[0].x)->m4[0u][0] - 6.0);
v8 = _118;
int _127 = _118 | int(u24.m2[0]->m4[1u][0] - 8.0);
int _127 = _118 | int(((device t21*)u24.m2[0].x)->m4[1u][0] - 8.0);
v8 = _127;
int _136 = _127 | int(u24.m2[0]->m4[0u][1] - 7.0);
int _136 = _127 | int(((device t21*)u24.m2[0].x)->m4[0u][1] - 7.0);
v8 = _136;
int _145 = _136 | int(u24.m2[0]->m4[1u][1] - 9.0);
int _145 = _136 | int(((device t21*)u24.m2[0].x)->m4[1u][1] - 9.0);
v8 = _145;
int _155 = _145 | (u24.m2[u35.m0[1]]->m0[0].x - 6);
int _155 = _145 | (((device t21*)u24.m2[u35.m0[1]].x)->m0[0].x - 6);
v8 = _155;
int _167 = _155 | (u24.m2[u35.m0[1]]->m0[u35.m0[1]].x - 7);
int _167 = _155 | (((device t21*)u24.m2[u35.m0[1]].x)->m0[u35.m0[1]].x - 7);
v8 = _167;
int _177 = _167 | (u24.m2[u35.m0[1]]->m1 - 8);
int _177 = _167 | (((device t21*)u24.m2[u35.m0[1]].x)->m1 - 8);
v8 = _177;
int _187 = _177 | int(u24.m2[u35.m0[1]]->m4[0u][0] - 9.0);
int _187 = _177 | int(((device t21*)u24.m2[u35.m0[1]].x)->m4[0u][0] - 9.0);
v8 = _187;
int _198 = _187 | int(u24.m2[u35.m0[1]]->m4[1u][0] - 11.0);
int _198 = _187 | int(((device t21*)u24.m2[u35.m0[1]].x)->m4[1u][0] - 11.0);
v8 = _198;
int _209 = _198 | int(u24.m2[u35.m0[1]]->m4[0u][1] - 10.0);
int _209 = _198 | int(((device t21*)u24.m2[u35.m0[1]].x)->m4[0u][1] - 10.0);
v8 = _209;
int _220 = _209 | int(u24.m2[u35.m0[1]]->m4[1u][1] - 12.0);
int _220 = _209 | int(((device t21*)u24.m2[u35.m0[1]].x)->m4[1u][1] - 12.0);
v8 = _220;
int _228 = _220 | (u24.m3->m0[0].x - 9);
v8 = _228;

View File

@ -0,0 +1,26 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO;
struct UBO
{
uint2 b;
};
struct SSBO
{
packed_float3 a1;
float a2;
};
kernel void main0(constant UBO& _10 [[buffer(0)]])
{
((device SSBO*)as_type<uint64_t>(_10.b))->a1 = float3(1.0, 2.0, 3.0);
uint2 v2 = as_type<uint2>((uint64_t)((device SSBO*)as_type<uint64_t>(_10.b + uint2(32u))));
float3 v3 = float3(((device SSBO*)as_type<uint64_t>(v2))->a1);
((device SSBO*)as_type<uint64_t>(v2))->a1 = v3 + float3(1.0);
}

View File

@ -9,7 +9,7 @@ struct t24
{
int4 m0[2];
int m1;
device t21* m2[2];
ulong2 m2[2];
device t21* m3;
float2x4 m4;
};
@ -18,7 +18,7 @@ struct t21
{
int4 m0[2];
int m1;
device t21* m2[2];
ulong2 m2[2];
device t21* m3;
float2x4 m4;
};
@ -40,20 +40,20 @@ kernel void main0(constant t24& u24 [[buffer(0)]], constant t35& u35 [[buffer(1)
v8 |= int(u24.m4[1u][0] - 5.0);
v8 |= int(u24.m4[0u][1] - 4.0);
v8 |= int(u24.m4[1u][1] - 6.0);
v8 |= (u24.m2[0]->m0[0].x - 3);
v8 |= (u24.m2[0]->m0[u35.m0[1]].x - 4);
v8 |= (u24.m2[0]->m1 - 5);
v8 |= int(u24.m2[0]->m4[0u][0] - 6.0);
v8 |= int(u24.m2[0]->m4[1u][0] - 8.0);
v8 |= int(u24.m2[0]->m4[0u][1] - 7.0);
v8 |= int(u24.m2[0]->m4[1u][1] - 9.0);
v8 |= (u24.m2[u35.m0[1]]->m0[0].x - 6);
v8 |= (u24.m2[u35.m0[1]]->m0[u35.m0[1]].x - 7);
v8 |= (u24.m2[u35.m0[1]]->m1 - 8);
v8 |= int(u24.m2[u35.m0[1]]->m4[0u][0] - 9.0);
v8 |= int(u24.m2[u35.m0[1]]->m4[1u][0] - 11.0);
v8 |= int(u24.m2[u35.m0[1]]->m4[0u][1] - 10.0);
v8 |= int(u24.m2[u35.m0[1]]->m4[1u][1] - 12.0);
v8 |= (((device t21*)u24.m2[0].x)->m0[0].x - 3);
v8 |= (((device t21*)u24.m2[0].x)->m0[u35.m0[1]].x - 4);
v8 |= (((device t21*)u24.m2[0].x)->m1 - 5);
v8 |= int(((device t21*)u24.m2[0].x)->m4[0u][0] - 6.0);
v8 |= int(((device t21*)u24.m2[0].x)->m4[1u][0] - 8.0);
v8 |= int(((device t21*)u24.m2[0].x)->m4[0u][1] - 7.0);
v8 |= int(((device t21*)u24.m2[0].x)->m4[1u][1] - 9.0);
v8 |= (((device t21*)u24.m2[u35.m0[1]].x)->m0[0].x - 6);
v8 |= (((device t21*)u24.m2[u35.m0[1]].x)->m0[u35.m0[1]].x - 7);
v8 |= (((device t21*)u24.m2[u35.m0[1]].x)->m1 - 8);
v8 |= int(((device t21*)u24.m2[u35.m0[1]].x)->m4[0u][0] - 9.0);
v8 |= int(((device t21*)u24.m2[u35.m0[1]].x)->m4[1u][0] - 11.0);
v8 |= int(((device t21*)u24.m2[u35.m0[1]].x)->m4[0u][1] - 10.0);
v8 |= int(((device t21*)u24.m2[u35.m0[1]].x)->m4[1u][1] - 12.0);
v8 |= (u24.m3->m0[0].x - 9);
v8 |= (u24.m3->m0[u35.m0[1]].x - 10);
v8 |= (u24.m3->m1 - 11);

View File

@ -0,0 +1,22 @@
#version 450
#extension GL_EXT_buffer_reference : require
#extension GL_EXT_buffer_reference_uvec2 : require
layout(buffer_reference, buffer_reference_align = 4) buffer SSBO
{
vec3 a1; // Will be 12-byte packed
float a2;
};
layout(push_constant) uniform UBO
{
uvec2 b;
};
void main()
{
SSBO(b).a1 = vec3(1.0, 2.0, 3.0); // uvec2 -> buff ref and assign to packed
uvec2 v2 = uvec2(SSBO(b + 32)); // uvec2 -> buff ref -> uvec2
vec3 v3 = SSBO(v2).a1; // uvec2 -> buff ref and assign from packed
SSBO(v2).a1 = v3 + 1.0; // uvec2 -> buff ref and assign to packed
}

View File

@ -8981,6 +8981,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (!is_literal)
mod_flags &= ~ACCESS_CHAIN_INDEX_IS_LITERAL_BIT;
access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index);
check_physical_type_cast(expr, type, physical_type);
};
for (uint32_t i = 0; i < count; i++)
@ -9313,6 +9314,10 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
return expr;
}
void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t)
{
}
void CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
{
}

View File

@ -711,6 +711,7 @@ protected:
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
virtual void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed);

View File

@ -1967,6 +1967,13 @@ void CompilerMSL::mark_packable_structs()
mark_as_packable(type);
}
});
// Physical storage buffer pointers can appear outside of the context of a variable, if the address
// is calculated from a ulong or uvec2 and cast to a pointer, so check if they need to be packed too.
ir.for_each_typed_id<SPIRType>([&](uint32_t, SPIRType &type) {
if (type.basetype == SPIRType::Struct && type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
mark_as_packable(type);
});
}
// If the specified type is a struct, it and any nested structs
@ -4325,8 +4332,17 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
auto physical_type = mbr_type;
physical_type.vecsize = elems_per_stride;
if (!is_buff_ptr)
physical_type.parent_type = 0;
physical_type.parent_type = 0;
// If this is a physical buffer pointer, replace type with a ulongn vector.
if (is_buff_ptr)
{
physical_type.width = 64;
physical_type.basetype = to_unsigned_basetype(physical_type.width);
physical_type.pointer = false;
physical_type.pointer_depth = false;
physical_type.forward_pointer = false;
}
uint32_t type_id = ir.increase_bound_by(1);
set<SPIRType>(type_id, physical_type);
@ -7709,6 +7725,23 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
}
// If the physical type of a physical buffer pointer has been changed
// to a ulong or ulongn vector, add a cast back to the pointer type.
void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
{
auto *p_physical_type = maybe_get<SPIRType>(physical_type);
if (p_physical_type &&
p_physical_type->storage == StorageClassPhysicalStorageBuffer &&
p_physical_type->basetype == to_unsigned_basetype(64))
{
if (p_physical_type->vecsize > 1)
expr += ".x";
expr = join("((", type_to_glsl(*type), ")", expr, ")");
}
}
// Override for MSL-specific syntax instructions
void CompilerMSL::emit_instruction(const Instruction &instruction)
{
@ -14362,18 +14395,31 @@ string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in
// size (eg. short shift right becomes int), which means chaining integer ops
// together may introduce size variations that SPIR-V doesn't know about.
if (same_size_cast && !integral_cast)
{
return "as_type<" + type_to_glsl(out_type) + ">";
}
else
{
return type_to_glsl(out_type);
}
}
bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
bool CompilerMSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0)
{
return false;
auto &out_type = get<SPIRType>(result_type);
auto &in_type = expression_type(op0);
bool uvec2_to_ptr = (in_type.basetype == SPIRType::UInt && in_type.vecsize == 2 &&
out_type.pointer && out_type.storage == StorageClassPhysicalStorageBuffer);
bool ptr_to_uvec2 = (in_type.pointer && in_type.storage == StorageClassPhysicalStorageBuffer &&
out_type.basetype == SPIRType::UInt && out_type.vecsize == 2);
string expr;
// Casting between uvec2 and buffer storage pointer per GL_EXT_buffer_reference_uvec2
if (uvec2_to_ptr)
expr = join("((", type_to_glsl(out_type), ")as_type<uint64_t>(", to_unpacked_expression(op0), "))");
else if (ptr_to_uvec2)
expr = join("as_type<", type_to_glsl(out_type), ">((uint64_t)", to_unpacked_expression(op0), ")");
else
return false;
emit_op(result_type, id, expr, should_forward(op0));
return true;
}
// Returns an MSL string identifying the name of a SPIR-V builtin.
@ -15050,9 +15096,18 @@ uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_p
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
{
uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
// Work our way through potentially layered arrays,
// stopping when we hit a pointer that is not also an array.
size_t dim_cnt = type.array.size();
for (uint32_t dim_idx = 0; dim_idx < dim_cnt; dim_idx++)
type_size *= to_array_size_literal(type, dim_idx);
int32_t dim_idx = dim_cnt - 1;
auto *p_type = &type;
while (!type_is_pointer(*p_type) && dim_idx >= 0)
{
type_size *= to_array_size_literal(*p_type, dim_idx);
p_type = &get<SPIRType>(p_type->parent_type);
dim_idx--;
}
return type_size;
}

View File

@ -986,6 +986,8 @@ protected:
void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool &is_packed) override;
void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;
bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length);
bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);
bool is_out_of_bounds_tessellation_level(uint32_t id_lhs);