Update Metal matrix intrinsic polyfills to allow half.

The inverse, outerProduct and matrixCompMult polyfill functions in Metal
were written assuming that all float matrices would use the `float`
type. They now use a template so that `half` matrices will work too.

Change-Id: I7696c8ad1e4aaffbd71c56b9245485e74cd96c5a
Bug: skia:12339
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/463338
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
John Stiles 2021-10-25 19:51:13 -04:00 committed by SkCQ
parent 643bd0fc8f
commit b37100de7d
5 changed files with 76 additions and 72 deletions

View File

@ -394,61 +394,64 @@ void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
}
static constexpr char kInverse2x2[] = R"(
float2x2 float2x2_inverse(float2x2 m) {
return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
template <typename T>
matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
}
)";
static constexpr char kInverse3x3[] = R"(
float3x3 float3x3_inverse(float3x3 m) {
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
float b01 = a22*a11 - a12*a21;
float b11 = -a22*a10 + a12*a20;
float b21 = a21*a10 - a11*a20;
float det = a00*b01 + a01*b11 + a02*b21;
return float3x3(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
template <typename T>
matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
T b01 = a22*a11 - a12*a21;
T b11 = -a22*a10 + a12*a20;
T b21 = a21*a10 - a11*a20;
T det = a00*b01 + a01*b11 + a02*b21;
return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
}
)";
static constexpr char kInverse4x4[] = R"(
float4x4 float4x4_inverse(float4x4 m) {
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
float b00 = a00*a11 - a01*a10;
float b01 = a00*a12 - a02*a10;
float b02 = a00*a13 - a03*a10;
float b03 = a01*a12 - a02*a11;
float b04 = a01*a13 - a03*a11;
float b05 = a02*a13 - a03*a12;
float b06 = a20*a31 - a21*a30;
float b07 = a20*a32 - a22*a30;
float b08 = a20*a33 - a23*a30;
float b09 = a21*a32 - a22*a31;
float b10 = a21*a33 - a23*a31;
float b11 = a22*a33 - a23*a32;
float det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
return float4x4(a11*b11 - a12*b10 + a13*b09,
a02*b10 - a01*b11 - a03*b09,
a31*b05 - a32*b04 + a33*b03,
a22*b04 - a21*b05 - a23*b03,
a12*b08 - a10*b11 - a13*b07,
a00*b11 - a02*b08 + a03*b07,
a32*b02 - a30*b05 - a33*b01,
a20*b05 - a22*b02 + a23*b01,
a10*b10 - a11*b08 + a13*b06,
a01*b08 - a00*b10 - a03*b06,
a30*b04 - a31*b02 + a33*b00,
a21*b02 - a20*b04 - a23*b00,
a11*b07 - a10*b09 - a12*b06,
a00*b09 - a01*b07 + a02*b06,
a31*b01 - a30*b03 - a32*b00,
a20*b03 - a21*b01 + a22*b00) * (1/det);
template <typename T>
matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
T b00 = a00*a11 - a01*a10;
T b01 = a00*a12 - a02*a10;
T b02 = a00*a13 - a03*a10;
T b03 = a01*a12 - a02*a11;
T b04 = a01*a13 - a03*a11;
T b05 = a02*a13 - a03*a12;
T b06 = a20*a31 - a21*a30;
T b07 = a20*a32 - a22*a30;
T b08 = a20*a33 - a23*a30;
T b09 = a21*a32 - a22*a31;
T b10 = a21*a33 - a23*a31;
T b11 = a22*a33 - a23*a32;
T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
a02*b10 - a01*b11 - a03*b09,
a31*b05 - a32*b04 + a33*b03,
a22*b04 - a21*b05 - a23*b03,
a12*b08 - a10*b11 - a13*b07,
a00*b11 - a02*b08 + a03*b07,
a32*b02 - a30*b05 - a33*b01,
a20*b05 - a22*b02 + a23*b01,
a10*b10 - a11*b08 + a13*b06,
a01*b08 - a00*b10 - a03*b06,
a30*b04 - a31*b02 + a33*b00,
a21*b02 - a20*b04 - a23*b00,
a11*b07 - a10*b09 - a12*b06,
a00*b09 - a01*b07 + a02*b06,
a31*b01 - a30*b03 - a32*b00,
a20*b03 - a21*b01 + a22*b00) * (1/det);
}
)";
@ -458,7 +461,7 @@ String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments)
const Type& type = arguments.front()->type();
if (type.isMatrix() && type.rows() == type.columns()) {
// Inject the correct polyfill based on the matrix size.
String name = this->typeName(type) + "_inverse";
auto name = String::printf("mat%d_inverse", type.columns());
auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
if (didInsert) {
switch (type.rows()) {
@ -482,8 +485,8 @@ String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments)
void MetalCodeGenerator::writeMatrixCompMult() {
static constexpr char kMatrixCompMult[] = R"(
template <int C, int R>
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
template <typename T, int C, int R>
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
for (int c = 0; c < C; ++c) {
a[c] *= b[c];
}
@ -500,9 +503,9 @@ matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C,
void MetalCodeGenerator::writeOuterProduct() {
static constexpr char kOuterProduct[] = R"(
template <int C, int R>
matrix<float, C, R> outerProduct(const vec<float, R> a, const vec<float, C> b) {
matrix<float, C, R> result;
template <typename T, int C, int R>
matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
matrix<T, C, R> result;
for (int c = 0; c < C; ++c) {
result[c] = a * b[c];
}

View File

@ -44,17 +44,18 @@ thread bool operator!=(const float4x4 left, const float4x4 right) {
return !(left == right);
}
float3x3 float3x3_inverse(float3x3 m) {
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
float b01 = a22*a11 - a12*a21;
float b11 = -a22*a10 + a12*a20;
float b21 = a21*a10 - a11*a20;
float det = a00*b01 + a01*b11 + a02*b21;
return float3x3(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
template <typename T>
matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
T b01 = a22*a11 - a12*a21;
T b11 = -a22*a10 + a12*a20;
T b21 = a21*a10 - a11*a20;
T det = a00*b01 + a01*b11 + a02*b21;
return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
}
fragment Outputs fragmentMain(Inputs _in [[stage_in]], constant Uniforms& _uniforms [[buffer(0)]], bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]) {
Outputs _out;
@ -62,6 +63,6 @@ fragment Outputs fragmentMain(Inputs _in [[stage_in]], constant Uniforms& _unifo
float2x2 inv2x2 = float2x2(float2(-2.0, 1.0), float2(1.5, -0.5));
float3x3 inv3x3 = float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0));
float4x4 inv4x4 = float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5));
_out.sk_FragColor = ((float2x2(float2(-2.0, 1.0), float2(1.5, -0.5)) == inv2x2 && float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0)) == inv3x3) && float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5)) == inv4x4) && float3x3_inverse(float3x3(float3(1.0, 2.0, 3.0), float3(4.0, 5.0, 6.0), float3(7.0, 8.0, 9.0))) != inv3x3 ? _uniforms.colorGreen : _uniforms.colorRed;
_out.sk_FragColor = ((float2x2(float2(-2.0, 1.0), float2(1.5, -0.5)) == inv2x2 && float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0)) == inv3x3) && float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5)) == inv4x4) && mat3_inverse(float3x3(float3(1.0, 2.0, 3.0), float3(4.0, 5.0, 6.0), float3(7.0, 8.0, 9.0))) != inv3x3 ? _uniforms.colorGreen : _uniforms.colorRed;
return _out;
}

View File

@ -19,8 +19,8 @@ thread bool operator!=(const float2x2 left, const float2x2 right);
thread bool operator==(const float3x3 left, const float3x3 right);
thread bool operator!=(const float3x3 left, const float3x3 right);
template <int C, int R>
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
template <typename T, int C, int R>
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
for (int c = 0; c < C; ++c) {
a[c] *= b[c];
}

View File

@ -20,8 +20,8 @@ thread bool operator!=(const float4x2 left, const float4x2 right);
thread bool operator==(const float4x3 left, const float4x3 right);
thread bool operator!=(const float4x3 left, const float4x3 right);
template <int C, int R>
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
template <typename T, int C, int R>
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
for (int c = 0; c < C; ++c) {
a[c] *= b[c];
}

View File

@ -39,9 +39,9 @@ thread bool operator!=(const float2x2 left, const float2x2 right) {
return !(left == right);
}
template <int C, int R>
matrix<float, C, R> outerProduct(const vec<float, R> a, const vec<float, C> b) {
matrix<float, C, R> result;
template <typename T, int C, int R>
matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
matrix<T, C, R> result;
for (int c = 0; c < C; ++c) {
result[c] = a * b[c];
}