Add support for constant folding of matrix-times-matrix.

This code should be easily adaptable to matrix-times-vector as well;
just treat the vector as a 1-row or 1-column matrix. I haven't gotten
around to writing tests for this, though.

Change-Id: If59ae52cd12952b44d3574d54398b2dc66edbcc8
Bug: skia:12819
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/505221
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
This commit is contained in:
John Stiles 2022-02-07 14:27:23 -05:00 committed by SkCQ
parent d0234ba3bf
commit f1bb464ee4
3 changed files with 53 additions and 8 deletions

View File

@ -94,6 +94,58 @@ static std::unique_ptr<Expression> simplify_constant_equality(const Context& con
return nullptr;
}
static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
const Expression& left,
const Expression& right) {
const Type& leftType = left.type();
const Type& rightType = right.type();
SkASSERT(leftType.isMatrix());
SkASSERT(rightType.isMatrix());
const Type& componentType = leftType.componentType();
SkASSERT(componentType.matches(rightType.componentType()));
const int leftColumns = leftType.columns(),
leftRows = leftType.rows(),
rightColumns = rightType.columns(),
rightRows = rightType.rows(),
outColumns = rightColumns,
outRows = leftRows;
SkASSERT(leftColumns == rightRows);
const Type& resultType = componentType.toCompound(context, outColumns, outRows);
// Fetch the left matrix.
double leftVals[4][4];
for (int c = 0; c < leftColumns; ++c) {
for (int r = 0; r < leftRows; ++r) {
leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
}
}
// Fetch the right matrix.
double rightVals[4][4];
for (int c = 0; c < rightColumns; ++c) {
for (int r = 0; r < rightRows; ++r) {
rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
}
}
ExpressionArray args;
args.reserve_back(outColumns * outRows);
for (int c = 0; c < outColumns; ++c) {
for (int r = 0; r < outRows; ++r) {
// Compute a dot product for this position.
double val = 0;
for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
}
args.push_back(Literal::Make(left.fLine, val, &componentType));
}
}
return ConstructorCompound::Make(context, left.fLine, resultType, std::move(args));
}
static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
const Expression& left,
Operator op,
@ -533,8 +585,7 @@ std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
// Perform matrix * matrix multiplication.
if (op.kind() == Token::Kind::TK_STAR && leftType.isMatrix() && rightType.isMatrix()) {
// TODO(skia:12819): Implement matrix * matrix multiplication.
return nullptr;
return simplify_matrix_times_matrix(context, *left, *right);
}
// Perform constant folding on pairs of vectors/matrices.

View File

@ -14,14 +14,10 @@ bool test_matrix_op_scalar_half_b() {
}
bool test_matrix_op_matrix_float_b() {
bool ok = true;
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;
}
bool test_matrix_op_matrix_half_b() {
bool ok = true;
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;
}
vec4 main() {

View File

@ -8,12 +8,10 @@ bool test_eq_half_b() {
}
bool test_matrix_op_matrix_float_b() {
bool ok = true;
ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
return ok;
}
bool test_matrix_op_matrix_half_b() {
bool ok = true;
ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
return ok;
}
vec4 main() {