From dd21816b00f5cc39483d1813f3d09b58fc61222d Mon Sep 17 00:00:00 2001 From: Ethan Nicholas Date: Thu, 8 Oct 2020 05:48:01 -0400 Subject: [PATCH] moved SkSL Ternary data into IRNode Change-Id: I70e63aaa73082024c8f0887a941d54cfd12aa2b6 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/323883 Reviewed-by: John Stiles Commit-Queue: Ethan Nicholas --- src/sksl/SkSLAnalysis.cpp | 4 +- src/sksl/SkSLByteCodeGenerator.cpp | 10 ++--- src/sksl/SkSLCFGGenerator.cpp | 18 ++++----- src/sksl/SkSLCompiler.cpp | 14 +++---- src/sksl/SkSLDehydrator.cpp | 6 +-- src/sksl/SkSLGLSLCodeGenerator.cpp | 6 +-- src/sksl/SkSLInliner.cpp | 6 +-- src/sksl/SkSLMetalCodeGenerator.cpp | 10 ++--- src/sksl/SkSLSPIRVCodeGenerator.cpp | 22 +++++------ src/sksl/ir/SkSLBinaryExpression.h | 2 +- src/sksl/ir/SkSLTernaryExpression.h | 61 ++++++++++++++++++++--------- 11 files changed, 91 insertions(+), 68 deletions(-) diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp index 0a7714e6c4..7187f154b4 100644 --- a/src/sksl/SkSLAnalysis.cpp +++ b/src/sksl/SkSLAnalysis.cpp @@ -408,8 +408,8 @@ bool TProgramVisitor::visitExpression(EXPR e) { case Expression::Kind::kTernary: { auto& t = e.template as(); - return this->visitExpression(*t.fTest) || this->visitExpression(*t.fIfTrue) || - this->visitExpression(*t.fIfFalse); + return this->visitExpression(*t.test()) || this->visitExpression(*t.ifTrue()) || + this->visitExpression(*t.ifFalse()); } default: SkUNREACHABLE; diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp index fa55aa9003..9e2b778c1f 100644 --- a/src/sksl/SkSLByteCodeGenerator.cpp +++ b/src/sksl/SkSLByteCodeGenerator.cpp @@ -1427,14 +1427,14 @@ void ByteCodeGenerator::writeSwizzle(const Swizzle& s) { void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) { int count = SlotCount(t.type()); - SkASSERT(count == SlotCount(t.fIfTrue->type())); - SkASSERT(count == SlotCount(t.fIfFalse->type())); + SkASSERT(count == SlotCount(t.ifTrue()->type())); + SkASSERT(count == SlotCount(t.ifFalse()->type())); - this->writeExpression(*t.fTest); + this->writeExpression(*t.test()); this->write(ByteCodeInstruction::kMaskPush); - this->writeExpression(*t.fIfTrue); + this->writeExpression(*t.ifTrue()); this->write(ByteCodeInstruction::kMaskNegate); - this->writeExpression(*t.fIfFalse); + this->writeExpression(*t.ifFalse()); this->write(ByteCodeInstruction::kMaskBlend, count); } diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp index 084073cb3d..cd6a17d0de 100644 --- a/src/sksl/SkSLCFGGenerator.cpp +++ b/src/sksl/SkSLCFGGenerator.cpp @@ -141,13 +141,13 @@ bool BasicBlock::tryRemoveLValueBefore(std::vector::iterator* } case Expression::Kind::kTernary: { TernaryExpression& ternary = lvalue->as(); - if (!this->tryRemoveExpressionBefore(iter, ternary.fTest.get())) { + if (!this->tryRemoveExpressionBefore(iter, ternary.test().get())) { return false; } - if (!this->tryRemoveLValueBefore(iter, ternary.fIfTrue.get())) { + if (!this->tryRemoveLValueBefore(iter, ternary.ifTrue().get())) { return false; } - return this->tryRemoveLValueBefore(iter, ternary.fIfFalse.get()); + return this->tryRemoveLValueBefore(iter, ternary.ifFalse().get()); } default: #ifdef SK_DEBUG @@ -413,15 +413,15 @@ void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr* e, bool break; case Expression::Kind::kTernary: { TernaryExpression& t = e->get()->as(); - this->addExpression(cfg, &t.fTest, constantPropagate); + this->addExpression(cfg, &t.test(), constantPropagate); cfg.currentBlock().fNodes.push_back(BasicBlock::MakeExpression(e, constantPropagate)); BlockId start = cfg.fCurrent; cfg.newBlock(); - this->addExpression(cfg, &t.fIfTrue, constantPropagate); + this->addExpression(cfg, &t.ifTrue(), constantPropagate); BlockId next = cfg.newBlock(); cfg.fCurrent = start; cfg.newBlock(); - this->addExpression(cfg, &t.fIfFalse, constantPropagate); + this->addExpression(cfg, &t.ifFalse(), constantPropagate); cfg.addExit(cfg.fCurrent, next); cfg.fCurrent = next; break; @@ -454,12 +454,12 @@ void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr* e) { break; case Expression::Kind::kTernary: { TernaryExpression& ternary = e->get()->as(); - this->addExpression(cfg, &ternary.fTest, /*constantPropagate=*/true); + this->addExpression(cfg, &ternary.test(), /*constantPropagate=*/true); // Technically we will of course only evaluate one or the other, but if the test turns // out to be constant, the ternary will get collapsed down to just one branch anyway. So // it should be ok to pretend that we always evaluate both branches here. - this->addLValue(cfg, &ternary.fIfTrue); - this->addLValue(cfg, &ternary.fIfFalse); + this->addLValue(cfg, &ternary.ifTrue()); + this->addLValue(cfg, &ternary.ifFalse()); break; } default: diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index adbf37435e..85dbbc7fe1 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -466,10 +466,10 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptraddDefinition(lvalue->as().fIfTrue.get(), + this->addDefinition(lvalue->as().ifTrue().get(), (std::unique_ptr*) &fContext->fDefined_Expression, definitions); - this->addDefinition(lvalue->as().fIfFalse.get(), + this->addDefinition(lvalue->as().ifFalse().get(), (std::unique_ptr*) &fContext->fDefined_Expression, definitions); break; @@ -633,7 +633,7 @@ static bool is_dead(const Expression& lvalue) { } case Expression::Kind::kTernary: { const TernaryExpression& t = lvalue.as(); - return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse); + return !t.test()->hasSideEffects() && is_dead(*t.ifTrue()) && is_dead(*t.ifFalse()); } case Expression::Kind::kExternalValue: return false; @@ -926,13 +926,13 @@ void Compiler::simplifyExpression(DefinitionMap& definitions, } case Expression::Kind::kTernary: { TernaryExpression* t = &expr->as(); - if (t->fTest->kind() == Expression::Kind::kBoolLiteral) { + if (t->test()->is()) { // ternary has a constant test, replace it with either the true or // false branch - if (t->fTest->as().value()) { - (*iter)->setExpression(std::move(t->fIfTrue)); + if (t->test()->as().value()) { + (*iter)->setExpression(std::move(t->ifTrue())); } else { - (*iter)->setExpression(std::move(t->fIfFalse)); + (*iter)->setExpression(std::move(t->ifFalse())); } *outUpdated = true; *outNeedsRescan = true; diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp index 679e5ed44c..2b81eaae3a 100644 --- a/src/sksl/SkSLDehydrator.cpp +++ b/src/sksl/SkSLDehydrator.cpp @@ -370,9 +370,9 @@ void Dehydrator::write(const Expression* e) { case Expression::Kind::kTernary: { const TernaryExpression& t = e->as(); this->writeU8(Rehydrator::kTernary_Command); - this->write(t.fTest.get()); - this->write(t.fIfTrue.get()); - this->write(t.fIfFalse.get()); + this->write(t.test().get()); + this->write(t.ifTrue().get()); + this->write(t.ifFalse().get()); break; } case Expression::Kind::kVariableReference: { diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp index e25972738f..8eef94241e 100644 --- a/src/sksl/SkSLGLSLCodeGenerator.cpp +++ b/src/sksl/SkSLGLSLCodeGenerator.cpp @@ -976,11 +976,11 @@ void GLSLCodeGenerator::writeTernaryExpression(const TernaryExpression& t, if (kTernary_Precedence >= parentPrecedence) { this->write("("); } - this->writeExpression(*t.fTest, kTernary_Precedence); + this->writeExpression(*t.test(), kTernary_Precedence); this->write(" ? "); - this->writeExpression(*t.fIfTrue, kTernary_Precedence); + this->writeExpression(*t.ifTrue(), kTernary_Precedence); this->write(" : "); - this->writeExpression(*t.fIfFalse, kTernary_Precedence); + this->writeExpression(*t.ifFalse(), kTernary_Precedence); if (kTernary_Precedence >= parentPrecedence) { this->write(")"); } diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp index c705ddeb1c..41f35f03c0 100644 --- a/src/sksl/SkSLInliner.cpp +++ b/src/sksl/SkSLInliner.cpp @@ -408,8 +408,8 @@ std::unique_ptr Inliner::inlineExpression(int offset, } case Expression::Kind::kTernary: { const TernaryExpression& t = expression.as(); - return std::make_unique(offset, expr(t.fTest), - expr(t.fIfTrue), expr(t.fIfFalse)); + return std::make_unique(offset, expr(t.test()), + expr(t.ifTrue()), expr(t.ifFalse())); } case Expression::Kind::kTypeReference: return expression.clone(); @@ -1043,7 +1043,7 @@ public: case Expression::Kind::kTernary: { TernaryExpression& ternaryExpr = (*expr)->as(); // The test expression is a candidate for inlining. - this->visitExpression(&ternaryExpr.fTest); + this->visitExpression(&ternaryExpr.test()); // The true- and false-expressions cannot be inlined, because we are only allowed to // evaluate one side. break; diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp index c5df51645c..980d499b87 100644 --- a/src/sksl/SkSLMetalCodeGenerator.cpp +++ b/src/sksl/SkSLMetalCodeGenerator.cpp @@ -876,11 +876,11 @@ void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t, if (kTernary_Precedence >= parentPrecedence) { this->write("("); } - this->writeExpression(*t.fTest, kTernary_Precedence); + this->writeExpression(*t.test(), kTernary_Precedence); this->write(" ? "); - this->writeExpression(*t.fIfTrue, kTernary_Precedence); + this->writeExpression(*t.ifTrue(), kTernary_Precedence); this->write(" : "); - this->writeExpression(*t.fIfFalse, kTernary_Precedence); + this->writeExpression(*t.ifFalse(), kTernary_Precedence); if (kTernary_Precedence >= parentPrecedence) { this->write(")"); } @@ -1711,8 +1711,8 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi return this->requirements(e->as().fOperand.get()); case Expression::Kind::kTernary: { const TernaryExpression& t = e->as(); - return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) | - this->requirements(t.fIfFalse.get()); + return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) | + this->requirements(t.ifFalse().get()); } case Expression::Kind::kVariableReference: { const VariableReference& v = e->as(); diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 9f45c5b7f1..63a354338a 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -1788,18 +1788,18 @@ std::unique_ptr SPIRVCodeGenerator::getLValue(const } case Expression::Kind::kTernary: { TernaryExpression& t = (TernaryExpression&) expr; - SpvId test = this->writeExpression(*t.fTest, out); + SpvId test = this->writeExpression(*t.test(), out); SpvId end = this->nextId(); SpvId ifTrueLabel = this->nextId(); SpvId ifFalseLabel = this->nextId(); this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out); this->writeLabel(ifTrueLabel, out); - SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer(); + SpvId ifTrue = this->getLValue(*t.ifTrue(), out)->getPointer(); SkASSERT(ifTrue); this->writeInstruction(SpvOpBranch, end, out); ifTrueLabel = fCurrentBlock; - SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer(); + SpvId ifFalse = this->getLValue(*t.ifFalse(), out)->getPointer(); SkASSERT(ifFalse); ifFalseLabel = fCurrentBlock; this->writeInstruction(SpvOpBranch, end, out); @@ -2389,14 +2389,14 @@ SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { const Type& type = t.type(); - SpvId test = this->writeExpression(*t.fTest, out); - if (t.fIfTrue->type().columns() == 1 && - t.fIfTrue->isCompileTimeConstant() && - t.fIfFalse->isCompileTimeConstant()) { + SpvId test = this->writeExpression(*t.test(), out); + if (t.ifTrue()->type().columns() == 1 && + t.ifTrue()->isCompileTimeConstant() && + t.ifFalse()->isCompileTimeConstant()) { // both true and false are constants, can just use OpSelect SpvId result = this->nextId(); - SpvId trueId = this->writeExpression(*t.fIfTrue, out); - SpvId falseId = this->writeExpression(*t.fIfFalse, out); + SpvId trueId = this->writeExpression(*t.ifTrue(), out); + SpvId falseId = this->writeExpression(*t.ifFalse(), out); this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId, out); return result; @@ -2412,10 +2412,10 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, Out this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); this->writeLabel(trueLabel, out); - this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out); + this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out); this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(falseLabel, out); - this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out); + this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out); this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(end, out); SpvId result = this->nextId(); diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h index 5ef6240973..dcfce09179 100644 --- a/src/sksl/ir/SkSLBinaryExpression.h +++ b/src/sksl/ir/SkSLBinaryExpression.h @@ -31,7 +31,7 @@ static inline bool check_ref(const Expression& expr) { return check_ref(*expr.as().fBase); case Expression::Kind::kTernary: { const TernaryExpression& t = expr.as(); - return check_ref(*t.fIfTrue) && check_ref(*t.fIfFalse); + return check_ref(*t.ifTrue()) && check_ref(*t.ifFalse()); } case Expression::Kind::kVariableReference: { const VariableReference& ref = expr.as(); diff --git a/src/sksl/ir/SkSLTernaryExpression.h b/src/sksl/ir/SkSLTernaryExpression.h index f590bf8191..c5aef106c6 100644 --- a/src/sksl/ir/SkSLTernaryExpression.h +++ b/src/sksl/ir/SkSLTernaryExpression.h @@ -16,43 +16,66 @@ namespace SkSL { /** * A ternary expression (test ? ifTrue : ifFalse). */ -struct TernaryExpression : public Expression { +class TernaryExpression : public Expression { +public: static constexpr Kind kExpressionKind = Kind::kTernary; TernaryExpression(int offset, std::unique_ptr test, std::unique_ptr ifTrue, std::unique_ptr ifFalse) - : INHERITED(offset, kExpressionKind, &ifTrue->type()) - , fTest(std::move(test)) - , fIfTrue(std::move(ifTrue)) - , fIfFalse(std::move(ifFalse)) { - SkASSERT(fIfTrue->type() == fIfFalse->type()); + : INHERITED(offset, kExpressionKind, &ifTrue->type()) { + SkASSERT(ifTrue->type() == ifFalse->type()); + fExpressionChildren.reserve(3); + fExpressionChildren.push_back(std::move(test)); + fExpressionChildren.push_back(std::move(ifTrue)); + fExpressionChildren.push_back(std::move(ifFalse)); + } + + std::unique_ptr& test() { + return fExpressionChildren[0]; + } + + const std::unique_ptr& test() const { + return fExpressionChildren[0]; + } + + std::unique_ptr& ifTrue() { + return fExpressionChildren[1]; + } + + const std::unique_ptr& ifTrue() const { + return fExpressionChildren[1]; + } + + std::unique_ptr& ifFalse() { + return fExpressionChildren[2]; + } + + const std::unique_ptr& ifFalse() const { + return fExpressionChildren[2]; } bool hasProperty(Property property) const override { - return fTest->hasProperty(property) || fIfTrue->hasProperty(property) || - fIfFalse->hasProperty(property); + return this->test()->hasProperty(property) || this->ifTrue()->hasProperty(property) || + this->ifFalse()->hasProperty(property); } bool isConstantOrUniform() const override { - return fTest->isConstantOrUniform() && fIfTrue->isConstantOrUniform() && - fIfFalse->isConstantOrUniform(); + return this->test()->isConstantOrUniform() && this->ifTrue()->isConstantOrUniform() && + this->ifFalse()->isConstantOrUniform(); } std::unique_ptr clone() const override { - return std::unique_ptr(new TernaryExpression(fOffset, fTest->clone(), - fIfTrue->clone(), - fIfFalse->clone())); + return std::unique_ptr(new TernaryExpression(fOffset, this->test()->clone(), + this->ifTrue()->clone(), + this->ifFalse()->clone())); } String description() const override { - return "(" + fTest->description() + " ? " + fIfTrue->description() + " : " + - fIfFalse->description() + ")"; + return "(" + this->test()->description() + " ? " + this->ifTrue()->description() + " : " + + this->ifFalse()->description() + ")"; } - std::unique_ptr fTest; - std::unique_ptr fIfTrue; - std::unique_ptr fIfFalse; - +private: using INHERITED = Expression; };