Implement constant folding for componentwise matrix-matrix ops.

Now, constant mat+mat, mat-mat, and mat/mat operations can be optimized
away. mat*mat does not operate componentwise and will need to be
handled differently.

Change-Id: Iabac6e58999eac46c256d7dcdb9b95d05de530bc
Bug: skia:12819
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/498716
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
John Stiles 2022-01-25 12:02:39 -05:00 committed by SkCQ
parent a9f7a8b617
commit 0f5bc280a0
2 changed files with 8 additions and 22 deletions

View File

@ -531,8 +531,14 @@ std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
#undef RESULT
}
// Perform constant folding on pairs of vectors.
if (leftType.isVector() && leftType.matches(rightType)) {
// Perform matrix * matrix multiplication.
if (op.kind() == Token::Kind::TK_STAR && leftType.isMatrix() && rightType.isMatrix()) {
// TODO(skia:12819): Implement matrix * matrix multiplication.
return nullptr;
}
// Perform constant folding on pairs of vectors/matrices.
if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
return simplify_componentwise(context, *left, op, *right);
}

View File

@ -14,32 +14,12 @@ bool test_matrix_op_scalar_half_b() {
}
bool test_matrix_op_matrix_float_b() {
bool ok = true;
const mat3 splat_4 = mat3(4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0);
ok = ok && mat3(2.0) + splat_4 == mat3(6.0, 4.0, 4.0, 4.0, 6.0, 4.0, 4.0, 4.0, 6.0);
ok = ok && mat3(2.0) - splat_4 == mat3(-2.0, -4.0, -4.0, -4.0, -2.0, -4.0, -4.0, -4.0, -2.0);
ok = ok && mat3(2.0) / splat_4 == mat3(0.5);
ok = ok && splat_4 + mat3(2.0) == mat3(6.0, 4.0, 4.0, 4.0, 6.0, 4.0, 4.0, 4.0, 6.0);
ok = ok && splat_4 - mat3(2.0) == mat3(2.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 4.0, 2.0);
ok = ok && mat2(4.0, 4.0, 4.0, 4.0) / mat2(2.0, 2.0, 2.0, 2.0) == mat2(2.0, 2.0, 2.0, 2.0);
ok = ok && mat4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0) + mat4(16.0, 15.0, 14.0, 13.0, 12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0) == mat4(17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0);
ok = ok && mat2(10.0, 20.0, 30.0, 40.0) - mat2(1.0, 2.0, 3.0, 4.0) == mat2(9.0, 18.0, 27.0, 36.0);
ok = ok && mat2(10.0, 20.0, 30.0, 40.0) / mat2(5.0, 4.0, 30.0, 1.0) == mat2(2.0, 5.0, 1.0, 40.0);
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;
}
bool test_matrix_op_matrix_half_b() {
bool ok = true;
const mat3 splat_4 = mat3(4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0);
ok = ok && mat3(2.0) + splat_4 == mat3(6.0, 4.0, 4.0, 4.0, 6.0, 4.0, 4.0, 4.0, 6.0);
ok = ok && mat3(2.0) - splat_4 == mat3(-2.0, -4.0, -4.0, -4.0, -2.0, -4.0, -4.0, -4.0, -2.0);
ok = ok && mat3(2.0) / splat_4 == mat3(0.5);
ok = ok && splat_4 + mat3(2.0) == mat3(6.0, 4.0, 4.0, 4.0, 6.0, 4.0, 4.0, 4.0, 6.0);
ok = ok && splat_4 - mat3(2.0) == mat3(2.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 4.0, 2.0);
ok = ok && mat2(4.0, 4.0, 4.0, 4.0) / mat2(2.0, 2.0, 2.0, 2.0) == mat2(2.0, 2.0, 2.0, 2.0);
ok = ok && mat4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0) + mat4(16.0, 15.0, 14.0, 13.0, 12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0) == mat4(17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0);
ok = ok && mat2(10.0, 20.0, 30.0, 40.0) - mat2(1.0, 2.0, 3.0, 4.0) == mat2(9.0, 18.0, 27.0, 36.0);
ok = ok && mat2(10.0, 20.0, 30.0, 40.0) / mat2(5.0, 4.0, 30.0, 1.0) == mat2(2.0, 5.0, 1.0, 40.0);
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;