diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h index bb086eb663..dbdd129f0f 100644 --- a/src/sksl/ir/SkSLBoolLiteral.h +++ b/src/sksl/ir/SkSLBoolLiteral.h @@ -73,6 +73,11 @@ public: return std::make_unique(fOffset, this->value(), &this->type()); } + const Expression* getConstantSubexpression(int n) const override { + SkASSERT(n == 0); + return this; + } + private: bool fValue; diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp index 9cfa6fed06..f489f80580 100644 --- a/src/sksl/ir/SkSLConstructor.cpp +++ b/src/sksl/ir/SkSLConstructor.cpp @@ -175,65 +175,6 @@ std::unique_ptr Constructor::MakeCompoundConstructor(const Context& return std::make_unique(offset, type, std::move(args)); } -Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const { - if (other.is()) { - return other.compareConstant(*this); - } - if (other.is()) { - return other.compareConstant(*this); - } - if (other.is()) { - return other.compareConstant(*this); - } - if (!other.is()) { - return ComparisonResult::kUnknown; - } - const Constructor& c = other.as(); - const Type& myType = this->type(); - SkASSERT(myType == c.type()); - - if (myType.isVector()) { - if (myType.componentType().isFloat()) { - for (int i = 0; i < myType.columns(); i++) { - if (this->getFVecComponent(i) != c.getFVecComponent(i)) { - return ComparisonResult::kNotEqual; - } - } - return ComparisonResult::kEqual; - } - if (myType.componentType().isInteger()) { - for (int i = 0; i < myType.columns(); i++) { - if (this->getIVecComponent(i) != c.getIVecComponent(i)) { - return ComparisonResult::kNotEqual; - } - } - return ComparisonResult::kEqual; - } - if (myType.componentType().isBoolean()) { - for (int i = 0; i < myType.columns(); i++) { - if (this->getBVecComponent(i) != c.getBVecComponent(i)) { - return ComparisonResult::kNotEqual; - } - } - return ComparisonResult::kEqual; - } - } - - if (myType.isMatrix()) { - for (int col = 0; col < myType.columns(); col++) { - for (int row = 0; row < myType.rows(); row++) { - if (getMatComponent(col, row) != c.getMatComponent(col, row)) { - return ComparisonResult::kNotEqual; - } - } - } - return ComparisonResult::kEqual; - } - - SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str()); - return ComparisonResult::kUnknown; -} - template ResultType Constructor::getConstantValue(const Expression& expr) const { const Type& type = expr.type(); @@ -313,50 +254,6 @@ template SKSL_INT Constructor::getVecComponent(int) const; template SKSL_FLOAT Constructor::getVecComponent(int) const; template bool Constructor::getVecComponent(int) const; -SKSL_FLOAT Constructor::getMatComponent(int col, int row) const { - SkDEBUGCODE(const Type& myType = this->type();) - SkASSERT(this->isCompileTimeConstant()); - SkASSERT(myType.isMatrix()); - SkASSERT(col < myType.columns() && row < myType.rows()); - if (this->arguments().size() == 1) { - const Type& argType = this->arguments()[0]->type(); - if (argType.isScalar()) { - // single scalar argument, so matrix is of the form: - // x 0 0 - // 0 x 0 - // 0 0 x - // return x if col == row - return col == row ? this->getConstantValue(*this->arguments()[0]) : 0.0; - } - if (argType.isMatrix()) { - SkASSERT(this->arguments()[0]->isAnyConstructor()); - // single matrix argument. make sure we're within the argument's bounds. - if (col < argType.columns() && row < argType.rows()) { - // within bounds, defer to argument - return this->arguments()[0]->getMatComponent(col, row); - } - // out of bounds - return 0.0; - } - } - int currentIndex = 0; - int targetIndex = col * this->type().rows() + row; - for (const auto& arg : this->arguments()) { - const Type& argType = arg->type(); - SkASSERT(targetIndex >= currentIndex); - SkASSERT(argType.rows() == 1); - if (currentIndex + argType.columns() > targetIndex) { - if (argType.columns() == 1) { - return arg->getConstantFloat(); - } else { - return arg->getFVecComponent(targetIndex - currentIndex); - } - } - currentIndex += argType.columns(); - } - SK_ABORT("can't happen, matrix component out of bounds"); -} - SKSL_INT Constructor::getConstantInt() const { // We're looking for scalar integer constructors only, i.e. `int(1)`. SkASSERT(this->arguments().size() == 1); @@ -399,6 +296,45 @@ bool Constructor::getConstantBool() const { (bool)expr.getConstantFloat(); } +const Expression* AnyConstructor::getConstantSubexpression(int n) const { + SkASSERT(n >= 0 && n < (int)this->type().slotCount()); + for (const std::unique_ptr& arg : this->argumentSpan()) { + int argSlots = arg->type().slotCount(); + if (n < argSlots) { + return arg->getConstantSubexpression(n); + } + n -= argSlots; + } + + SkDEBUGFAIL("argument-list slot count doesn't match constructor-type slot count"); + return nullptr; +} + +Expression::ComparisonResult AnyConstructor::compareConstant(const Expression& other) const { + ComparisonResult result = ComparisonResult::kEqual; + SkASSERT(this->type().slotCount() == other.type().slotCount()); + + int exprs = this->type().slotCount(); + for (int n = 0; n < exprs; ++n) { + // Get the n'th subexpression from each side. If either one is null, return "unknown." + const Expression* left = this->getConstantSubexpression(n); + if (!left) { + return ComparisonResult::kUnknown; + } + const Expression* right = other.getConstantSubexpression(n); + if (!right) { + return ComparisonResult::kUnknown; + } + // Recurse into the subexpressions; the literal types will perform real comparisons, and + // most other expressions fall back on the base class Expression which returns unknown. + result = left->compareConstant(*right); + if (result != ComparisonResult::kEqual) { + break; + } + } + return result; +} + AnyConstructor& Expression::asAnyConstructor() { SkASSERT(this->isAnyConstructor()); return static_cast(*this); diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index d7c6c3a5e4..7ff91c9e93 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -71,6 +71,10 @@ public: return true; } + const Expression* getConstantSubexpression(int n) const override; + + ComparisonResult compareConstant(const Expression& other) const override; + private: std::unique_ptr fArgument; @@ -183,8 +187,6 @@ public: return std::make_unique(fOffset, this->type(), this->cloneArguments()); } - ComparisonResult compareConstant(const Expression& other) const override; - template ResultType getVecComponent(int index) const; @@ -215,8 +217,6 @@ public: return this->getVecComponent(n); } - SKSL_FLOAT getMatComponent(int col, int row) const override; - SKSL_INT getConstantInt() const override; SKSL_FLOAT getConstantFloat() const override; diff --git a/src/sksl/ir/SkSLConstructorArray.cpp b/src/sksl/ir/SkSLConstructorArray.cpp index 4517b7549f..c305fb4745 100644 --- a/src/sksl/ir/SkSLConstructorArray.cpp +++ b/src/sksl/ir/SkSLConstructorArray.cpp @@ -56,22 +56,4 @@ std::unique_ptr ConstructorArray::Make(const Context& context, return std::make_unique(offset, type, std::move(args)); } -Expression::ComparisonResult ConstructorArray::compareConstant(const Expression& other) const { - // There is only one array-constructor type, so if this comparison had type-checked - // successfully, `other` should be a ConstructorArray with the same array size. - const ConstructorArray& otherArray = other.as(); - int numColumns = this->type().columns(); - SkASSERT(numColumns == otherArray.type().columns()); - - ComparisonResult check = ComparisonResult::kEqual; - for (int index = 0; index < numColumns; index++) { - check = this->arguments()[index]->compareConstant(*otherArray.arguments()[index]); - if (check != ComparisonResult::kEqual) { - break; - } - } - - return check; -} - } // namespace SkSL diff --git a/src/sksl/ir/SkSLConstructorArray.h b/src/sksl/ir/SkSLConstructorArray.h index a2b154dda3..139de5231f 100644 --- a/src/sksl/ir/SkSLConstructorArray.h +++ b/src/sksl/ir/SkSLConstructorArray.h @@ -39,8 +39,6 @@ public: return std::make_unique(fOffset, this->type(), this->cloneArguments()); } - ComparisonResult compareConstant(const Expression& other) const override; - private: using INHERITED = MultiArgumentConstructor; }; diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp index ccdcdefb8f..97cf760544 100644 --- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp +++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp @@ -21,35 +21,17 @@ std::unique_ptr ConstructorDiagonalMatrix::Make(const Context& conte return std::make_unique(offset, type, std::move(arg)); } -Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant( - const Expression& other) const { - SkASSERT(other.type().isMatrix()); - SkASSERT(this->type() == other.type()); +const Expression* ConstructorDiagonalMatrix::getConstantSubexpression(int n) const { + int rows = this->type().rows(); + int row = n % rows; + int col = n / rows; - // The other constructor might not be DiagonalMatrix-based, so we check each cell individually. - for (int col = 0; col < this->type().columns(); col++) { - for (int row = 0; row < this->type().rows(); row++) { - if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) { - return ComparisonResult::kNotEqual; - } - } - } - - return ComparisonResult::kEqual; -} - -SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const { - SkASSERT(this->isCompileTimeConstant()); SkASSERT(col >= 0); SkASSERT(row >= 0); SkASSERT(col < this->type().columns()); SkASSERT(row < this->type().rows()); - // Our matrix is of the form: - // |x 0 0| - // |0 x 0| - // |0 0 x| - return (col == row) ? this->argument()->getConstantFloat() : 0.0; + return (col == row) ? this->argument()->getConstantSubexpression(0) : &fZeroLiteral; } } // namespace SkSL diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h index 3e870510b0..b9fd229c19 100644 --- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h +++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h @@ -27,7 +27,8 @@ public: static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix; ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr arg) - : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {} + : INHERITED(offset, kExpressionKind, &type, std::move(arg)) + , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) {} static std::unique_ptr Make(const Context& context, int offset, @@ -39,11 +40,10 @@ public: argument()->clone()); } - ComparisonResult compareConstant(const Expression& other) const override; - - SKSL_FLOAT getMatComponent(int col, int row) const override; + const Expression* getConstantSubexpression(int n) const override; private: + const FloatLiteral fZeroLiteral; using INHERITED = SingleArgumentConstructor; }; diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.cpp b/src/sksl/ir/SkSLConstructorMatrixResize.cpp index e7e41f3c0f..4c8b3745dd 100644 --- a/src/sksl/ir/SkSLConstructorMatrixResize.cpp +++ b/src/sksl/ir/SkSLConstructorMatrixResize.cpp @@ -27,27 +27,11 @@ std::unique_ptr ConstructorMatrixResize::Make(const Context& context return std::make_unique(offset, type, std::move(arg)); } -Expression::ComparisonResult ConstructorMatrixResize::compareConstant( - const Expression& other) const { - SkASSERT(other.type().isMatrix()); - SkASSERT(this->type() == other.type()); - SkASSERT(this->type().rows() == other.type().rows()); - SkASSERT(this->type().columns() == other.type().columns()); +const Expression* ConstructorMatrixResize::getConstantSubexpression(int n) const { + int rows = this->type().rows(); + int row = n % rows; + int col = n / rows; - // Check each cell individually. - for (int col = 0; col < this->type().columns(); col++) { - for (int row = 0; row < this->type().rows(); row++) { - if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) { - return ComparisonResult::kNotEqual; - } - } - } - - return ComparisonResult::kEqual; -} - -SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const { - SkASSERT(this->isCompileTimeConstant()); SkASSERT(col >= 0); SkASSERT(row >= 0); SkASSERT(col < this->type().columns()); @@ -59,13 +43,15 @@ SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const { // |0 0 1| // Where `m` is the matrix being wrapped, and other cells contain the identity matrix. - // Forward `getMatComponent` to the wrapped matrix if the position is in its bounds. + // Forward `getConstantSubexpression` to the wrapped matrix if the position is in its bounds. if (col < this->argument()->type().columns() && row < this->argument()->type().rows()) { - return this->argument()->getMatComponent(col, row); + // Recalculate `n` in terms of the inner matrix's dimensions. + n = row + (col * this->argument()->type().rows()); + return this->argument()->getConstantSubexpression(n); } // Synthesize an identity matrix for out-of-bounds positions. - return (col == row) ? 1.0f : 0.0f; + return (col == row) ? &fOneLiteral : &fZeroLiteral; } } // namespace SkSL diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.h b/src/sksl/ir/SkSLConstructorMatrixResize.h index 2a18403864..00361808d6 100644 --- a/src/sksl/ir/SkSLConstructorMatrixResize.h +++ b/src/sksl/ir/SkSLConstructorMatrixResize.h @@ -28,7 +28,9 @@ public: static constexpr Kind kExpressionKind = Kind::kConstructorMatrixResize; ConstructorMatrixResize(int offset, const Type& type, std::unique_ptr arg) - : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {} + : INHERITED(offset, kExpressionKind, &type, std::move(arg)) + , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) + , fOneLiteral(offset, /*value=*/1.0f, &type.componentType()) {} static std::unique_ptr Make(const Context& context, int offset, @@ -40,12 +42,12 @@ public: argument()->clone()); } - Expression::ComparisonResult compareConstant(const Expression& other) const override; - - SKSL_FLOAT getMatComponent(int col, int row) const override; + const Expression* getConstantSubexpression(int n) const override; private: using INHERITED = SingleArgumentConstructor; + const FloatLiteral fZeroLiteral; + const FloatLiteral fOneLiteral; }; } // namespace SkSL diff --git a/src/sksl/ir/SkSLConstructorSplat.cpp b/src/sksl/ir/SkSLConstructorSplat.cpp index d8d0062f99..4d07288061 100644 --- a/src/sksl/ir/SkSLConstructorSplat.cpp +++ b/src/sksl/ir/SkSLConstructorSplat.cpp @@ -25,29 +25,4 @@ std::unique_ptr ConstructorSplat::Make(const Context& context, return std::make_unique(offset, type, std::move(arg)); } -Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const { - SkASSERT(this->type() == other.type()); - if (!other.isAnyConstructor()) { - return ComparisonResult::kUnknown; - } - - return this->compareConstantConstructor(other.asAnyConstructor()); -} - -Expression::ComparisonResult ConstructorSplat::compareConstantConstructor( - const AnyConstructor& other) const { - ComparisonResult check = ComparisonResult::kEqual; - for (const std::unique_ptr& expr : other.argumentSpan()) { - // We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)` - check = expr->isAnyConstructor() - ? this->compareConstantConstructor(expr->asAnyConstructor()) - : argument()->compareConstant(*expr); - if (check != ComparisonResult::kEqual) { - break; - } - } - - return check; -} - } // namespace SkSL diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h index 4ca3a076c5..5fb2a454a4 100644 --- a/src/sksl/ir/SkSLConstructorSplat.h +++ b/src/sksl/ir/SkSLConstructorSplat.h @@ -38,8 +38,6 @@ public: return std::make_unique(fOffset, this->type(), argument()->clone()); } - ComparisonResult compareConstant(const Expression& other) const override; - SKSL_FLOAT getFVecComponent(int) const override { return this->argument()->getConstantFloat(); } @@ -52,9 +50,12 @@ public: return this->argument()->getConstantBool(); } -private: - Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const; + const Expression* getConstantSubexpression(int n) const override { + SkASSERT(n >= 0 && n < this->type().columns()); + return this->argument()->getConstantSubexpression(0); + } +private: using INHERITED = SingleArgumentConstructor; }; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index c3329928ab..7b5386894e 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -178,6 +178,18 @@ public: return this->type().coercionCost(target); } + /** + * Returns the n'th compile-time constant expression within a literal or constructor. + * Use Type::slotCount to determine the number of subexpressions within an expression. + * Subexpressions which are not compile-time constants will return null. + * `vec4(1, vec2(2), 3)` contains four subexpressions: (1, 2, 2, 3) + * `mat2(f)` contains four subexpressions: (null, 0, + * 0, null) + */ + virtual const Expression* getConstantSubexpression(int n) const { + return nullptr; + } + /** * For a vector of floating point values, return the value of the n'th vector component. It is * an error to call this method on an expression which is not a vector of floating-point @@ -215,16 +227,6 @@ public: */ template T getVecComponent(int index) const; - /** - * For a literal matrix expression, return the floating point value of the component at - * [col][row]. It is an error to call this method on an expression which is not a literal - * matrix. - */ - virtual SKSL_FLOAT getMatComponent(int col, int row) const { - SkASSERT(false); - return 0; - } - virtual std::unique_ptr clone() const = 0; private: diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h index 4cc8eb948e..40eebc2ad4 100644 --- a/src/sksl/ir/SkSLFloatLiteral.h +++ b/src/sksl/ir/SkSLFloatLiteral.h @@ -79,6 +79,11 @@ public: return std::make_unique(fOffset, this->value(), &this->type()); } + const Expression* getConstantSubexpression(int n) const override { + SkASSERT(n == 0); + return this; + } + private: float fValue; diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h index 35050790fd..b9763a7070 100644 --- a/src/sksl/ir/SkSLIntLiteral.h +++ b/src/sksl/ir/SkSLIntLiteral.h @@ -81,6 +81,11 @@ public: return std::make_unique(fOffset, this->value(), &this->type()); } + const Expression* getConstantSubexpression(int n) const override { + SkASSERT(n == 0); + return this; + } + private: SKSL_INT fValue;