Merge pull request #2234 from KhronosGroup/fix-2226

MSL: Support std140 half matrices and arrays.
This commit is contained in:
Hans-Kristian Arntzen 2023-11-27 17:17:13 +01:00 committed by GitHub
commit 3717660e14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 450 additions and 14 deletions

View File

@ -0,0 +1,64 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template <typename T>
struct spvPaddedStd140 { alignas(16) T data; };
template <typename T, int n>
using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];
struct Foo
{
spvPaddedStd140Matrix<half3, 2> c23;
spvPaddedStd140Matrix<half2, 3> c32;
spvPaddedStd140Matrix<half2, 3> r23;
spvPaddedStd140Matrix<half3, 2> r32;
spvPaddedStd140<half> h1[6];
spvPaddedStd140<half2> h2[6];
spvPaddedStd140<half3> h3[6];
spvPaddedStd140<half4> h4[6];
};
struct main0_out
{
float4 FragColor [[color(0)]];
};
fragment main0_out main0(device Foo& _20 [[buffer(0)]])
{
main0_out out = {};
((device half*)&_20.c23[1].data)[2u] = half(1.0);
((device half*)&_20.c32[2].data)[1u] = half(2.0);
((device half*)&_20.r23[2u])[1] = half(3.0);
((device half*)&_20.r32[1u])[2] = half(4.0);
_20.c23[1].data = half3(half(0.0), half(1.0), half(2.0));
_20.c32[1].data = half2(half(0.0), half(1.0));
((device half*)&_20.r23[0])[1] = half3(half(0.0), half(1.0), half(2.0)).x;
((device half*)&_20.r23[1])[1] = half3(half(0.0), half(1.0), half(2.0)).y;
((device half*)&_20.r23[2])[1] = half3(half(0.0), half(1.0), half(2.0)).z;
((device half*)&_20.r32[0])[1] = half2(half(0.0), half(1.0)).x;
((device half*)&_20.r32[1])[1] = half2(half(0.0), half(1.0)).y;
(device half3&)_20.c23[0] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0];
(device half3&)_20.c23[1] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1];
(device half2&)_20.c32[0] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0];
(device half2&)_20.c32[1] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1];
(device half2&)_20.c32[2] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2];
(device half2&)_20.r23[0] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][0], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][0]);
(device half2&)_20.r23[1] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][1], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][1]);
(device half2&)_20.r23[2] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][2], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][2]);
(device half3&)_20.r32[0] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][0]);
(device half3&)_20.r32[1] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][1]);
_20.h1[5].data = half(1.0);
_20.h2[5].data = half2(half(1.0), half(2.0));
_20.h3[5].data = half3(half(1.0), half(2.0), half(3.0));
_20.h4[5].data = half4(half(1.0), half(2.0), half(3.0), half(4.0));
((device half*)&_20.h2[5].data)[1u] = half(10.0);
((device half*)&_20.h3[5].data)[2u] = half(11.0);
((device half*)&_20.h4[5].data)[3u] = half(12.0);
out.FragColor = float4(1.0);
return out;
}

View File

@ -0,0 +1,98 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template <typename T>
struct spvPaddedStd140 { alignas(16) T data; };
template <typename T, int n>
using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];
struct Foo
{
spvPaddedStd140Matrix<half2, 2> c22;
spvPaddedStd140Matrix<half2, 2> c22arr[3];
spvPaddedStd140Matrix<half3, 2> c23;
spvPaddedStd140Matrix<half4, 2> c24;
spvPaddedStd140Matrix<half2, 3> c32;
spvPaddedStd140Matrix<half3, 3> c33;
spvPaddedStd140Matrix<half4, 3> c34;
spvPaddedStd140Matrix<half2, 4> c42;
spvPaddedStd140Matrix<half3, 4> c43;
spvPaddedStd140Matrix<half4, 4> c44;
spvPaddedStd140Matrix<half2, 2> r22;
spvPaddedStd140Matrix<half2, 2> r22arr[3];
spvPaddedStd140Matrix<half2, 3> r23;
spvPaddedStd140Matrix<half2, 4> r24;
spvPaddedStd140Matrix<half3, 2> r32;
spvPaddedStd140Matrix<half3, 3> r33;
spvPaddedStd140Matrix<half3, 4> r34;
spvPaddedStd140Matrix<half4, 2> r42;
spvPaddedStd140Matrix<half4, 3> r43;
spvPaddedStd140Matrix<half4, 4> r44;
spvPaddedStd140<half> h1[6];
spvPaddedStd140<half2> h2[6];
spvPaddedStd140<half3> h3[6];
spvPaddedStd140<half4> h4[6];
};
struct main0_out
{
float4 FragColor [[color(0)]];
};
fragment main0_out main0(constant Foo& u [[buffer(0)]])
{
main0_out out = {};
half2 c2 = half2(u.c22[0].data) + half2(u.c22[1].data);
c2 = half2(u.c22arr[2][0].data) + half2(u.c22arr[2][1].data);
half3 c3 = half3(u.c23[0].data) + half3(u.c23[1].data);
half4 c4 = half4(u.c24[0].data) + half4(u.c24[1].data);
c2 = (half2(u.c32[0].data) + half2(u.c32[1].data)) + half2(u.c32[2].data);
c3 = (half3(u.c33[0].data) + half3(u.c33[1].data)) + half3(u.c33[2].data);
c4 = (half4(u.c34[0].data) + half4(u.c34[1].data)) + half4(u.c34[2].data);
c2 = ((half2(u.c42[0].data) + half2(u.c42[1].data)) + half2(u.c42[2].data)) + half2(u.c42[3].data);
c3 = ((half3(u.c43[0].data) + half3(u.c43[1].data)) + half3(u.c43[2].data)) + half3(u.c43[3].data);
c4 = ((half4(u.c44[0].data) + half4(u.c44[1].data)) + half4(u.c44[2].data)) + half4(u.c44[3].data);
half c = ((u.c22[0].data.x + u.c22[0].data.y) + u.c22[1].data.x) + u.c22[1].data.y;
c = ((u.c22arr[2][0].data.x + u.c22arr[2][0].data.y) + u.c22arr[2][1].data.x) + u.c22arr[2][1].data.y;
half2x2 c22 = half2x2(u.c22[0].data.xy, u.c22[1].data.xy);
c22 = half2x2(u.c22arr[2][0].data.xy, u.c22arr[2][1].data.xy);
half2x3 c23 = half2x3(u.c23[0].data.xyz, u.c23[1].data.xyz);
half2x4 c24 = half2x4(u.c24[0].data, u.c24[1].data);
half3x2 c32 = half3x2(u.c32[0].data.xy, u.c32[1].data.xy, u.c32[2].data.xy);
half3x3 c33 = half3x3(u.c33[0].data.xyz, u.c33[1].data.xyz, u.c33[2].data.xyz);
half3x4 c34 = half3x4(u.c34[0].data, u.c34[1].data, u.c34[2].data);
half4x2 c42 = half4x2(u.c42[0].data.xy, u.c42[1].data.xy, u.c42[2].data.xy, u.c42[3].data.xy);
half4x3 c43 = half4x3(u.c43[0].data.xyz, u.c43[1].data.xyz, u.c43[2].data.xyz, u.c43[3].data.xyz);
half4x4 c44 = half4x4(u.c44[0].data, u.c44[1].data, u.c44[2].data, u.c44[3].data);
half2 r2 = half2(u.r22[0].data[0], u.r22[1].data[0]) + half2(u.r22[0].data[1], u.r22[1].data[1]);
r2 = half2(u.r22arr[2][0].data[0], u.r22arr[2][1].data[0]) + half2(u.r22arr[2][0].data[1], u.r22arr[2][1].data[1]);
half3 r3 = half3(u.r23[0].data[0], u.r23[1].data[0], u.r23[2].data[0]) + half3(u.r23[0].data[1], u.r23[1].data[1], u.r23[2].data[1]);
half4 r4 = half4(u.r24[0].data[0], u.r24[1].data[0], u.r24[2].data[0], u.r24[3].data[0]) + half4(u.r24[0].data[1], u.r24[1].data[1], u.r24[2].data[1], u.r24[3].data[1]);
r2 = (half2(u.r32[0].data[0], u.r32[1].data[0]) + half2(u.r32[0].data[1], u.r32[1].data[1])) + half2(u.r32[0].data[2], u.r32[1].data[2]);
r3 = (half3(u.r33[0].data[0], u.r33[1].data[0], u.r33[2].data[0]) + half3(u.r33[0].data[1], u.r33[1].data[1], u.r33[2].data[1])) + half3(u.r33[0].data[2], u.r33[1].data[2], u.r33[2].data[2]);
r4 = (half4(u.r34[0].data[0], u.r34[1].data[0], u.r34[2].data[0], u.r34[3].data[0]) + half4(u.r34[0].data[1], u.r34[1].data[1], u.r34[2].data[1], u.r34[3].data[1])) + half4(u.r34[0].data[2], u.r34[1].data[2], u.r34[2].data[2], u.r34[3].data[2]);
r2 = ((half2(u.r42[0].data[0], u.r42[1].data[0]) + half2(u.r42[0].data[1], u.r42[1].data[1])) + half2(u.r42[0].data[2], u.r42[1].data[2])) + half2(u.r42[0].data[3], u.r42[1].data[3]);
r3 = ((half3(u.r43[0].data[0], u.r43[1].data[0], u.r43[2].data[0]) + half3(u.r43[0].data[1], u.r43[1].data[1], u.r43[2].data[1])) + half3(u.r43[0].data[2], u.r43[1].data[2], u.r43[2].data[2])) + half3(u.r43[0].data[3], u.r43[1].data[3], u.r43[2].data[3]);
r4 = ((half4(u.r44[0].data[0], u.r44[1].data[0], u.r44[2].data[0], u.r44[3].data[0]) + half4(u.r44[0].data[1], u.r44[1].data[1], u.r44[2].data[1], u.r44[3].data[1])) + half4(u.r44[0].data[2], u.r44[1].data[2], u.r44[2].data[2], u.r44[3].data[2])) + half4(u.r44[0].data[3], u.r44[1].data[3], u.r44[2].data[3], u.r44[3].data[3]);
half r = ((u.r22[0u].data[0] + u.r22[1u].data[0]) + u.r22[0u].data[1]) + u.r22[1u].data[1];
half2x2 r22 = transpose(half2x2(u.r22[0].data.xy, u.r22[1].data.xy));
half2x3 r23 = transpose(half3x2(u.r23[0].data.xy, u.r23[1].data.xy, u.r23[2].data.xy));
half2x4 r24 = transpose(half4x2(u.r24[0].data.xy, u.r24[1].data.xy, u.r24[2].data.xy, u.r24[3].data.xy));
half3x2 r32 = transpose(half2x3(u.r32[0].data.xyz, u.r32[1].data.xyz));
half3x3 r33 = transpose(half3x3(u.r33[0].data.xyz, u.r33[1].data.xyz, u.r33[2].data.xyz));
half3x4 r34 = transpose(half4x3(u.r34[0].data.xyz, u.r34[1].data.xyz, u.r34[2].data.xyz, u.r34[3].data.xyz));
half4x2 r42 = transpose(half2x4(u.r42[0].data, u.r42[1].data));
half4x3 r43 = transpose(half3x4(u.r43[0].data, u.r43[1].data, u.r43[2].data));
half4x4 r44 = transpose(half4x4(u.r44[0].data, u.r44[1].data, u.r44[2].data, u.r44[3].data));
half h1 = half(u.h1[5].data);
half2 h2 = half2(u.h2[5].data);
half3 h3 = half3(u.h3[5].data);
half4 h4 = half4(u.h4[5].data);
out.FragColor = float4(1.0);
return out;
}

View File

@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
layout(set = 0, binding = 0, std140) buffer Foo
{
f16mat2x3 c23;
f16mat3x2 c32;
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat3x2 r32;
float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
};
layout(location = 0) out vec4 FragColor;
void main()
{
// Store scalar
c23[1][2] = 1.0hf;
c32[2][1] = 2.0hf;
r23[1][2] = 3.0hf;
r32[2][1] = 4.0hf;
// Store vector
c23[1] = f16vec3(0, 1, 2);
c32[1] = f16vec2(0, 1);
r23[1] = f16vec3(0, 1, 2);
r32[1] = f16vec2(0, 1);
// Store matrix
c23 = f16mat2x3(1, 2, 3, 4, 5, 6);
c32 = f16mat3x2(1, 2, 3, 4, 5, 6);
r23 = f16mat2x3(1, 2, 3, 4, 5, 6);
r32 = f16mat3x2(1, 2, 3, 4, 5, 6);
// Store array
h1[5] = 1.0hf;
h2[5] = f16vec2(1, 2);
h3[5] = f16vec3(1, 2, 3);
h4[5] = f16vec4(1, 2, 3, 4);
// Store scalar in array
h2[5][1] = 10.0hf;
h3[5][2] = 11.0hf;
h4[5][3] = 12.0hf;
FragColor = vec4(1.0);
}

View File

@ -0,0 +1,110 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
layout(set = 0, binding = 0, std140) uniform Foo
{
f16mat2x2 c22;
f16mat2x2 c22arr[3];
f16mat2x3 c23;
f16mat2x4 c24;
f16mat3x2 c32;
f16mat3x3 c33;
f16mat3x4 c34;
f16mat4x2 c42;
f16mat4x3 c43;
f16mat4x4 c44;
layout(row_major) f16mat2x2 r22;
layout(row_major) f16mat2x2 r22arr[3];
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat2x4 r24;
layout(row_major) f16mat3x2 r32;
layout(row_major) f16mat3x3 r33;
layout(row_major) f16mat3x4 r34;
layout(row_major) f16mat4x2 r42;
layout(row_major) f16mat4x3 r43;
layout(row_major) f16mat4x4 r44;
float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
} u;
layout(location = 0) out vec4 FragColor;
void main()
{
// Load vectors.
f16vec2 c2 = u.c22[0] + u.c22[1];
c2 = u.c22arr[2][0] + u.c22arr[2][1];
f16vec3 c3 = u.c23[0] + u.c23[1];
f16vec4 c4 = u.c24[0] + u.c24[1];
c2 = u.c32[0] + u.c32[1] + u.c32[2];
c3 = u.c33[0] + u.c33[1] + u.c33[2];
c4 = u.c34[0] + u.c34[1] + u.c34[2];
c2 = u.c42[0] + u.c42[1] + u.c42[2] + u.c42[3];
c3 = u.c43[0] + u.c43[1] + u.c43[2] + u.c43[3];
c4 = u.c44[0] + u.c44[1] + u.c44[2] + u.c44[3];
// Load scalars.
float16_t c = u.c22[0].x + u.c22[0].y + u.c22[1].x + u.c22[1].y;
c = u.c22arr[2][0].x + u.c22arr[2][0].y + u.c22arr[2][1].x + u.c22arr[2][1].y;
// Load full matrix.
f16mat2x2 c22 = u.c22;
c22 = u.c22arr[2];
f16mat2x3 c23 = u.c23;
f16mat2x4 c24 = u.c24;
f16mat3x2 c32 = u.c32;
f16mat3x3 c33 = u.c33;
f16mat3x4 c34 = u.c34;
f16mat4x2 c42 = u.c42;
f16mat4x3 c43 = u.c43;
f16mat4x4 c44 = u.c44;
// Same, but row-major.
f16vec2 r2 = u.r22[0] + u.r22[1];
r2 = u.r22arr[2][0] + u.r22arr[2][1];
f16vec3 r3 = u.r23[0] + u.r23[1];
f16vec4 r4 = u.r24[0] + u.r24[1];
r2 = u.r32[0] + u.r32[1] + u.r32[2];
r3 = u.r33[0] + u.r33[1] + u.r33[2];
r4 = u.r34[0] + u.r34[1] + u.r34[2];
r2 = u.r42[0] + u.r42[1] + u.r42[2] + u.r42[3];
r3 = u.r43[0] + u.r43[1] + u.r43[2] + u.r43[3];
r4 = u.r44[0] + u.r44[1] + u.r44[2] + u.r44[3];
// Load scalars.
float16_t r = u.r22[0].x + u.r22[0].y + u.r22[1].x + u.r22[1].y;
// Load full matrix.
f16mat2x2 r22 = u.r22;
f16mat2x3 r23 = u.r23;
f16mat2x4 r24 = u.r24;
f16mat3x2 r32 = u.r32;
f16mat3x3 r33 = u.r33;
f16mat3x4 r34 = u.r34;
f16mat4x2 r42 = u.r42;
f16mat4x3 r43 = u.r43;
f16mat4x4 r44 = u.r44;
float16_t h1 = u.h1[5];
f16vec2 h2 = u.h2[5];
f16vec3 h3 = u.h3[5];
f16vec4 h4 = u.h4[5];
FragColor = vec4(1.0);
}

View File

@ -10176,6 +10176,16 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
type_id = type->parent_type;
type = &get<SPIRType>(type_id);
// If the physical type has an unnatural vecsize,
// we must assume it's a faked struct where the .data member
// is used for the real payload.
if (physical_type && (is_vector(*type) || is_scalar(*type)))
{
auto &phys = get<SPIRType>(physical_type);
if (phys.vecsize > 4)
expr += ".data";
}
access_chain_is_arrayed = true;
}
// For structs, the index refers to a constant, which indexes into the members, possibly through a redirection mapping.
@ -10261,6 +10271,16 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
expr += to_unpacked_expression(index, register_expression_read);
expr += "]";
// If the physical type has an unnatural vecsize,
// we must assume it's a faked struct where the .data member
// is used for the real payload.
if (physical_type)
{
auto &phys = get<SPIRType>(physical_type);
if (phys.vecsize > 4 || phys.columns > 4)
expr += ".data";
}
type_id = type->parent_type;
type = &get<SPIRType>(type_id);
}
@ -10275,6 +10295,18 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (column_index != string::npos)
{
deferred_index = expr.substr(column_index);
auto end_deferred_index = deferred_index.find_last_of(']');
if (end_deferred_index != string::npos && end_deferred_index + 1 != deferred_index.size())
{
// If we have any data member fixups, it must be transposed so that it refers to this index.
// E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong,
// and needs to be [1].data[0] instead.
end_deferred_index++;
deferred_index = deferred_index.substr(end_deferred_index) +
deferred_index.substr(0, end_deferred_index);
}
expr.resize(column_index);
}
}
@ -10353,8 +10385,14 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (row_major_matrix_needs_conversion && !ignore_potential_sliced_writes)
{
prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
is_packed);
if (prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
is_packed))
{
// We're in a pointer context now, so just remove any member dereference.
auto first_index = deferred_index.find_first_of('[');
if (first_index != string::npos && first_index != 0)
deferred_index = deferred_index.substr(first_index);
}
}
if (access_meshlet_position_y)
@ -10413,8 +10451,9 @@ void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uin
{
}
void CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
{
return false;
}
string CompilerGLSL::to_flattened_struct_member(const string &basename, const SPIRType &type, uint32_t index)
@ -14957,6 +14996,17 @@ string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType &ex
auto column_expr = exp_str.substr(column_index);
exp_str.resize(column_index);
auto end_deferred_index = column_expr.find_last_of(']');
if (end_deferred_index != string::npos && end_deferred_index + 1 != column_expr.size())
{
// If we have any data member fixups, it must be transposed so that it refers to this index.
// E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong,
// and needs to be [1].data[0] instead.
end_deferred_index++;
column_expr = column_expr.substr(end_deferred_index) +
column_expr.substr(0, end_deferred_index);
}
auto transposed_expr = type_to_glsl_constructor(exp_type) + "(";
// Loading a column from a row-major matrix. Unroll the load.

View File

@ -757,7 +757,7 @@ protected:
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
virtual void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed);
std::string access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,

View File

@ -4775,9 +4775,17 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
if (elems_per_stride == 3)
SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
else if (elems_per_stride > 4)
else if (elems_per_stride > 4 && elems_per_stride != 8)
SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
if (elems_per_stride == 8)
{
if (mbr_type.width == 16)
add_spv_func_and_recompile(SPVFuncImplPaddedStd140);
else
SPIRV_CROSS_THROW("Unexpected type in std140 wide array resolve.");
}
auto physical_type = mbr_type;
physical_type.vecsize = elems_per_stride;
physical_type.parent_type = 0;
@ -4809,13 +4817,20 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
if (elems_per_stride == 3)
SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
else if (elems_per_stride > 4)
else if (elems_per_stride > 4 && elems_per_stride != 8)
SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
if (elems_per_stride == 8)
{
if (mbr_type.basetype != SPIRType::Half)
SPIRV_CROSS_THROW("Unexpected type in std140 wide matrix stride resolve.");
add_spv_func_and_recompile(SPVFuncImplPaddedStd140);
}
bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
auto physical_type = mbr_type;
physical_type.parent_type = 0;
if (row_major)
physical_type.columns = elems_per_stride;
else
@ -5114,6 +5129,13 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
{
auto lhs_expr = to_enclosed_expression(lhs_expression);
auto column_index = lhs_expr.find_last_of('[');
// Get rid of any ".data" half8 handling here, we're casting to scalar anyway.
auto end_column_index = lhs_expr.find_last_of(']');
auto end_dot_index = lhs_expr.find_last_of('.');
if (end_dot_index != string::npos && end_dot_index > end_column_index)
lhs_expr.resize(end_dot_index);
if (column_index != string::npos)
{
statement("((", cast_addr_space, " ", type_to_glsl(write_type), "*)&",
@ -5124,7 +5146,9 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
lhs_e->need_transpose = true;
}
else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
else if ((is_matrix(physical_type) || is_array(physical_type)) &&
physical_type.vecsize <= 4 &&
physical_type.vecsize > type.vecsize)
{
assert(type.vecsize >= 1 && type.vecsize <= 3);
@ -5181,19 +5205,26 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
".x",
".xy",
".xyz",
"",
};
// TODO: Move everything to the template wrapper?
bool uses_std140_wrapper = physical_type && physical_type->vecsize > 4;
if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
!uses_std140_wrapper &&
physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
{
// std140 array cases for vectors.
assert(type.vecsize >= 1 && type.vecsize <= 3);
return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
}
else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
else if (physical_type && is_matrix(*physical_type) && is_vector(type) &&
!uses_std140_wrapper &&
physical_type->vecsize > type.vecsize)
{
// Extract column from padded matrix.
assert(type.vecsize >= 1 && type.vecsize <= 3);
assert(type.vecsize >= 1 && type.vecsize <= 4);
return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
}
else if (is_matrix(type))
@ -5215,6 +5246,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
string unpack_expr = join(base_type, columns, "x", vecsize, "(");
const char *load_swiz = "";
const char *data_swiz = physical_vecsize > 4 ? ".data" : "";
if (physical_vecsize != vecsize)
load_swiz = swizzle_lut[vecsize - 1];
@ -5227,7 +5259,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
if (packed)
unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
else
unpack_expr += join(expr_str, "[", i, "]", load_swiz);
unpack_expr += join(expr_str, "[", i, "]", data_swiz, load_swiz);
}
unpack_expr += ")";
@ -7335,6 +7367,15 @@ void CompilerMSL::emit_custom_functions()
}
break;
case SPVFuncImplPaddedStd140:
// .data is used in access chain.
statement("template <typename T>");
statement("struct spvPaddedStd140 { alignas(16) T data; };");
statement("template <typename T, int n>");
statement("using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];");
statement("");
break;
default:
break;
}
@ -8338,7 +8379,7 @@ bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
(builtin == BuiltInTessLevelOuter && c->scalar() == 3);
}
void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
bool CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed)
{
// If there is any risk of writes happening with the access chain in question,
@ -8352,7 +8393,10 @@ void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, cons
// Further indexing should happen with packed rules (array index, not swizzle).
is_packed = true;
return true;
}
else
return false;
}
bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
@ -11776,6 +11820,7 @@ void CompilerMSL::emit_fixup()
string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
const string &qualifier)
{
uint32_t orig_member_type_id = member_type_id;
if (member_is_remapped_physical_type(type, index))
member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
auto &physical_type = get<SPIRType>(member_type_id);
@ -11887,7 +11932,24 @@ string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_
array_type = type_to_array_glsl(physical_type);
}
auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id, true), " ", qualifier,
string decl_type;
if (declared_type->vecsize > 4)
{
auto orig_type = get<SPIRType>(orig_member_type_id);
if (is_matrix(orig_type) && row_major)
swap(orig_type.vecsize, orig_type.columns);
orig_type.columns = 1;
decl_type = type_to_glsl(orig_type, orig_id, true);
if (declared_type->columns > 1)
decl_type = join("spvPaddedStd140Matrix<", decl_type, ", ", declared_type->columns, ">");
else
decl_type = join("spvPaddedStd140<", decl_type, ">");
}
else
decl_type = type_to_glsl(*declared_type, orig_id, true);
auto result = join(pack_pfx, decl_type, " ", qualifier,
to_member_name(type, index), member_attribute_qualifier(type, index), array_type, ";");
is_using_builtin_array = false;

View File

@ -815,6 +815,7 @@ protected:
SPVFuncImplVariableDescriptor,
SPVFuncImplVariableSizedDescriptor,
SPVFuncImplVariableDescriptorArray,
SPVFuncImplPaddedStd140
};
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
@ -1097,7 +1098,7 @@ protected:
void analyze_sampled_image_usage();
bool access_chain_needs_stage_io_builtin_translation(uint32_t base) override;
void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool &is_packed) override;
void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;