Replace getMatExpression with getConstantSubexpression.

This approach gives us similar flexibility but requires fewer lines of
code to get the same result. In a followup CL we will be able to
eliminate get[BFI]VecExpression as well. This approach will also scale
to arrays and structs if we want to support constant-folding on these.

Change-Id: Ib0034935926c7004f84ba62ddbdb3168df8ce91d
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393076
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
John Stiles 2021-04-06 13:08:54 -04:00 committed by Skia Commit-Bot
parent 5abb9e1426
commit 7ec097c05f
14 changed files with 99 additions and 220 deletions

View File

@ -73,6 +73,11 @@ public:
return std::make_unique<BoolLiteral>(fOffset, this->value(), &this->type());
}
const Expression* getConstantSubexpression(int n) const override {
SkASSERT(n == 0);
return this;
}
private:
bool fValue;

View File

@ -175,65 +175,6 @@ std::unique_ptr<Expression> Constructor::MakeCompoundConstructor(const Context&
return std::make_unique<Constructor>(offset, type, std::move(args));
}
Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
if (other.is<ConstructorDiagonalMatrix>()) {
return other.compareConstant(*this);
}
if (other.is<ConstructorMatrixResize>()) {
return other.compareConstant(*this);
}
if (other.is<ConstructorSplat>()) {
return other.compareConstant(*this);
}
if (!other.is<Constructor>()) {
return ComparisonResult::kUnknown;
}
const Constructor& c = other.as<Constructor>();
const Type& myType = this->type();
SkASSERT(myType == c.type());
if (myType.isVector()) {
if (myType.componentType().isFloat()) {
for (int i = 0; i < myType.columns(); i++) {
if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
return ComparisonResult::kNotEqual;
}
}
return ComparisonResult::kEqual;
}
if (myType.componentType().isInteger()) {
for (int i = 0; i < myType.columns(); i++) {
if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
return ComparisonResult::kNotEqual;
}
}
return ComparisonResult::kEqual;
}
if (myType.componentType().isBoolean()) {
for (int i = 0; i < myType.columns(); i++) {
if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
return ComparisonResult::kNotEqual;
}
}
return ComparisonResult::kEqual;
}
}
if (myType.isMatrix()) {
for (int col = 0; col < myType.columns(); col++) {
for (int row = 0; row < myType.rows(); row++) {
if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
return ComparisonResult::kNotEqual;
}
}
}
return ComparisonResult::kEqual;
}
SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
return ComparisonResult::kUnknown;
}
template <typename ResultType>
ResultType Constructor::getConstantValue(const Expression& expr) const {
const Type& type = expr.type();
@ -313,50 +254,6 @@ template SKSL_INT Constructor::getVecComponent(int) const;
template SKSL_FLOAT Constructor::getVecComponent(int) const;
template bool Constructor::getVecComponent(int) const;
SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
SkDEBUGCODE(const Type& myType = this->type();)
SkASSERT(this->isCompileTimeConstant());
SkASSERT(myType.isMatrix());
SkASSERT(col < myType.columns() && row < myType.rows());
if (this->arguments().size() == 1) {
const Type& argType = this->arguments()[0]->type();
if (argType.isScalar()) {
// single scalar argument, so matrix is of the form:
// x 0 0
// 0 x 0
// 0 0 x
// return x if col == row
return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
}
if (argType.isMatrix()) {
SkASSERT(this->arguments()[0]->isAnyConstructor());
// single matrix argument. make sure we're within the argument's bounds.
if (col < argType.columns() && row < argType.rows()) {
// within bounds, defer to argument
return this->arguments()[0]->getMatComponent(col, row);
}
// out of bounds
return 0.0;
}
}
int currentIndex = 0;
int targetIndex = col * this->type().rows() + row;
for (const auto& arg : this->arguments()) {
const Type& argType = arg->type();
SkASSERT(targetIndex >= currentIndex);
SkASSERT(argType.rows() == 1);
if (currentIndex + argType.columns() > targetIndex) {
if (argType.columns() == 1) {
return arg->getConstantFloat();
} else {
return arg->getFVecComponent(targetIndex - currentIndex);
}
}
currentIndex += argType.columns();
}
SK_ABORT("can't happen, matrix component out of bounds");
}
SKSL_INT Constructor::getConstantInt() const {
// We're looking for scalar integer constructors only, i.e. `int(1)`.
SkASSERT(this->arguments().size() == 1);
@ -399,6 +296,45 @@ bool Constructor::getConstantBool() const {
(bool)expr.getConstantFloat();
}
const Expression* AnyConstructor::getConstantSubexpression(int n) const {
SkASSERT(n >= 0 && n < (int)this->type().slotCount());
for (const std::unique_ptr<Expression>& arg : this->argumentSpan()) {
int argSlots = arg->type().slotCount();
if (n < argSlots) {
return arg->getConstantSubexpression(n);
}
n -= argSlots;
}
SkDEBUGFAIL("argument-list slot count doesn't match constructor-type slot count");
return nullptr;
}
Expression::ComparisonResult AnyConstructor::compareConstant(const Expression& other) const {
ComparisonResult result = ComparisonResult::kEqual;
SkASSERT(this->type().slotCount() == other.type().slotCount());
int exprs = this->type().slotCount();
for (int n = 0; n < exprs; ++n) {
// Get the n'th subexpression from each side. If either one is null, return "unknown."
const Expression* left = this->getConstantSubexpression(n);
if (!left) {
return ComparisonResult::kUnknown;
}
const Expression* right = other.getConstantSubexpression(n);
if (!right) {
return ComparisonResult::kUnknown;
}
// Recurse into the subexpressions; the literal types will perform real comparisons, and
// most other expressions fall back on the base class Expression which returns unknown.
result = left->compareConstant(*right);
if (result != ComparisonResult::kEqual) {
break;
}
}
return result;
}
AnyConstructor& Expression::asAnyConstructor() {
SkASSERT(this->isAnyConstructor());
return static_cast<AnyConstructor&>(*this);

View File

@ -71,6 +71,10 @@ public:
return true;
}
const Expression* getConstantSubexpression(int n) const override;
ComparisonResult compareConstant(const Expression& other) const override;
private:
std::unique_ptr<Expression> fArgument;
@ -183,8 +187,6 @@ public:
return std::make_unique<Constructor>(fOffset, this->type(), this->cloneArguments());
}
ComparisonResult compareConstant(const Expression& other) const override;
template <typename ResultType>
ResultType getVecComponent(int index) const;
@ -215,8 +217,6 @@ public:
return this->getVecComponent<bool>(n);
}
SKSL_FLOAT getMatComponent(int col, int row) const override;
SKSL_INT getConstantInt() const override;
SKSL_FLOAT getConstantFloat() const override;

View File

@ -56,22 +56,4 @@ std::unique_ptr<Expression> ConstructorArray::Make(const Context& context,
return std::make_unique<ConstructorArray>(offset, type, std::move(args));
}
Expression::ComparisonResult ConstructorArray::compareConstant(const Expression& other) const {
// There is only one array-constructor type, so if this comparison had type-checked
// successfully, `other` should be a ConstructorArray with the same array size.
const ConstructorArray& otherArray = other.as<ConstructorArray>();
int numColumns = this->type().columns();
SkASSERT(numColumns == otherArray.type().columns());
ComparisonResult check = ComparisonResult::kEqual;
for (int index = 0; index < numColumns; index++) {
check = this->arguments()[index]->compareConstant(*otherArray.arguments()[index]);
if (check != ComparisonResult::kEqual) {
break;
}
}
return check;
}
} // namespace SkSL

View File

@ -39,8 +39,6 @@ public:
return std::make_unique<ConstructorArray>(fOffset, this->type(), this->cloneArguments());
}
ComparisonResult compareConstant(const Expression& other) const override;
private:
using INHERITED = MultiArgumentConstructor;
};

View File

@ -21,35 +21,17 @@ std::unique_ptr<Expression> ConstructorDiagonalMatrix::Make(const Context& conte
return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
}
Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
const Expression& other) const {
SkASSERT(other.type().isMatrix());
SkASSERT(this->type() == other.type());
const Expression* ConstructorDiagonalMatrix::getConstantSubexpression(int n) const {
int rows = this->type().rows();
int row = n % rows;
int col = n / rows;
// The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
for (int col = 0; col < this->type().columns(); col++) {
for (int row = 0; row < this->type().rows(); row++) {
if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
return ComparisonResult::kNotEqual;
}
}
}
return ComparisonResult::kEqual;
}
SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
SkASSERT(this->isCompileTimeConstant());
SkASSERT(col >= 0);
SkASSERT(row >= 0);
SkASSERT(col < this->type().columns());
SkASSERT(row < this->type().rows());
// Our matrix is of the form:
// |x 0 0|
// |0 x 0|
// |0 0 x|
return (col == row) ? this->argument()->getConstantFloat() : 0.0;
return (col == row) ? this->argument()->getConstantSubexpression(0) : &fZeroLiteral;
}
} // namespace SkSL

View File

@ -27,7 +27,8 @@ public:
static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
: INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
: INHERITED(offset, kExpressionKind, &type, std::move(arg))
, fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) {}
static std::unique_ptr<Expression> Make(const Context& context,
int offset,
@ -39,11 +40,10 @@ public:
argument()->clone());
}
ComparisonResult compareConstant(const Expression& other) const override;
SKSL_FLOAT getMatComponent(int col, int row) const override;
const Expression* getConstantSubexpression(int n) const override;
private:
const FloatLiteral fZeroLiteral;
using INHERITED = SingleArgumentConstructor;
};

View File

@ -27,27 +27,11 @@ std::unique_ptr<Expression> ConstructorMatrixResize::Make(const Context& context
return std::make_unique<ConstructorMatrixResize>(offset, type, std::move(arg));
}
Expression::ComparisonResult ConstructorMatrixResize::compareConstant(
const Expression& other) const {
SkASSERT(other.type().isMatrix());
SkASSERT(this->type() == other.type());
SkASSERT(this->type().rows() == other.type().rows());
SkASSERT(this->type().columns() == other.type().columns());
const Expression* ConstructorMatrixResize::getConstantSubexpression(int n) const {
int rows = this->type().rows();
int row = n % rows;
int col = n / rows;
// Check each cell individually.
for (int col = 0; col < this->type().columns(); col++) {
for (int row = 0; row < this->type().rows(); row++) {
if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
return ComparisonResult::kNotEqual;
}
}
}
return ComparisonResult::kEqual;
}
SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
SkASSERT(this->isCompileTimeConstant());
SkASSERT(col >= 0);
SkASSERT(row >= 0);
SkASSERT(col < this->type().columns());
@ -59,13 +43,15 @@ SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
// |0 0 1|
// Where `m` is the matrix being wrapped, and other cells contain the identity matrix.
// Forward `getMatComponent` to the wrapped matrix if the position is in its bounds.
// Forward `getConstantSubexpression` to the wrapped matrix if the position is in its bounds.
if (col < this->argument()->type().columns() && row < this->argument()->type().rows()) {
return this->argument()->getMatComponent(col, row);
// Recalculate `n` in terms of the inner matrix's dimensions.
n = row + (col * this->argument()->type().rows());
return this->argument()->getConstantSubexpression(n);
}
// Synthesize an identity matrix for out-of-bounds positions.
return (col == row) ? 1.0f : 0.0f;
return (col == row) ? &fOneLiteral : &fZeroLiteral;
}
} // namespace SkSL

View File

@ -28,7 +28,9 @@ public:
static constexpr Kind kExpressionKind = Kind::kConstructorMatrixResize;
ConstructorMatrixResize(int offset, const Type& type, std::unique_ptr<Expression> arg)
: INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
: INHERITED(offset, kExpressionKind, &type, std::move(arg))
, fZeroLiteral(offset, /*value=*/0.0f, &type.componentType())
, fOneLiteral(offset, /*value=*/1.0f, &type.componentType()) {}
static std::unique_ptr<Expression> Make(const Context& context,
int offset,
@ -40,12 +42,12 @@ public:
argument()->clone());
}
Expression::ComparisonResult compareConstant(const Expression& other) const override;
SKSL_FLOAT getMatComponent(int col, int row) const override;
const Expression* getConstantSubexpression(int n) const override;
private:
using INHERITED = SingleArgumentConstructor;
const FloatLiteral fZeroLiteral;
const FloatLiteral fOneLiteral;
};
} // namespace SkSL

View File

@ -25,29 +25,4 @@ std::unique_ptr<Expression> ConstructorSplat::Make(const Context& context,
return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
}
Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
SkASSERT(this->type() == other.type());
if (!other.isAnyConstructor()) {
return ComparisonResult::kUnknown;
}
return this->compareConstantConstructor(other.asAnyConstructor());
}
Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
const AnyConstructor& other) const {
ComparisonResult check = ComparisonResult::kEqual;
for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
// We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
check = expr->isAnyConstructor()
? this->compareConstantConstructor(expr->asAnyConstructor())
: argument()->compareConstant(*expr);
if (check != ComparisonResult::kEqual) {
break;
}
}
return check;
}
} // namespace SkSL

View File

@ -38,8 +38,6 @@ public:
return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
}
ComparisonResult compareConstant(const Expression& other) const override;
SKSL_FLOAT getFVecComponent(int) const override {
return this->argument()->getConstantFloat();
}
@ -52,9 +50,12 @@ public:
return this->argument()->getConstantBool();
}
private:
Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
const Expression* getConstantSubexpression(int n) const override {
SkASSERT(n >= 0 && n < this->type().columns());
return this->argument()->getConstantSubexpression(0);
}
private:
using INHERITED = SingleArgumentConstructor;
};

View File

@ -178,6 +178,18 @@ public:
return this->type().coercionCost(target);
}
/**
* Returns the n'th compile-time constant expression within a literal or constructor.
* Use Type::slotCount to determine the number of subexpressions within an expression.
* Subexpressions which are not compile-time constants will return null.
* `vec4(1, vec2(2), 3)` contains four subexpressions: (1, 2, 2, 3)
* `mat2(f)` contains four subexpressions: (null, 0,
* 0, null)
*/
virtual const Expression* getConstantSubexpression(int n) const {
return nullptr;
}
/**
* For a vector of floating point values, return the value of the n'th vector component. It is
* an error to call this method on an expression which is not a vector of floating-point
@ -215,16 +227,6 @@ public:
*/
template <typename T> T getVecComponent(int index) const;
/**
* For a literal matrix expression, return the floating point value of the component at
* [col][row]. It is an error to call this method on an expression which is not a literal
* matrix.
*/
virtual SKSL_FLOAT getMatComponent(int col, int row) const {
SkASSERT(false);
return 0;
}
virtual std::unique_ptr<Expression> clone() const = 0;
private:

View File

@ -79,6 +79,11 @@ public:
return std::make_unique<FloatLiteral>(fOffset, this->value(), &this->type());
}
const Expression* getConstantSubexpression(int n) const override {
SkASSERT(n == 0);
return this;
}
private:
float fValue;

View File

@ -81,6 +81,11 @@ public:
return std::make_unique<IntLiteral>(fOffset, this->value(), &this->type());
}
const Expression* getConstantSubexpression(int n) const override {
SkASSERT(n == 0);
return this;
}
private:
SKSL_INT fValue;