MSL: Support std140 half matrices and arrays.

Super awkward since there is no clean way to express this.
This commit is contained in:
Hans-Kristian Arntzen 2023-11-27 13:36:49 +01:00
parent 42299f92ef
commit 57dbfa0400
8 changed files with 431 additions and 12 deletions

View File

@ -0,0 +1,66 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct half8 { alignas(16) half4 data; half4 padding_for_std140_fix_your_shader; };
using half2x8 = half8[2];
using half3x8 = half8[3];
using half4x8 = half8[4];
struct ushort8 { alignas(16) ushort4 data; ushort4 padding_for_std140_fix_your_shader; };
struct short8 { alignas(16) short4 data; short4 padding_for_std140_fix_your_shader; };
struct Foo
{
half2x8 c23;
half3x8 c32;
half3x8 r23;
half2x8 r32;
half8 h1[6];
half8 h2[6];
half8 h3[6];
half8 h4[6];
};
struct main0_out
{
float4 FragColor [[color(0)]];
};
fragment main0_out main0(device Foo& _20 [[buffer(0)]])
{
main0_out out = {};
((device half*)&_20.c23[1].data)[2u] = half(1.0);
((device half*)&_20.c32[2].data)[1u] = half(2.0);
((device half*)&_20.r23[2u])[1] = half(3.0);
((device half*)&_20.r32[1u])[2] = half(4.0);
_20.c23[1].data = half3(half(0.0), half(1.0), half(2.0));
_20.c32[1].data = half2(half(0.0), half(1.0));
((device half*)&_20.r23[0])[1] = half3(half(0.0), half(1.0), half(2.0)).x;
((device half*)&_20.r23[1])[1] = half3(half(0.0), half(1.0), half(2.0)).y;
((device half*)&_20.r23[2])[1] = half3(half(0.0), half(1.0), half(2.0)).z;
((device half*)&_20.r32[0])[1] = half2(half(0.0), half(1.0)).x;
((device half*)&_20.r32[1])[1] = half2(half(0.0), half(1.0)).y;
(device half3&)_20.c23[0] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0];
(device half3&)_20.c23[1] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1];
(device half2&)_20.c32[0] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0];
(device half2&)_20.c32[1] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1];
(device half2&)_20.c32[2] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2];
(device half2&)_20.r23[0] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][0], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][0]);
(device half2&)_20.r23[1] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][1], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][1]);
(device half2&)_20.r23[2] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][2], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][2]);
(device half3&)_20.r32[0] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][0]);
(device half3&)_20.r32[1] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][1]);
_20.h1[5].data = half(1.0);
_20.h2[5].data = half2(half(1.0), half(2.0));
_20.h3[5].data = half3(half(1.0), half(2.0), half(3.0));
_20.h4[5].data = half4(half(1.0), half(2.0), half(3.0), half(4.0));
((device half*)&_20.h2[5].data)[1u] = half(10.0);
((device half*)&_20.h3[5].data)[2u] = half(11.0);
((device half*)&_20.h4[5].data)[3u] = half(12.0);
out.FragColor = float4(1.0);
return out;
}

View File

@ -0,0 +1,100 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct half8 { alignas(16) half4 data; half4 padding_for_std140_fix_your_shader; };
using half2x8 = half8[2];
using half3x8 = half8[3];
using half4x8 = half8[4];
struct ushort8 { alignas(16) ushort4 data; ushort4 padding_for_std140_fix_your_shader; };
struct short8 { alignas(16) short4 data; short4 padding_for_std140_fix_your_shader; };
struct Foo
{
half2x8 c22;
half2x8 c22arr[3];
half2x8 c23;
half2x8 c24;
half3x8 c32;
half3x8 c33;
half3x8 c34;
half4x8 c42;
half4x8 c43;
half4x8 c44;
half2x8 r22;
half2x8 r22arr[3];
half3x8 r23;
half4x8 r24;
half2x8 r32;
half3x8 r33;
half4x8 r34;
half2x8 r42;
half3x8 r43;
half4x8 r44;
half8 h1[6];
half8 h2[6];
half8 h3[6];
half8 h4[6];
};
struct main0_out
{
float4 FragColor [[color(0)]];
};
fragment main0_out main0(constant Foo& u [[buffer(0)]])
{
main0_out out = {};
half2 c2 = u.c22[0].data.xy + u.c22[1].data.xy;
c2 = u.c22arr[2][0].data.xy + u.c22arr[2][1].data.xy;
half3 c3 = u.c23[0].data.xyz + u.c23[1].data.xyz;
half4 c4 = u.c24[0].data + u.c24[1].data;
c2 = (u.c32[0].data.xy + u.c32[1].data.xy) + u.c32[2].data.xy;
c3 = (u.c33[0].data.xyz + u.c33[1].data.xyz) + u.c33[2].data.xyz;
c4 = (u.c34[0].data + u.c34[1].data) + u.c34[2].data;
c2 = ((u.c42[0].data.xy + u.c42[1].data.xy) + u.c42[2].data.xy) + u.c42[3].data.xy;
c3 = ((u.c43[0].data.xyz + u.c43[1].data.xyz) + u.c43[2].data.xyz) + u.c43[3].data.xyz;
c4 = ((u.c44[0].data + u.c44[1].data) + u.c44[2].data) + u.c44[3].data;
half c = ((u.c22[0].data.x + u.c22[0].data.y) + u.c22[1].data.x) + u.c22[1].data.y;
c = ((u.c22arr[2][0].data.x + u.c22arr[2][0].data.y) + u.c22arr[2][1].data.x) + u.c22arr[2][1].data.y;
half2x2 c22 = half2x2(u.c22[0].data.xy, u.c22[1].data.xy);
c22 = half2x2(u.c22arr[2][0].data.xy, u.c22arr[2][1].data.xy);
half2x3 c23 = half2x3(u.c23[0].data.xyz, u.c23[1].data.xyz);
half2x4 c24 = half2x4(u.c24[0].data, u.c24[1].data);
half3x2 c32 = half3x2(u.c32[0].data.xy, u.c32[1].data.xy, u.c32[2].data.xy);
half3x3 c33 = half3x3(u.c33[0].data.xyz, u.c33[1].data.xyz, u.c33[2].data.xyz);
half3x4 c34 = half3x4(u.c34[0].data, u.c34[1].data, u.c34[2].data);
half4x2 c42 = half4x2(u.c42[0].data.xy, u.c42[1].data.xy, u.c42[2].data.xy, u.c42[3].data.xy);
half4x3 c43 = half4x3(u.c43[0].data.xyz, u.c43[1].data.xyz, u.c43[2].data.xyz, u.c43[3].data.xyz);
half4x4 c44 = half4x4(u.c44[0].data, u.c44[1].data, u.c44[2].data, u.c44[3].data);
half2 r2 = half2(u.r22[0].data[0], u.r22[1].data[0]) + half2(u.r22[0].data[1], u.r22[1].data[1]);
r2 = half2(u.r22arr[2][0].data[0], u.r22arr[2][1].data[0]) + half2(u.r22arr[2][0].data[1], u.r22arr[2][1].data[1]);
half3 r3 = half3(u.r23[0].data[0], u.r23[1].data[0], u.r23[2].data[0]) + half3(u.r23[0].data[1], u.r23[1].data[1], u.r23[2].data[1]);
half4 r4 = half4(u.r24[0].data[0], u.r24[1].data[0], u.r24[2].data[0], u.r24[3].data[0]) + half4(u.r24[0].data[1], u.r24[1].data[1], u.r24[2].data[1], u.r24[3].data[1]);
r2 = (half2(u.r32[0].data[0], u.r32[1].data[0]) + half2(u.r32[0].data[1], u.r32[1].data[1])) + half2(u.r32[0].data[2], u.r32[1].data[2]);
r3 = (half3(u.r33[0].data[0], u.r33[1].data[0], u.r33[2].data[0]) + half3(u.r33[0].data[1], u.r33[1].data[1], u.r33[2].data[1])) + half3(u.r33[0].data[2], u.r33[1].data[2], u.r33[2].data[2]);
r4 = (half4(u.r34[0].data[0], u.r34[1].data[0], u.r34[2].data[0], u.r34[3].data[0]) + half4(u.r34[0].data[1], u.r34[1].data[1], u.r34[2].data[1], u.r34[3].data[1])) + half4(u.r34[0].data[2], u.r34[1].data[2], u.r34[2].data[2], u.r34[3].data[2]);
r2 = ((half2(u.r42[0].data[0], u.r42[1].data[0]) + half2(u.r42[0].data[1], u.r42[1].data[1])) + half2(u.r42[0].data[2], u.r42[1].data[2])) + half2(u.r42[0].data[3], u.r42[1].data[3]);
r3 = ((half3(u.r43[0].data[0], u.r43[1].data[0], u.r43[2].data[0]) + half3(u.r43[0].data[1], u.r43[1].data[1], u.r43[2].data[1])) + half3(u.r43[0].data[2], u.r43[1].data[2], u.r43[2].data[2])) + half3(u.r43[0].data[3], u.r43[1].data[3], u.r43[2].data[3]);
r4 = ((half4(u.r44[0].data[0], u.r44[1].data[0], u.r44[2].data[0], u.r44[3].data[0]) + half4(u.r44[0].data[1], u.r44[1].data[1], u.r44[2].data[1], u.r44[3].data[1])) + half4(u.r44[0].data[2], u.r44[1].data[2], u.r44[2].data[2], u.r44[3].data[2])) + half4(u.r44[0].data[3], u.r44[1].data[3], u.r44[2].data[3], u.r44[3].data[3]);
half r = ((u.r22[0u].data[0] + u.r22[1u].data[0]) + u.r22[0u].data[1]) + u.r22[1u].data[1];
half2x2 r22 = transpose(half2x2(u.r22[0].data.xy, u.r22[1].data.xy));
half2x3 r23 = transpose(half3x2(u.r23[0].data.xy, u.r23[1].data.xy, u.r23[2].data.xy));
half2x4 r24 = transpose(half4x2(u.r24[0].data.xy, u.r24[1].data.xy, u.r24[2].data.xy, u.r24[3].data.xy));
half3x2 r32 = transpose(half2x3(u.r32[0].data.xyz, u.r32[1].data.xyz));
half3x3 r33 = transpose(half3x3(u.r33[0].data.xyz, u.r33[1].data.xyz, u.r33[2].data.xyz));
half3x4 r34 = transpose(half4x3(u.r34[0].data.xyz, u.r34[1].data.xyz, u.r34[2].data.xyz, u.r34[3].data.xyz));
half4x2 r42 = transpose(half2x4(u.r42[0].data, u.r42[1].data));
half4x3 r43 = transpose(half3x4(u.r43[0].data, u.r43[1].data, u.r43[2].data));
half4x4 r44 = transpose(half4x4(u.r44[0].data, u.r44[1].data, u.r44[2].data, u.r44[3].data));
half h1 = u.h1[5].data.x;
half2 h2 = u.h2[5].data.xy;
half3 h3 = u.h3[5].data.xyz;
half4 h4 = half4(u.h4[5].data);
out.FragColor = float4(1.0);
return out;
}

View File

@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
layout(set = 0, binding = 0, std140) buffer Foo
{
f16mat2x3 c23;
f16mat3x2 c32;
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat3x2 r32;
float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
};
layout(location = 0) out vec4 FragColor;
void main()
{
// Store scalar
c23[1][2] = 1.0hf;
c32[2][1] = 2.0hf;
r23[1][2] = 3.0hf;
r32[2][1] = 4.0hf;
// Store vector
c23[1] = f16vec3(0, 1, 2);
c32[1] = f16vec2(0, 1);
r23[1] = f16vec3(0, 1, 2);
r32[1] = f16vec2(0, 1);
// Store matrix
c23 = f16mat2x3(1, 2, 3, 4, 5, 6);
c32 = f16mat3x2(1, 2, 3, 4, 5, 6);
r23 = f16mat2x3(1, 2, 3, 4, 5, 6);
r32 = f16mat3x2(1, 2, 3, 4, 5, 6);
// Store array
h1[5] = 1.0hf;
h2[5] = f16vec2(1, 2);
h3[5] = f16vec3(1, 2, 3);
h4[5] = f16vec4(1, 2, 3, 4);
// Store scalar in array
h2[5][1] = 10.0hf;
h3[5][2] = 11.0hf;
h4[5][3] = 12.0hf;
FragColor = vec4(1.0);
}

View File

@ -0,0 +1,110 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
layout(set = 0, binding = 0, std140) uniform Foo
{
f16mat2x2 c22;
f16mat2x2 c22arr[3];
f16mat2x3 c23;
f16mat2x4 c24;
f16mat3x2 c32;
f16mat3x3 c33;
f16mat3x4 c34;
f16mat4x2 c42;
f16mat4x3 c43;
f16mat4x4 c44;
layout(row_major) f16mat2x2 r22;
layout(row_major) f16mat2x2 r22arr[3];
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat2x4 r24;
layout(row_major) f16mat3x2 r32;
layout(row_major) f16mat3x3 r33;
layout(row_major) f16mat3x4 r34;
layout(row_major) f16mat4x2 r42;
layout(row_major) f16mat4x3 r43;
layout(row_major) f16mat4x4 r44;
float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
} u;
layout(location = 0) out vec4 FragColor;
void main()
{
// Load vectors.
f16vec2 c2 = u.c22[0] + u.c22[1];
c2 = u.c22arr[2][0] + u.c22arr[2][1];
f16vec3 c3 = u.c23[0] + u.c23[1];
f16vec4 c4 = u.c24[0] + u.c24[1];
c2 = u.c32[0] + u.c32[1] + u.c32[2];
c3 = u.c33[0] + u.c33[1] + u.c33[2];
c4 = u.c34[0] + u.c34[1] + u.c34[2];
c2 = u.c42[0] + u.c42[1] + u.c42[2] + u.c42[3];
c3 = u.c43[0] + u.c43[1] + u.c43[2] + u.c43[3];
c4 = u.c44[0] + u.c44[1] + u.c44[2] + u.c44[3];
// Load scalars.
float16_t c = u.c22[0].x + u.c22[0].y + u.c22[1].x + u.c22[1].y;
c = u.c22arr[2][0].x + u.c22arr[2][0].y + u.c22arr[2][1].x + u.c22arr[2][1].y;
// Load full matrix.
f16mat2x2 c22 = u.c22;
c22 = u.c22arr[2];
f16mat2x3 c23 = u.c23;
f16mat2x4 c24 = u.c24;
f16mat3x2 c32 = u.c32;
f16mat3x3 c33 = u.c33;
f16mat3x4 c34 = u.c34;
f16mat4x2 c42 = u.c42;
f16mat4x3 c43 = u.c43;
f16mat4x4 c44 = u.c44;
// Same, but row-major.
f16vec2 r2 = u.r22[0] + u.r22[1];
r2 = u.r22arr[2][0] + u.r22arr[2][1];
f16vec3 r3 = u.r23[0] + u.r23[1];
f16vec4 r4 = u.r24[0] + u.r24[1];
r2 = u.r32[0] + u.r32[1] + u.r32[2];
r3 = u.r33[0] + u.r33[1] + u.r33[2];
r4 = u.r34[0] + u.r34[1] + u.r34[2];
r2 = u.r42[0] + u.r42[1] + u.r42[2] + u.r42[3];
r3 = u.r43[0] + u.r43[1] + u.r43[2] + u.r43[3];
r4 = u.r44[0] + u.r44[1] + u.r44[2] + u.r44[3];
// Load scalars.
float16_t r = u.r22[0].x + u.r22[0].y + u.r22[1].x + u.r22[1].y;
// Load full matrix.
f16mat2x2 r22 = u.r22;
f16mat2x3 r23 = u.r23;
f16mat2x4 r24 = u.r24;
f16mat3x2 r32 = u.r32;
f16mat3x3 r33 = u.r33;
f16mat3x4 r34 = u.r34;
f16mat4x2 r42 = u.r42;
f16mat4x3 r43 = u.r43;
f16mat4x4 r44 = u.r44;
float16_t h1 = u.h1[5];
f16vec2 h2 = u.h2[5];
f16vec3 h3 = u.h3[5];
f16vec4 h4 = u.h4[5];
FragColor = vec4(1.0);
}

View File

@ -10176,6 +10176,16 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
type_id = type->parent_type;
type = &get<SPIRType>(type_id);
// If the physical type has an unnatural vecsize,
// we must assume it's a faked struct where the .data member
// is used for the real payload.
if (physical_type && (is_vector(*type) || is_scalar(*type)))
{
auto &phys = get<SPIRType>(physical_type);
if (phys.vecsize > 4)
expr += ".data";
}
access_chain_is_arrayed = true;
}
// For structs, the index refers to a constant, which indexes into the members, possibly through a redirection mapping.
@ -10261,6 +10271,16 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
expr += to_unpacked_expression(index, register_expression_read);
expr += "]";
// If the physical type has an unnatural vecsize,
// we must assume it's a faked struct where the .data member
// is used for the real payload.
if (physical_type)
{
auto &phys = get<SPIRType>(physical_type);
if (phys.vecsize > 4 || phys.columns > 4)
expr += ".data";
}
type_id = type->parent_type;
type = &get<SPIRType>(type_id);
}
@ -10275,6 +10295,18 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (column_index != string::npos)
{
deferred_index = expr.substr(column_index);
auto end_deferred_index = deferred_index.find_last_of(']');
if (end_deferred_index != string::npos && end_deferred_index + 1 != deferred_index.size())
{
// If we have any data member fixups, it must be transposed so that it refers to this index.
// E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong,
// and needs to be [1].data[0] instead.
end_deferred_index++;
deferred_index = deferred_index.substr(end_deferred_index) +
deferred_index.substr(0, end_deferred_index);
}
expr.resize(column_index);
}
}
@ -10353,8 +10385,14 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (row_major_matrix_needs_conversion && !ignore_potential_sliced_writes)
{
prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
is_packed);
if (prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
is_packed))
{
// We're in a pointer context now, so just remove any member dereference.
auto first_index = deferred_index.find_first_of('[');
if (first_index != string::npos && first_index != 0)
deferred_index = deferred_index.substr(first_index);
}
}
if (access_meshlet_position_y)
@ -10413,8 +10451,9 @@ void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uin
{
}
void CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
{
return false;
}
string CompilerGLSL::to_flattened_struct_member(const string &basename, const SPIRType &type, uint32_t index)
@ -14957,6 +14996,17 @@ string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType &ex
auto column_expr = exp_str.substr(column_index);
exp_str.resize(column_index);
auto end_deferred_index = column_expr.find_last_of(']');
if (end_deferred_index != string::npos && end_deferred_index + 1 != column_expr.size())
{
// If we have any data member fixups, it must be transposed so that it refers to this index.
// E.g. [0].data followed by [1] would be shuffled to [1][0].data which is wrong,
// and needs to be [1].data[0] instead.
end_deferred_index++;
column_expr = column_expr.substr(end_deferred_index) +
column_expr.substr(0, end_deferred_index);
}
auto transposed_expr = type_to_glsl_constructor(exp_type) + "(";
// Loading a column from a row-major matrix. Unroll the load.

View File

@ -757,7 +757,7 @@ protected:
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
virtual void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed);
std::string access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,

View File

@ -4775,9 +4775,17 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
if (elems_per_stride == 3)
SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
else if (elems_per_stride > 4)
else if (elems_per_stride > 4 && elems_per_stride != 8)
SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
if (elems_per_stride == 8)
{
if (mbr_type.width == 16)
add_spv_func_and_recompile(SPVFuncImplHalfStd140);
else
SPIRV_CROSS_THROW("Unexpected type in std140 wide array resolve.");
}
auto physical_type = mbr_type;
physical_type.vecsize = elems_per_stride;
physical_type.parent_type = 0;
@ -4809,13 +4817,20 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
if (elems_per_stride == 3)
SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
else if (elems_per_stride > 4)
else if (elems_per_stride > 4 && elems_per_stride != 8)
SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
if (elems_per_stride == 8)
{
if (mbr_type.basetype != SPIRType::Half)
SPIRV_CROSS_THROW("Unexpected type in std140 wide matrix stride resolve.");
add_spv_func_and_recompile(SPVFuncImplHalfStd140);
}
bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
auto physical_type = mbr_type;
physical_type.parent_type = 0;
if (row_major)
physical_type.columns = elems_per_stride;
else
@ -5114,6 +5129,13 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
{
auto lhs_expr = to_enclosed_expression(lhs_expression);
auto column_index = lhs_expr.find_last_of('[');
// Get rid of any ".data" half8 handling here, we're casting to scalar anyway.
auto end_column_index = lhs_expr.find_last_of(']');
auto end_dot_index = lhs_expr.find_last_of('.');
if (end_dot_index != string::npos && end_dot_index > end_column_index)
lhs_expr.resize(end_dot_index);
if (column_index != string::npos)
{
statement("((", cast_addr_space, " ", type_to_glsl(write_type), "*)&",
@ -5124,7 +5146,9 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
lhs_e->need_transpose = true;
}
else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
else if ((is_matrix(physical_type) || is_array(physical_type)) &&
physical_type.vecsize <= 4 &&
physical_type.vecsize > type.vecsize)
{
assert(type.vecsize >= 1 && type.vecsize <= 3);
@ -5181,6 +5205,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
".x",
".xy",
".xyz",
"",
};
if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
@ -5193,7 +5218,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
{
// Extract column from padded matrix.
assert(type.vecsize >= 1 && type.vecsize <= 3);
assert(type.vecsize >= 1 && type.vecsize <= 4);
return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
}
else if (is_matrix(type))
@ -5215,6 +5240,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
string unpack_expr = join(base_type, columns, "x", vecsize, "(");
const char *load_swiz = "";
const char *data_swiz = physical_vecsize > 4 ? ".data" : "";
if (physical_vecsize != vecsize)
load_swiz = swizzle_lut[vecsize - 1];
@ -5227,7 +5253,7 @@ string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type
if (packed)
unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
else
unpack_expr += join(expr_str, "[", i, "]", load_swiz);
unpack_expr += join(expr_str, "[", i, "]", data_swiz, load_swiz);
}
unpack_expr += ")";
@ -7335,6 +7361,18 @@ void CompilerMSL::emit_custom_functions()
}
break;
case SPVFuncImplHalfStd140:
// .data is used in access chain.
statement("struct half8 { alignas(16) half4 data; half4 padding_for_std140_fix_your_shader; };");
// Physical type remapping is used to load/store full matrices anyway.
statement("using half2x8 = half8[2];");
statement("using half3x8 = half8[3];");
statement("using half4x8 = half8[4];");
statement("struct ushort8 { alignas(16) ushort4 data; ushort4 padding_for_std140_fix_your_shader; };");
statement("struct short8 { alignas(16) short4 data; short4 padding_for_std140_fix_your_shader; };");
statement("");
break;
default:
break;
}
@ -8338,7 +8376,7 @@ bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
(builtin == BuiltInTessLevelOuter && c->scalar() == 3);
}
void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
bool CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed)
{
// If there is any risk of writes happening with the access chain in question,
@ -8352,7 +8390,10 @@ void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, cons
// Further indexing should happen with packed rules (array index, not swizzle).
is_packed = true;
return true;
}
else
return false;
}
bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)

View File

@ -815,6 +815,7 @@ protected:
SPVFuncImplVariableDescriptor,
SPVFuncImplVariableSizedDescriptor,
SPVFuncImplVariableDescriptorArray,
SPVFuncImplHalfStd140
};
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
@ -1097,7 +1098,7 @@ protected:
void analyze_sampled_image_usage();
bool access_chain_needs_stage_io_builtin_translation(uint32_t base) override;
void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool &is_packed) override;
void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;