Replace getMatExpression with getConstantSubexpression.
This approach gives us similar flexibility but requires fewer lines of code to get the same result. In a followup CL we will be able to eliminate get[BFI]VecExpression as well. This approach will also scale to arrays and structs if we want to support constant-folding on these. Change-Id: Ib0034935926c7004f84ba62ddbdb3168df8ce91d Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393076 Commit-Queue: John Stiles <johnstiles@google.com> Auto-Submit: John Stiles <johnstiles@google.com> Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
parent
5abb9e1426
commit
7ec097c05f
@ -73,6 +73,11 @@ public:
|
||||
return std::make_unique<BoolLiteral>(fOffset, this->value(), &this->type());
|
||||
}
|
||||
|
||||
const Expression* getConstantSubexpression(int n) const override {
|
||||
SkASSERT(n == 0);
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
bool fValue;
|
||||
|
||||
|
@ -175,65 +175,6 @@ std::unique_ptr<Expression> Constructor::MakeCompoundConstructor(const Context&
|
||||
return std::make_unique<Constructor>(offset, type, std::move(args));
|
||||
}
|
||||
|
||||
Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
|
||||
if (other.is<ConstructorDiagonalMatrix>()) {
|
||||
return other.compareConstant(*this);
|
||||
}
|
||||
if (other.is<ConstructorMatrixResize>()) {
|
||||
return other.compareConstant(*this);
|
||||
}
|
||||
if (other.is<ConstructorSplat>()) {
|
||||
return other.compareConstant(*this);
|
||||
}
|
||||
if (!other.is<Constructor>()) {
|
||||
return ComparisonResult::kUnknown;
|
||||
}
|
||||
const Constructor& c = other.as<Constructor>();
|
||||
const Type& myType = this->type();
|
||||
SkASSERT(myType == c.type());
|
||||
|
||||
if (myType.isVector()) {
|
||||
if (myType.componentType().isFloat()) {
|
||||
for (int i = 0; i < myType.columns(); i++) {
|
||||
if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
if (myType.componentType().isInteger()) {
|
||||
for (int i = 0; i < myType.columns(); i++) {
|
||||
if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
if (myType.componentType().isBoolean()) {
|
||||
for (int i = 0; i < myType.columns(); i++) {
|
||||
if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
}
|
||||
|
||||
if (myType.isMatrix()) {
|
||||
for (int col = 0; col < myType.columns(); col++) {
|
||||
for (int row = 0; row < myType.rows(); row++) {
|
||||
if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
|
||||
SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
|
||||
return ComparisonResult::kUnknown;
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
ResultType Constructor::getConstantValue(const Expression& expr) const {
|
||||
const Type& type = expr.type();
|
||||
@ -313,50 +254,6 @@ template SKSL_INT Constructor::getVecComponent(int) const;
|
||||
template SKSL_FLOAT Constructor::getVecComponent(int) const;
|
||||
template bool Constructor::getVecComponent(int) const;
|
||||
|
||||
SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
|
||||
SkDEBUGCODE(const Type& myType = this->type();)
|
||||
SkASSERT(this->isCompileTimeConstant());
|
||||
SkASSERT(myType.isMatrix());
|
||||
SkASSERT(col < myType.columns() && row < myType.rows());
|
||||
if (this->arguments().size() == 1) {
|
||||
const Type& argType = this->arguments()[0]->type();
|
||||
if (argType.isScalar()) {
|
||||
// single scalar argument, so matrix is of the form:
|
||||
// x 0 0
|
||||
// 0 x 0
|
||||
// 0 0 x
|
||||
// return x if col == row
|
||||
return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
|
||||
}
|
||||
if (argType.isMatrix()) {
|
||||
SkASSERT(this->arguments()[0]->isAnyConstructor());
|
||||
// single matrix argument. make sure we're within the argument's bounds.
|
||||
if (col < argType.columns() && row < argType.rows()) {
|
||||
// within bounds, defer to argument
|
||||
return this->arguments()[0]->getMatComponent(col, row);
|
||||
}
|
||||
// out of bounds
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
int currentIndex = 0;
|
||||
int targetIndex = col * this->type().rows() + row;
|
||||
for (const auto& arg : this->arguments()) {
|
||||
const Type& argType = arg->type();
|
||||
SkASSERT(targetIndex >= currentIndex);
|
||||
SkASSERT(argType.rows() == 1);
|
||||
if (currentIndex + argType.columns() > targetIndex) {
|
||||
if (argType.columns() == 1) {
|
||||
return arg->getConstantFloat();
|
||||
} else {
|
||||
return arg->getFVecComponent(targetIndex - currentIndex);
|
||||
}
|
||||
}
|
||||
currentIndex += argType.columns();
|
||||
}
|
||||
SK_ABORT("can't happen, matrix component out of bounds");
|
||||
}
|
||||
|
||||
SKSL_INT Constructor::getConstantInt() const {
|
||||
// We're looking for scalar integer constructors only, i.e. `int(1)`.
|
||||
SkASSERT(this->arguments().size() == 1);
|
||||
@ -399,6 +296,45 @@ bool Constructor::getConstantBool() const {
|
||||
(bool)expr.getConstantFloat();
|
||||
}
|
||||
|
||||
const Expression* AnyConstructor::getConstantSubexpression(int n) const {
|
||||
SkASSERT(n >= 0 && n < (int)this->type().slotCount());
|
||||
for (const std::unique_ptr<Expression>& arg : this->argumentSpan()) {
|
||||
int argSlots = arg->type().slotCount();
|
||||
if (n < argSlots) {
|
||||
return arg->getConstantSubexpression(n);
|
||||
}
|
||||
n -= argSlots;
|
||||
}
|
||||
|
||||
SkDEBUGFAIL("argument-list slot count doesn't match constructor-type slot count");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Expression::ComparisonResult AnyConstructor::compareConstant(const Expression& other) const {
|
||||
ComparisonResult result = ComparisonResult::kEqual;
|
||||
SkASSERT(this->type().slotCount() == other.type().slotCount());
|
||||
|
||||
int exprs = this->type().slotCount();
|
||||
for (int n = 0; n < exprs; ++n) {
|
||||
// Get the n'th subexpression from each side. If either one is null, return "unknown."
|
||||
const Expression* left = this->getConstantSubexpression(n);
|
||||
if (!left) {
|
||||
return ComparisonResult::kUnknown;
|
||||
}
|
||||
const Expression* right = other.getConstantSubexpression(n);
|
||||
if (!right) {
|
||||
return ComparisonResult::kUnknown;
|
||||
}
|
||||
// Recurse into the subexpressions; the literal types will perform real comparisons, and
|
||||
// most other expressions fall back on the base class Expression which returns unknown.
|
||||
result = left->compareConstant(*right);
|
||||
if (result != ComparisonResult::kEqual) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
AnyConstructor& Expression::asAnyConstructor() {
|
||||
SkASSERT(this->isAnyConstructor());
|
||||
return static_cast<AnyConstructor&>(*this);
|
||||
|
@ -71,6 +71,10 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
const Expression* getConstantSubexpression(int n) const override;
|
||||
|
||||
ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Expression> fArgument;
|
||||
|
||||
@ -183,8 +187,6 @@ public:
|
||||
return std::make_unique<Constructor>(fOffset, this->type(), this->cloneArguments());
|
||||
}
|
||||
|
||||
ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
template <typename ResultType>
|
||||
ResultType getVecComponent(int index) const;
|
||||
|
||||
@ -215,8 +217,6 @@ public:
|
||||
return this->getVecComponent<bool>(n);
|
||||
}
|
||||
|
||||
SKSL_FLOAT getMatComponent(int col, int row) const override;
|
||||
|
||||
SKSL_INT getConstantInt() const override;
|
||||
|
||||
SKSL_FLOAT getConstantFloat() const override;
|
||||
|
@ -56,22 +56,4 @@ std::unique_ptr<Expression> ConstructorArray::Make(const Context& context,
|
||||
return std::make_unique<ConstructorArray>(offset, type, std::move(args));
|
||||
}
|
||||
|
||||
Expression::ComparisonResult ConstructorArray::compareConstant(const Expression& other) const {
|
||||
// There is only one array-constructor type, so if this comparison had type-checked
|
||||
// successfully, `other` should be a ConstructorArray with the same array size.
|
||||
const ConstructorArray& otherArray = other.as<ConstructorArray>();
|
||||
int numColumns = this->type().columns();
|
||||
SkASSERT(numColumns == otherArray.type().columns());
|
||||
|
||||
ComparisonResult check = ComparisonResult::kEqual;
|
||||
for (int index = 0; index < numColumns; index++) {
|
||||
check = this->arguments()[index]->compareConstant(*otherArray.arguments()[index]);
|
||||
if (check != ComparisonResult::kEqual) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return check;
|
||||
}
|
||||
|
||||
} // namespace SkSL
|
||||
|
@ -39,8 +39,6 @@ public:
|
||||
return std::make_unique<ConstructorArray>(fOffset, this->type(), this->cloneArguments());
|
||||
}
|
||||
|
||||
ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
private:
|
||||
using INHERITED = MultiArgumentConstructor;
|
||||
};
|
||||
|
@ -21,35 +21,17 @@ std::unique_ptr<Expression> ConstructorDiagonalMatrix::Make(const Context& conte
|
||||
return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
|
||||
}
|
||||
|
||||
Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
|
||||
const Expression& other) const {
|
||||
SkASSERT(other.type().isMatrix());
|
||||
SkASSERT(this->type() == other.type());
|
||||
const Expression* ConstructorDiagonalMatrix::getConstantSubexpression(int n) const {
|
||||
int rows = this->type().rows();
|
||||
int row = n % rows;
|
||||
int col = n / rows;
|
||||
|
||||
// The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
|
||||
for (int col = 0; col < this->type().columns(); col++) {
|
||||
for (int row = 0; row < this->type().rows(); row++) {
|
||||
if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
|
||||
SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
|
||||
SkASSERT(this->isCompileTimeConstant());
|
||||
SkASSERT(col >= 0);
|
||||
SkASSERT(row >= 0);
|
||||
SkASSERT(col < this->type().columns());
|
||||
SkASSERT(row < this->type().rows());
|
||||
|
||||
// Our matrix is of the form:
|
||||
// |x 0 0|
|
||||
// |0 x 0|
|
||||
// |0 0 x|
|
||||
return (col == row) ? this->argument()->getConstantFloat() : 0.0;
|
||||
return (col == row) ? this->argument()->getConstantSubexpression(0) : &fZeroLiteral;
|
||||
}
|
||||
|
||||
} // namespace SkSL
|
||||
|
@ -27,7 +27,8 @@ public:
|
||||
static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
|
||||
|
||||
ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
|
||||
: INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
|
||||
: INHERITED(offset, kExpressionKind, &type, std::move(arg))
|
||||
, fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) {}
|
||||
|
||||
static std::unique_ptr<Expression> Make(const Context& context,
|
||||
int offset,
|
||||
@ -39,11 +40,10 @@ public:
|
||||
argument()->clone());
|
||||
}
|
||||
|
||||
ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
SKSL_FLOAT getMatComponent(int col, int row) const override;
|
||||
const Expression* getConstantSubexpression(int n) const override;
|
||||
|
||||
private:
|
||||
const FloatLiteral fZeroLiteral;
|
||||
using INHERITED = SingleArgumentConstructor;
|
||||
};
|
||||
|
||||
|
@ -27,27 +27,11 @@ std::unique_ptr<Expression> ConstructorMatrixResize::Make(const Context& context
|
||||
return std::make_unique<ConstructorMatrixResize>(offset, type, std::move(arg));
|
||||
}
|
||||
|
||||
Expression::ComparisonResult ConstructorMatrixResize::compareConstant(
|
||||
const Expression& other) const {
|
||||
SkASSERT(other.type().isMatrix());
|
||||
SkASSERT(this->type() == other.type());
|
||||
SkASSERT(this->type().rows() == other.type().rows());
|
||||
SkASSERT(this->type().columns() == other.type().columns());
|
||||
const Expression* ConstructorMatrixResize::getConstantSubexpression(int n) const {
|
||||
int rows = this->type().rows();
|
||||
int row = n % rows;
|
||||
int col = n / rows;
|
||||
|
||||
// Check each cell individually.
|
||||
for (int col = 0; col < this->type().columns(); col++) {
|
||||
for (int row = 0; row < this->type().rows(); row++) {
|
||||
if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
|
||||
return ComparisonResult::kNotEqual;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ComparisonResult::kEqual;
|
||||
}
|
||||
|
||||
SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
|
||||
SkASSERT(this->isCompileTimeConstant());
|
||||
SkASSERT(col >= 0);
|
||||
SkASSERT(row >= 0);
|
||||
SkASSERT(col < this->type().columns());
|
||||
@ -59,13 +43,15 @@ SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
|
||||
// |0 0 1|
|
||||
// Where `m` is the matrix being wrapped, and other cells contain the identity matrix.
|
||||
|
||||
// Forward `getMatComponent` to the wrapped matrix if the position is in its bounds.
|
||||
// Forward `getConstantSubexpression` to the wrapped matrix if the position is in its bounds.
|
||||
if (col < this->argument()->type().columns() && row < this->argument()->type().rows()) {
|
||||
return this->argument()->getMatComponent(col, row);
|
||||
// Recalculate `n` in terms of the inner matrix's dimensions.
|
||||
n = row + (col * this->argument()->type().rows());
|
||||
return this->argument()->getConstantSubexpression(n);
|
||||
}
|
||||
|
||||
// Synthesize an identity matrix for out-of-bounds positions.
|
||||
return (col == row) ? 1.0f : 0.0f;
|
||||
return (col == row) ? &fOneLiteral : &fZeroLiteral;
|
||||
}
|
||||
|
||||
} // namespace SkSL
|
||||
|
@ -28,7 +28,9 @@ public:
|
||||
static constexpr Kind kExpressionKind = Kind::kConstructorMatrixResize;
|
||||
|
||||
ConstructorMatrixResize(int offset, const Type& type, std::unique_ptr<Expression> arg)
|
||||
: INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
|
||||
: INHERITED(offset, kExpressionKind, &type, std::move(arg))
|
||||
, fZeroLiteral(offset, /*value=*/0.0f, &type.componentType())
|
||||
, fOneLiteral(offset, /*value=*/1.0f, &type.componentType()) {}
|
||||
|
||||
static std::unique_ptr<Expression> Make(const Context& context,
|
||||
int offset,
|
||||
@ -40,12 +42,12 @@ public:
|
||||
argument()->clone());
|
||||
}
|
||||
|
||||
Expression::ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
SKSL_FLOAT getMatComponent(int col, int row) const override;
|
||||
const Expression* getConstantSubexpression(int n) const override;
|
||||
|
||||
private:
|
||||
using INHERITED = SingleArgumentConstructor;
|
||||
const FloatLiteral fZeroLiteral;
|
||||
const FloatLiteral fOneLiteral;
|
||||
};
|
||||
|
||||
} // namespace SkSL
|
||||
|
@ -25,29 +25,4 @@ std::unique_ptr<Expression> ConstructorSplat::Make(const Context& context,
|
||||
return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
|
||||
}
|
||||
|
||||
Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
|
||||
SkASSERT(this->type() == other.type());
|
||||
if (!other.isAnyConstructor()) {
|
||||
return ComparisonResult::kUnknown;
|
||||
}
|
||||
|
||||
return this->compareConstantConstructor(other.asAnyConstructor());
|
||||
}
|
||||
|
||||
Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
|
||||
const AnyConstructor& other) const {
|
||||
ComparisonResult check = ComparisonResult::kEqual;
|
||||
for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
|
||||
// We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
|
||||
check = expr->isAnyConstructor()
|
||||
? this->compareConstantConstructor(expr->asAnyConstructor())
|
||||
: argument()->compareConstant(*expr);
|
||||
if (check != ComparisonResult::kEqual) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return check;
|
||||
}
|
||||
|
||||
} // namespace SkSL
|
||||
|
@ -38,8 +38,6 @@ public:
|
||||
return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
|
||||
}
|
||||
|
||||
ComparisonResult compareConstant(const Expression& other) const override;
|
||||
|
||||
SKSL_FLOAT getFVecComponent(int) const override {
|
||||
return this->argument()->getConstantFloat();
|
||||
}
|
||||
@ -52,9 +50,12 @@ public:
|
||||
return this->argument()->getConstantBool();
|
||||
}
|
||||
|
||||
private:
|
||||
Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
|
||||
const Expression* getConstantSubexpression(int n) const override {
|
||||
SkASSERT(n >= 0 && n < this->type().columns());
|
||||
return this->argument()->getConstantSubexpression(0);
|
||||
}
|
||||
|
||||
private:
|
||||
using INHERITED = SingleArgumentConstructor;
|
||||
};
|
||||
|
||||
|
@ -178,6 +178,18 @@ public:
|
||||
return this->type().coercionCost(target);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the n'th compile-time constant expression within a literal or constructor.
|
||||
* Use Type::slotCount to determine the number of subexpressions within an expression.
|
||||
* Subexpressions which are not compile-time constants will return null.
|
||||
* `vec4(1, vec2(2), 3)` contains four subexpressions: (1, 2, 2, 3)
|
||||
* `mat2(f)` contains four subexpressions: (null, 0,
|
||||
* 0, null)
|
||||
*/
|
||||
virtual const Expression* getConstantSubexpression(int n) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* For a vector of floating point values, return the value of the n'th vector component. It is
|
||||
* an error to call this method on an expression which is not a vector of floating-point
|
||||
@ -215,16 +227,6 @@ public:
|
||||
*/
|
||||
template <typename T> T getVecComponent(int index) const;
|
||||
|
||||
/**
|
||||
* For a literal matrix expression, return the floating point value of the component at
|
||||
* [col][row]. It is an error to call this method on an expression which is not a literal
|
||||
* matrix.
|
||||
*/
|
||||
virtual SKSL_FLOAT getMatComponent(int col, int row) const {
|
||||
SkASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<Expression> clone() const = 0;
|
||||
|
||||
private:
|
||||
|
@ -79,6 +79,11 @@ public:
|
||||
return std::make_unique<FloatLiteral>(fOffset, this->value(), &this->type());
|
||||
}
|
||||
|
||||
const Expression* getConstantSubexpression(int n) const override {
|
||||
SkASSERT(n == 0);
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
float fValue;
|
||||
|
||||
|
@ -81,6 +81,11 @@ public:
|
||||
return std::make_unique<IntLiteral>(fOffset, this->value(), &this->type());
|
||||
}
|
||||
|
||||
const Expression* getConstantSubexpression(int n) const override {
|
||||
SkASSERT(n == 0);
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
SKSL_INT fValue;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user