Update Metal matrix intrinsic polyfills to allow half.
The inverse, outerProduct and matrixCompMult polyfill functions in Metal were written assuming that all float matrices would use the `float` type. They now use a template so that `half` matrices will work too. Change-Id: I7696c8ad1e4aaffbd71c56b9245485e74cd96c5a Bug: skia:12339 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/463338 Auto-Submit: John Stiles <johnstiles@google.com> Commit-Queue: Brian Osman <brianosman@google.com> Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
parent
643bd0fc8f
commit
b37100de7d
@ -394,61 +394,64 @@ void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
|
||||
}
|
||||
|
||||
static constexpr char kInverse2x2[] = R"(
|
||||
float2x2 float2x2_inverse(float2x2 m) {
|
||||
return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
|
||||
template <typename T>
|
||||
matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
|
||||
return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
|
||||
}
|
||||
)";
|
||||
|
||||
static constexpr char kInverse3x3[] = R"(
|
||||
float3x3 float3x3_inverse(float3x3 m) {
|
||||
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
|
||||
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
|
||||
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
|
||||
float b01 = a22*a11 - a12*a21;
|
||||
float b11 = -a22*a10 + a12*a20;
|
||||
float b21 = a21*a10 - a11*a20;
|
||||
float det = a00*b01 + a01*b11 + a02*b21;
|
||||
return float3x3(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
|
||||
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
|
||||
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
|
||||
template <typename T>
|
||||
matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
|
||||
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
|
||||
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
|
||||
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
|
||||
T b01 = a22*a11 - a12*a21;
|
||||
T b11 = -a22*a10 + a12*a20;
|
||||
T b21 = a21*a10 - a11*a20;
|
||||
T det = a00*b01 + a01*b11 + a02*b21;
|
||||
return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
|
||||
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
|
||||
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
|
||||
}
|
||||
)";
|
||||
|
||||
static constexpr char kInverse4x4[] = R"(
|
||||
float4x4 float4x4_inverse(float4x4 m) {
|
||||
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
|
||||
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
|
||||
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
|
||||
float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
|
||||
float b00 = a00*a11 - a01*a10;
|
||||
float b01 = a00*a12 - a02*a10;
|
||||
float b02 = a00*a13 - a03*a10;
|
||||
float b03 = a01*a12 - a02*a11;
|
||||
float b04 = a01*a13 - a03*a11;
|
||||
float b05 = a02*a13 - a03*a12;
|
||||
float b06 = a20*a31 - a21*a30;
|
||||
float b07 = a20*a32 - a22*a30;
|
||||
float b08 = a20*a33 - a23*a30;
|
||||
float b09 = a21*a32 - a22*a31;
|
||||
float b10 = a21*a33 - a23*a31;
|
||||
float b11 = a22*a33 - a23*a32;
|
||||
float det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
|
||||
return float4x4(a11*b11 - a12*b10 + a13*b09,
|
||||
a02*b10 - a01*b11 - a03*b09,
|
||||
a31*b05 - a32*b04 + a33*b03,
|
||||
a22*b04 - a21*b05 - a23*b03,
|
||||
a12*b08 - a10*b11 - a13*b07,
|
||||
a00*b11 - a02*b08 + a03*b07,
|
||||
a32*b02 - a30*b05 - a33*b01,
|
||||
a20*b05 - a22*b02 + a23*b01,
|
||||
a10*b10 - a11*b08 + a13*b06,
|
||||
a01*b08 - a00*b10 - a03*b06,
|
||||
a30*b04 - a31*b02 + a33*b00,
|
||||
a21*b02 - a20*b04 - a23*b00,
|
||||
a11*b07 - a10*b09 - a12*b06,
|
||||
a00*b09 - a01*b07 + a02*b06,
|
||||
a31*b01 - a30*b03 - a32*b00,
|
||||
a20*b03 - a21*b01 + a22*b00) * (1/det);
|
||||
template <typename T>
|
||||
matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
|
||||
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
|
||||
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
|
||||
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
|
||||
T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
|
||||
T b00 = a00*a11 - a01*a10;
|
||||
T b01 = a00*a12 - a02*a10;
|
||||
T b02 = a00*a13 - a03*a10;
|
||||
T b03 = a01*a12 - a02*a11;
|
||||
T b04 = a01*a13 - a03*a11;
|
||||
T b05 = a02*a13 - a03*a12;
|
||||
T b06 = a20*a31 - a21*a30;
|
||||
T b07 = a20*a32 - a22*a30;
|
||||
T b08 = a20*a33 - a23*a30;
|
||||
T b09 = a21*a32 - a22*a31;
|
||||
T b10 = a21*a33 - a23*a31;
|
||||
T b11 = a22*a33 - a23*a32;
|
||||
T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
|
||||
return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
|
||||
a02*b10 - a01*b11 - a03*b09,
|
||||
a31*b05 - a32*b04 + a33*b03,
|
||||
a22*b04 - a21*b05 - a23*b03,
|
||||
a12*b08 - a10*b11 - a13*b07,
|
||||
a00*b11 - a02*b08 + a03*b07,
|
||||
a32*b02 - a30*b05 - a33*b01,
|
||||
a20*b05 - a22*b02 + a23*b01,
|
||||
a10*b10 - a11*b08 + a13*b06,
|
||||
a01*b08 - a00*b10 - a03*b06,
|
||||
a30*b04 - a31*b02 + a33*b00,
|
||||
a21*b02 - a20*b04 - a23*b00,
|
||||
a11*b07 - a10*b09 - a12*b06,
|
||||
a00*b09 - a01*b07 + a02*b06,
|
||||
a31*b01 - a30*b03 - a32*b00,
|
||||
a20*b03 - a21*b01 + a22*b00) * (1/det);
|
||||
}
|
||||
)";
|
||||
|
||||
@ -458,7 +461,7 @@ String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments)
|
||||
const Type& type = arguments.front()->type();
|
||||
if (type.isMatrix() && type.rows() == type.columns()) {
|
||||
// Inject the correct polyfill based on the matrix size.
|
||||
String name = this->typeName(type) + "_inverse";
|
||||
auto name = String::printf("mat%d_inverse", type.columns());
|
||||
auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
|
||||
if (didInsert) {
|
||||
switch (type.rows()) {
|
||||
@ -482,8 +485,8 @@ String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments)
|
||||
|
||||
void MetalCodeGenerator::writeMatrixCompMult() {
|
||||
static constexpr char kMatrixCompMult[] = R"(
|
||||
template <int C, int R>
|
||||
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
|
||||
template <typename T, int C, int R>
|
||||
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
a[c] *= b[c];
|
||||
}
|
||||
@ -500,9 +503,9 @@ matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C,
|
||||
|
||||
void MetalCodeGenerator::writeOuterProduct() {
|
||||
static constexpr char kOuterProduct[] = R"(
|
||||
template <int C, int R>
|
||||
matrix<float, C, R> outerProduct(const vec<float, R> a, const vec<float, C> b) {
|
||||
matrix<float, C, R> result;
|
||||
template <typename T, int C, int R>
|
||||
matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
|
||||
matrix<T, C, R> result;
|
||||
for (int c = 0; c < C; ++c) {
|
||||
result[c] = a * b[c];
|
||||
}
|
||||
|
@ -44,17 +44,18 @@ thread bool operator!=(const float4x4 left, const float4x4 right) {
|
||||
return !(left == right);
|
||||
}
|
||||
|
||||
float3x3 float3x3_inverse(float3x3 m) {
|
||||
float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
|
||||
float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
|
||||
float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
|
||||
float b01 = a22*a11 - a12*a21;
|
||||
float b11 = -a22*a10 + a12*a20;
|
||||
float b21 = a21*a10 - a11*a20;
|
||||
float det = a00*b01 + a01*b11 + a02*b21;
|
||||
return float3x3(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
|
||||
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
|
||||
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
|
||||
template <typename T>
|
||||
matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
|
||||
T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
|
||||
T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
|
||||
T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
|
||||
T b01 = a22*a11 - a12*a21;
|
||||
T b11 = -a22*a10 + a12*a20;
|
||||
T b21 = a21*a10 - a11*a20;
|
||||
T det = a00*b01 + a01*b11 + a02*b21;
|
||||
return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
|
||||
b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
|
||||
b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
|
||||
}
|
||||
fragment Outputs fragmentMain(Inputs _in [[stage_in]], constant Uniforms& _uniforms [[buffer(0)]], bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]) {
|
||||
Outputs _out;
|
||||
@ -62,6 +63,6 @@ fragment Outputs fragmentMain(Inputs _in [[stage_in]], constant Uniforms& _unifo
|
||||
float2x2 inv2x2 = float2x2(float2(-2.0, 1.0), float2(1.5, -0.5));
|
||||
float3x3 inv3x3 = float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0));
|
||||
float4x4 inv4x4 = float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5));
|
||||
_out.sk_FragColor = ((float2x2(float2(-2.0, 1.0), float2(1.5, -0.5)) == inv2x2 && float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0)) == inv3x3) && float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5)) == inv4x4) && float3x3_inverse(float3x3(float3(1.0, 2.0, 3.0), float3(4.0, 5.0, 6.0), float3(7.0, 8.0, 9.0))) != inv3x3 ? _uniforms.colorGreen : _uniforms.colorRed;
|
||||
_out.sk_FragColor = ((float2x2(float2(-2.0, 1.0), float2(1.5, -0.5)) == inv2x2 && float3x3(float3(-24.0, 18.0, 5.0), float3(20.0, -15.0, -4.0), float3(-5.0, 4.0, 1.0)) == inv3x3) && float4x4(float4(-2.0, -0.5, 1.0, 0.5), float4(1.0, 0.5, 0.0, -0.5), float4(-8.0, -1.0, 2.0, 2.0), float4(3.0, 0.5, -1.0, -0.5)) == inv4x4) && mat3_inverse(float3x3(float3(1.0, 2.0, 3.0), float3(4.0, 5.0, 6.0), float3(7.0, 8.0, 9.0))) != inv3x3 ? _uniforms.colorGreen : _uniforms.colorRed;
|
||||
return _out;
|
||||
}
|
||||
|
@ -19,8 +19,8 @@ thread bool operator!=(const float2x2 left, const float2x2 right);
|
||||
thread bool operator==(const float3x3 left, const float3x3 right);
|
||||
thread bool operator!=(const float3x3 left, const float3x3 right);
|
||||
|
||||
template <int C, int R>
|
||||
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
|
||||
template <typename T, int C, int R>
|
||||
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
a[c] *= b[c];
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ thread bool operator!=(const float4x2 left, const float4x2 right);
|
||||
thread bool operator==(const float4x3 left, const float4x3 right);
|
||||
thread bool operator!=(const float4x3 left, const float4x3 right);
|
||||
|
||||
template <int C, int R>
|
||||
matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, const matrix<float, C, R> b) {
|
||||
template <typename T, int C, int R>
|
||||
matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
a[c] *= b[c];
|
||||
}
|
||||
|
@ -39,9 +39,9 @@ thread bool operator!=(const float2x2 left, const float2x2 right) {
|
||||
return !(left == right);
|
||||
}
|
||||
|
||||
template <int C, int R>
|
||||
matrix<float, C, R> outerProduct(const vec<float, R> a, const vec<float, C> b) {
|
||||
matrix<float, C, R> result;
|
||||
template <typename T, int C, int R>
|
||||
matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
|
||||
matrix<T, C, R> result;
|
||||
for (int c = 0; c < C; ++c) {
|
||||
result[c] = a * b[c];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user