diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp index 355b53caf7..57a5194c7e 100644 --- a/src/sksl/SkSLConstantFolder.cpp +++ b/src/sksl/SkSLConstantFolder.cpp @@ -123,33 +123,21 @@ static Constructor splat_scalar(const Expression& scalar, const Type& type) { } bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) { - switch (value.kind()) { - case Expression::Kind::kIntLiteral: - *out = value.as().value(); - return true; - case Expression::Kind::kVariableReference: { - const Variable& var = *value.as().variable(); - return (var.modifiers().fFlags & Modifiers::kConst_Flag) && - var.initialValue() && GetConstantInt(*var.initialValue(), out); - } - default: - return false; + const Expression* expr = GetConstantValueForVariable(value); + if (!expr->is()) { + return false; } + *out = expr->as().value(); + return true; } bool ConstantFolder::GetConstantFloat(const Expression& value, SKSL_FLOAT* out) { - switch (value.kind()) { - case Expression::Kind::kFloatLiteral: - *out = value.as().value(); - return true; - case Expression::Kind::kVariableReference: { - const Variable& var = *value.as().variable(); - return (var.modifiers().fFlags & Modifiers::kConst_Flag) && - var.initialValue() && GetConstantFloat(*var.initialValue(), out); - } - default: - return false; + const Expression* expr = GetConstantValueForVariable(value); + if (!expr->is()) { + return false; } + *out = expr->as().value(); + return true; } static bool contains_constant_zero(const Expression& expr) { @@ -161,10 +149,9 @@ static bool contains_constant_zero(const Expression& expr) { } return false; } - SKSL_INT intValue; - SKSL_FLOAT floatValue; - return (ConstantFolder::GetConstantInt(expr, &intValue) && intValue == 0) || - (ConstantFolder::GetConstantFloat(expr, &floatValue) && floatValue == 0.0f); + const Expression* value = ConstantFolder::GetConstantValueForVariable(expr); + return (value->is() && value->as().value() == 0.0) || + (value->is() && value->as().value() == 0.0); } bool ConstantFolder::ErrorOnDivideByZero(const Context& context, int offset, Operator op, @@ -184,28 +171,58 @@ bool ConstantFolder::ErrorOnDivideByZero(const Context& context, int offset, Ope } } +const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) { + for (const Expression* expr = &inExpr;;) { + if (!expr->is()) { + break; + } + const VariableReference& varRef = expr->as(); + if (varRef.refKind() != VariableRefKind::kRead) { + break; + } + const Variable& var = *varRef.variable(); + if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) { + break; + } + expr = var.initialValue(); + SkASSERT(expr); + if (expr->isCompileTimeConstant()) { + return expr; + } + if (!expr->is()) { + break; + } + } + // We didn't find a compile-time constant at the end. Return the expression as-is. + return &inExpr; +} + std::unique_ptr ConstantFolder::Simplify(const Context& context, int offset, - const Expression& left, + const Expression& leftExpr, Operator op, - const Expression& right) { + const Expression& rightExpr) { + // Replace constant variables with trivial initial-values. + const Expression* left = GetConstantValueForVariable(leftExpr); + const Expression* right = GetConstantValueForVariable(rightExpr); + // If this is the comma operator, the left side is evaluated but not otherwise used in any way. // So if the left side has no side effects, it can just be eliminated entirely. - if (op.kind() == Token::Kind::TK_COMMA && !left.hasSideEffects()) { - return right.clone(); + if (op.kind() == Token::Kind::TK_COMMA && !left->hasSideEffects()) { + return right->clone(); } // If this is the assignment operator, and both sides are the same trivial expression, this is // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`). // This can happen when other parts of the assignment are optimized away. - if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSelfAssignment(left, right)) { - return right.clone(); + if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSelfAssignment(*left, *right)) { + return right->clone(); } // Simplify the expression when both sides are constant Boolean literals. - if (left.is() && right.is()) { - bool leftVal = left.as().value(); - bool rightVal = right.as().value(); + if (left->is() && right->is()) { + bool leftVal = left->as().value(); + bool rightVal = right->as().value(); bool result; switch (op.kind()) { case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break; @@ -219,28 +236,28 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, } // If the left side is a Boolean literal, apply short-circuit optimizations. - if (left.is()) { - return short_circuit_boolean(left, op, right); + if (left->is()) { + return short_circuit_boolean(*left, op, *right); } // If the right side is a Boolean literal... - if (right.is()) { + if (right->is()) { // ... and the left side has no side effects... - if (!left.hasSideEffects()) { + if (!left->hasSideEffects()) { // We can reverse the expressions and short-circuit optimizations are still valid. - return short_circuit_boolean(right, op, left); + return short_circuit_boolean(*right, op, *left); } // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions. - return eliminate_no_op_boolean(left, op, right); + return eliminate_no_op_boolean(*left, op, *right); } - if (ErrorOnDivideByZero(context, offset, op, right)) { + if (ErrorOnDivideByZero(context, offset, op, *right)) { return nullptr; } // Other than the short-circuit cases above, constant folding requires both sides to be constant - if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) { + if (!left->isCompileTimeConstant() || !right->isCompileTimeConstant()) { return nullptr; } @@ -253,9 +270,9 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, #define URESULT(t, op) std::make_unique(context, offset, \ (SKSL_UINT) leftVal op \ (SKSL_UINT) rightVal) - if (left.is() && right.is()) { - SKSL_INT leftVal = left.as().value(); - SKSL_INT rightVal = right.as().value(); + if (left->is() && right->is()) { + SKSL_INT leftVal = left->as().value(); + SKSL_INT rightVal = right->as().value(); switch (op.kind()) { case Token::Kind::TK_PLUS: return URESULT(Int, +); case Token::Kind::TK_MINUS: return URESULT(Int, -); @@ -302,9 +319,9 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, } // Perform constant folding on pairs of floating-point literals. - if (left.is() && right.is()) { - SKSL_FLOAT leftVal = left.as().value(); - SKSL_FLOAT rightVal = right.as().value(); + if (left->is() && right->is()) { + SKSL_FLOAT leftVal = left->as().value(); + SKSL_FLOAT rightVal = right->as().value(); switch (op.kind()) { case Token::Kind::TK_PLUS: return RESULT(Float, +); case Token::Kind::TK_MINUS: return RESULT(Float, -); @@ -321,14 +338,14 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, } // Perform constant folding on pairs of vectors. - const Type& leftType = left.type(); - const Type& rightType = right.type(); + const Type& leftType = left->type(); + const Type& rightType = right->type(); if (leftType.isVector() && leftType == rightType) { if (leftType.componentType().isFloat()) { - return simplify_vector(context, left, op, right); + return simplify_vector(context, *left, op, *right); } if (leftType.componentType().isInteger()) { - return simplify_vector(context, left, op, right); + return simplify_vector(context, *left, op, *right); } return nullptr; } @@ -336,11 +353,12 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2 if (leftType.isVector() && leftType.componentType() == rightType) { if (rightType.isFloat()) { - return simplify_vector(context, left, op, splat_scalar(right, left.type())); + return simplify_vector(context, *left, op, + splat_scalar(*right, left->type())); } if (rightType.isInteger()) { - return simplify_vector(context, left, op, - splat_scalar(right, left.type())); + return simplify_vector(context, *left, op, + splat_scalar(*right, left->type())); } return nullptr; } @@ -348,12 +366,12 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2) if (rightType.isVector() && rightType.componentType() == leftType) { if (leftType.isFloat()) { - return simplify_vector(context, splat_scalar(left, right.type()), op, - right); + return simplify_vector(context, splat_scalar(*left, right->type()), op, + *right); } if (leftType.isInteger()) { - return simplify_vector(context, splat_scalar(left, right.type()), - op, right); + return simplify_vector(context, splat_scalar(*left, right->type()), + op, *right); } return nullptr; } @@ -372,7 +390,7 @@ std::unique_ptr ConstantFolder::Simplify(const Context& context, return nullptr; } - switch (left.compareConstant(right)) { + switch (left->compareConstant(*right)) { case Expression::ComparisonResult::kNotEqual: equality = !equality; [[fallthrough]]; diff --git a/src/sksl/SkSLConstantFolder.h b/src/sksl/SkSLConstantFolder.h index bcc5bda168..7861454395 100644 --- a/src/sksl/SkSLConstantFolder.h +++ b/src/sksl/SkSLConstantFolder.h @@ -37,6 +37,12 @@ public: */ static bool GetConstantFloat(const Expression& value, SKSL_FLOAT* out); + /** + * If the expression is a const variable with a known compile-time-constant value, returns that + * value. If not, returns the original expression as-is. + */ + static const Expression* GetConstantValueForVariable(const Expression& value); + /** * Reports an error and returns true if op is a division / mod operator and right is zero or * contains a zero element.