Add support for constant folding of matrix-times-matrix.
This code should be easily adaptable to matrix-times-vector as well; just treat the vector as a 1-row or 1-column matrix. I haven't gotten around to writing tests for this, though. Change-Id: If59ae52cd12952b44d3574d54398b2dc66edbcc8 Bug: skia:12819 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/505221 Auto-Submit: John Stiles <johnstiles@google.com> Reviewed-by: Brian Osman <brianosman@google.com> Reviewed-by: Ethan Nicholas <ethannicholas@google.com> Commit-Queue: John Stiles <johnstiles@google.com>
This commit is contained in:
parent
d0234ba3bf
commit
f1bb464ee4
@ -94,6 +94,58 @@ static std::unique_ptr<Expression> simplify_constant_equality(const Context& con
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
|
||||
const Expression& left,
|
||||
const Expression& right) {
|
||||
const Type& leftType = left.type();
|
||||
const Type& rightType = right.type();
|
||||
|
||||
SkASSERT(leftType.isMatrix());
|
||||
SkASSERT(rightType.isMatrix());
|
||||
|
||||
const Type& componentType = leftType.componentType();
|
||||
SkASSERT(componentType.matches(rightType.componentType()));
|
||||
|
||||
const int leftColumns = leftType.columns(),
|
||||
leftRows = leftType.rows(),
|
||||
rightColumns = rightType.columns(),
|
||||
rightRows = rightType.rows(),
|
||||
outColumns = rightColumns,
|
||||
outRows = leftRows;
|
||||
SkASSERT(leftColumns == rightRows);
|
||||
const Type& resultType = componentType.toCompound(context, outColumns, outRows);
|
||||
|
||||
// Fetch the left matrix.
|
||||
double leftVals[4][4];
|
||||
for (int c = 0; c < leftColumns; ++c) {
|
||||
for (int r = 0; r < leftRows; ++r) {
|
||||
leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
|
||||
}
|
||||
}
|
||||
// Fetch the right matrix.
|
||||
double rightVals[4][4];
|
||||
for (int c = 0; c < rightColumns; ++c) {
|
||||
for (int r = 0; r < rightRows; ++r) {
|
||||
rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionArray args;
|
||||
args.reserve_back(outColumns * outRows);
|
||||
for (int c = 0; c < outColumns; ++c) {
|
||||
for (int r = 0; r < outRows; ++r) {
|
||||
// Compute a dot product for this position.
|
||||
double val = 0;
|
||||
for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
|
||||
val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
|
||||
}
|
||||
args.push_back(Literal::Make(left.fLine, val, &componentType));
|
||||
}
|
||||
}
|
||||
|
||||
return ConstructorCompound::Make(context, left.fLine, resultType, std::move(args));
|
||||
}
|
||||
|
||||
static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
|
||||
const Expression& left,
|
||||
Operator op,
|
||||
@ -533,8 +585,7 @@ std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
|
||||
|
||||
// Perform matrix * matrix multiplication.
|
||||
if (op.kind() == Token::Kind::TK_STAR && leftType.isMatrix() && rightType.isMatrix()) {
|
||||
// TODO(skia:12819): Implement matrix * matrix multiplication.
|
||||
return nullptr;
|
||||
return simplify_matrix_times_matrix(context, *left, *right);
|
||||
}
|
||||
|
||||
// Perform constant folding on pairs of vectors/matrices.
|
||||
|
@ -14,14 +14,10 @@ bool test_matrix_op_scalar_half_b() {
|
||||
}
|
||||
bool test_matrix_op_matrix_float_b() {
|
||||
bool ok = true;
|
||||
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
|
||||
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
|
||||
return ok;
|
||||
}
|
||||
bool test_matrix_op_matrix_half_b() {
|
||||
bool ok = true;
|
||||
ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
|
||||
ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
|
||||
return ok;
|
||||
}
|
||||
vec4 main() {
|
||||
|
@ -8,12 +8,10 @@ bool test_eq_half_b() {
|
||||
}
|
||||
bool test_matrix_op_matrix_float_b() {
|
||||
bool ok = true;
|
||||
ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
|
||||
return ok;
|
||||
}
|
||||
bool test_matrix_op_matrix_half_b() {
|
||||
bool ok = true;
|
||||
ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
|
||||
return ok;
|
||||
}
|
||||
vec4 main() {
|
||||
|
Loading…
Reference in New Issue
Block a user