moved SkSL Ternary data into IRNode

Change-Id: I70e63aaa73082024c8f0887a941d54cfd12aa2b6
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/323883
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
Ethan Nicholas 2020-10-08 05:48:01 -04:00 committed by Skia Commit-Bot
parent 8c44ecae4f
commit dd21816b00
11 changed files with 91 additions and 68 deletions

View File

@ -408,8 +408,8 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
auto& t = e.template as<TernaryExpression>(); auto& t = e.template as<TernaryExpression>();
return this->visitExpression(*t.fTest) || this->visitExpression(*t.fIfTrue) || return this->visitExpression(*t.test()) || this->visitExpression(*t.ifTrue()) ||
this->visitExpression(*t.fIfFalse); this->visitExpression(*t.ifFalse());
} }
default: default:
SkUNREACHABLE; SkUNREACHABLE;

View File

@ -1427,14 +1427,14 @@ void ByteCodeGenerator::writeSwizzle(const Swizzle& s) {
void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) { void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) {
int count = SlotCount(t.type()); int count = SlotCount(t.type());
SkASSERT(count == SlotCount(t.fIfTrue->type())); SkASSERT(count == SlotCount(t.ifTrue()->type()));
SkASSERT(count == SlotCount(t.fIfFalse->type())); SkASSERT(count == SlotCount(t.ifFalse()->type()));
this->writeExpression(*t.fTest); this->writeExpression(*t.test());
this->write(ByteCodeInstruction::kMaskPush); this->write(ByteCodeInstruction::kMaskPush);
this->writeExpression(*t.fIfTrue); this->writeExpression(*t.ifTrue());
this->write(ByteCodeInstruction::kMaskNegate); this->write(ByteCodeInstruction::kMaskNegate);
this->writeExpression(*t.fIfFalse); this->writeExpression(*t.ifFalse());
this->write(ByteCodeInstruction::kMaskBlend, count); this->write(ByteCodeInstruction::kMaskBlend, count);
} }

View File

@ -141,13 +141,13 @@ bool BasicBlock::tryRemoveLValueBefore(std::vector<BasicBlock::Node>::iterator*
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression& ternary = lvalue->as<TernaryExpression>(); TernaryExpression& ternary = lvalue->as<TernaryExpression>();
if (!this->tryRemoveExpressionBefore(iter, ternary.fTest.get())) { if (!this->tryRemoveExpressionBefore(iter, ternary.test().get())) {
return false; return false;
} }
if (!this->tryRemoveLValueBefore(iter, ternary.fIfTrue.get())) { if (!this->tryRemoveLValueBefore(iter, ternary.ifTrue().get())) {
return false; return false;
} }
return this->tryRemoveLValueBefore(iter, ternary.fIfFalse.get()); return this->tryRemoveLValueBefore(iter, ternary.ifFalse().get());
} }
default: default:
#ifdef SK_DEBUG #ifdef SK_DEBUG
@ -413,15 +413,15 @@ void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool
break; break;
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression& t = e->get()->as<TernaryExpression>(); TernaryExpression& t = e->get()->as<TernaryExpression>();
this->addExpression(cfg, &t.fTest, constantPropagate); this->addExpression(cfg, &t.test(), constantPropagate);
cfg.currentBlock().fNodes.push_back(BasicBlock::MakeExpression(e, constantPropagate)); cfg.currentBlock().fNodes.push_back(BasicBlock::MakeExpression(e, constantPropagate));
BlockId start = cfg.fCurrent; BlockId start = cfg.fCurrent;
cfg.newBlock(); cfg.newBlock();
this->addExpression(cfg, &t.fIfTrue, constantPropagate); this->addExpression(cfg, &t.ifTrue(), constantPropagate);
BlockId next = cfg.newBlock(); BlockId next = cfg.newBlock();
cfg.fCurrent = start; cfg.fCurrent = start;
cfg.newBlock(); cfg.newBlock();
this->addExpression(cfg, &t.fIfFalse, constantPropagate); this->addExpression(cfg, &t.ifFalse(), constantPropagate);
cfg.addExit(cfg.fCurrent, next); cfg.addExit(cfg.fCurrent, next);
cfg.fCurrent = next; cfg.fCurrent = next;
break; break;
@ -454,12 +454,12 @@ void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr<Expression>* e) {
break; break;
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression& ternary = e->get()->as<TernaryExpression>(); TernaryExpression& ternary = e->get()->as<TernaryExpression>();
this->addExpression(cfg, &ternary.fTest, /*constantPropagate=*/true); this->addExpression(cfg, &ternary.test(), /*constantPropagate=*/true);
// Technically we will of course only evaluate one or the other, but if the test turns // Technically we will of course only evaluate one or the other, but if the test turns
// out to be constant, the ternary will get collapsed down to just one branch anyway. So // out to be constant, the ternary will get collapsed down to just one branch anyway. So
// it should be ok to pretend that we always evaluate both branches here. // it should be ok to pretend that we always evaluate both branches here.
this->addLValue(cfg, &ternary.fIfTrue); this->addLValue(cfg, &ternary.ifTrue());
this->addLValue(cfg, &ternary.fIfFalse); this->addLValue(cfg, &ternary.ifFalse());
break; break;
} }
default: default:

View File

@ -466,10 +466,10 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio
// To simplify analysis, we just pretend that we write to both sides of the ternary. // To simplify analysis, we just pretend that we write to both sides of the ternary.
// This allows for false positives (meaning we fail to detect that a variable might not // This allows for false positives (meaning we fail to detect that a variable might not
// have been assigned), but is preferable to false negatives. // have been assigned), but is preferable to false negatives.
this->addDefinition(lvalue->as<TernaryExpression>().fIfTrue.get(), this->addDefinition(lvalue->as<TernaryExpression>().ifTrue().get(),
(std::unique_ptr<Expression>*) &fContext->fDefined_Expression, (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions); definitions);
this->addDefinition(lvalue->as<TernaryExpression>().fIfFalse.get(), this->addDefinition(lvalue->as<TernaryExpression>().ifFalse().get(),
(std::unique_ptr<Expression>*) &fContext->fDefined_Expression, (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions); definitions);
break; break;
@ -633,7 +633,7 @@ static bool is_dead(const Expression& lvalue) {
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
const TernaryExpression& t = lvalue.as<TernaryExpression>(); const TernaryExpression& t = lvalue.as<TernaryExpression>();
return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse); return !t.test()->hasSideEffects() && is_dead(*t.ifTrue()) && is_dead(*t.ifFalse());
} }
case Expression::Kind::kExternalValue: case Expression::Kind::kExternalValue:
return false; return false;
@ -926,13 +926,13 @@ void Compiler::simplifyExpression(DefinitionMap& definitions,
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression* t = &expr->as<TernaryExpression>(); TernaryExpression* t = &expr->as<TernaryExpression>();
if (t->fTest->kind() == Expression::Kind::kBoolLiteral) { if (t->test()->is<BoolLiteral>()) {
// ternary has a constant test, replace it with either the true or // ternary has a constant test, replace it with either the true or
// false branch // false branch
if (t->fTest->as<BoolLiteral>().value()) { if (t->test()->as<BoolLiteral>().value()) {
(*iter)->setExpression(std::move(t->fIfTrue)); (*iter)->setExpression(std::move(t->ifTrue()));
} else { } else {
(*iter)->setExpression(std::move(t->fIfFalse)); (*iter)->setExpression(std::move(t->ifFalse()));
} }
*outUpdated = true; *outUpdated = true;
*outNeedsRescan = true; *outNeedsRescan = true;

View File

@ -370,9 +370,9 @@ void Dehydrator::write(const Expression* e) {
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
const TernaryExpression& t = e->as<TernaryExpression>(); const TernaryExpression& t = e->as<TernaryExpression>();
this->writeU8(Rehydrator::kTernary_Command); this->writeU8(Rehydrator::kTernary_Command);
this->write(t.fTest.get()); this->write(t.test().get());
this->write(t.fIfTrue.get()); this->write(t.ifTrue().get());
this->write(t.fIfFalse.get()); this->write(t.ifFalse().get());
break; break;
} }
case Expression::Kind::kVariableReference: { case Expression::Kind::kVariableReference: {

View File

@ -976,11 +976,11 @@ void GLSLCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
if (kTernary_Precedence >= parentPrecedence) { if (kTernary_Precedence >= parentPrecedence) {
this->write("("); this->write("(");
} }
this->writeExpression(*t.fTest, kTernary_Precedence); this->writeExpression(*t.test(), kTernary_Precedence);
this->write(" ? "); this->write(" ? ");
this->writeExpression(*t.fIfTrue, kTernary_Precedence); this->writeExpression(*t.ifTrue(), kTernary_Precedence);
this->write(" : "); this->write(" : ");
this->writeExpression(*t.fIfFalse, kTernary_Precedence); this->writeExpression(*t.ifFalse(), kTernary_Precedence);
if (kTernary_Precedence >= parentPrecedence) { if (kTernary_Precedence >= parentPrecedence) {
this->write(")"); this->write(")");
} }

View File

@ -408,8 +408,8 @@ std::unique_ptr<Expression> Inliner::inlineExpression(int offset,
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
const TernaryExpression& t = expression.as<TernaryExpression>(); const TernaryExpression& t = expression.as<TernaryExpression>();
return std::make_unique<TernaryExpression>(offset, expr(t.fTest), return std::make_unique<TernaryExpression>(offset, expr(t.test()),
expr(t.fIfTrue), expr(t.fIfFalse)); expr(t.ifTrue()), expr(t.ifFalse()));
} }
case Expression::Kind::kTypeReference: case Expression::Kind::kTypeReference:
return expression.clone(); return expression.clone();
@ -1043,7 +1043,7 @@ public:
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression& ternaryExpr = (*expr)->as<TernaryExpression>(); TernaryExpression& ternaryExpr = (*expr)->as<TernaryExpression>();
// The test expression is a candidate for inlining. // The test expression is a candidate for inlining.
this->visitExpression(&ternaryExpr.fTest); this->visitExpression(&ternaryExpr.test());
// The true- and false-expressions cannot be inlined, because we are only allowed to // The true- and false-expressions cannot be inlined, because we are only allowed to
// evaluate one side. // evaluate one side.
break; break;

View File

@ -876,11 +876,11 @@ void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
if (kTernary_Precedence >= parentPrecedence) { if (kTernary_Precedence >= parentPrecedence) {
this->write("("); this->write("(");
} }
this->writeExpression(*t.fTest, kTernary_Precedence); this->writeExpression(*t.test(), kTernary_Precedence);
this->write(" ? "); this->write(" ? ");
this->writeExpression(*t.fIfTrue, kTernary_Precedence); this->writeExpression(*t.ifTrue(), kTernary_Precedence);
this->write(" : "); this->write(" : ");
this->writeExpression(*t.fIfFalse, kTernary_Precedence); this->writeExpression(*t.ifFalse(), kTernary_Precedence);
if (kTernary_Precedence >= parentPrecedence) { if (kTernary_Precedence >= parentPrecedence) {
this->write(")"); this->write(")");
} }
@ -1711,8 +1711,8 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi
return this->requirements(e->as<PostfixExpression>().fOperand.get()); return this->requirements(e->as<PostfixExpression>().fOperand.get());
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
const TernaryExpression& t = e->as<TernaryExpression>(); const TernaryExpression& t = e->as<TernaryExpression>();
return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) | return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
this->requirements(t.fIfFalse.get()); this->requirements(t.ifFalse().get());
} }
case Expression::Kind::kVariableReference: { case Expression::Kind::kVariableReference: {
const VariableReference& v = e->as<VariableReference>(); const VariableReference& v = e->as<VariableReference>();

View File

@ -1788,18 +1788,18 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
TernaryExpression& t = (TernaryExpression&) expr; TernaryExpression& t = (TernaryExpression&) expr;
SpvId test = this->writeExpression(*t.fTest, out); SpvId test = this->writeExpression(*t.test(), out);
SpvId end = this->nextId(); SpvId end = this->nextId();
SpvId ifTrueLabel = this->nextId(); SpvId ifTrueLabel = this->nextId();
SpvId ifFalseLabel = this->nextId(); SpvId ifFalseLabel = this->nextId();
this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out); this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
this->writeLabel(ifTrueLabel, out); this->writeLabel(ifTrueLabel, out);
SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer(); SpvId ifTrue = this->getLValue(*t.ifTrue(), out)->getPointer();
SkASSERT(ifTrue); SkASSERT(ifTrue);
this->writeInstruction(SpvOpBranch, end, out); this->writeInstruction(SpvOpBranch, end, out);
ifTrueLabel = fCurrentBlock; ifTrueLabel = fCurrentBlock;
SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer(); SpvId ifFalse = this->getLValue(*t.ifFalse(), out)->getPointer();
SkASSERT(ifFalse); SkASSERT(ifFalse);
ifFalseLabel = fCurrentBlock; ifFalseLabel = fCurrentBlock;
this->writeInstruction(SpvOpBranch, end, out); this->writeInstruction(SpvOpBranch, end, out);
@ -2389,14 +2389,14 @@ SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream
SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
const Type& type = t.type(); const Type& type = t.type();
SpvId test = this->writeExpression(*t.fTest, out); SpvId test = this->writeExpression(*t.test(), out);
if (t.fIfTrue->type().columns() == 1 && if (t.ifTrue()->type().columns() == 1 &&
t.fIfTrue->isCompileTimeConstant() && t.ifTrue()->isCompileTimeConstant() &&
t.fIfFalse->isCompileTimeConstant()) { t.ifFalse()->isCompileTimeConstant()) {
// both true and false are constants, can just use OpSelect // both true and false are constants, can just use OpSelect
SpvId result = this->nextId(); SpvId result = this->nextId();
SpvId trueId = this->writeExpression(*t.fIfTrue, out); SpvId trueId = this->writeExpression(*t.ifTrue(), out);
SpvId falseId = this->writeExpression(*t.fIfFalse, out); SpvId falseId = this->writeExpression(*t.ifFalse(), out);
this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId, this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
out); out);
return result; return result;
@ -2412,10 +2412,10 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, Out
this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
this->writeLabel(trueLabel, out); this->writeLabel(trueLabel, out);
this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out); this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out);
this->writeInstruction(SpvOpBranch, end, out); this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(falseLabel, out); this->writeLabel(falseLabel, out);
this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out); this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out);
this->writeInstruction(SpvOpBranch, end, out); this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(end, out); this->writeLabel(end, out);
SpvId result = this->nextId(); SpvId result = this->nextId();

View File

@ -31,7 +31,7 @@ static inline bool check_ref(const Expression& expr) {
return check_ref(*expr.as<Swizzle>().fBase); return check_ref(*expr.as<Swizzle>().fBase);
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
const TernaryExpression& t = expr.as<TernaryExpression>(); const TernaryExpression& t = expr.as<TernaryExpression>();
return check_ref(*t.fIfTrue) && check_ref(*t.fIfFalse); return check_ref(*t.ifTrue()) && check_ref(*t.ifFalse());
} }
case Expression::Kind::kVariableReference: { case Expression::Kind::kVariableReference: {
const VariableReference& ref = expr.as<VariableReference>(); const VariableReference& ref = expr.as<VariableReference>();

View File

@ -16,43 +16,66 @@ namespace SkSL {
/** /**
* A ternary expression (test ? ifTrue : ifFalse). * A ternary expression (test ? ifTrue : ifFalse).
*/ */
struct TernaryExpression : public Expression { class TernaryExpression : public Expression {
public:
static constexpr Kind kExpressionKind = Kind::kTernary; static constexpr Kind kExpressionKind = Kind::kTernary;
TernaryExpression(int offset, std::unique_ptr<Expression> test, TernaryExpression(int offset, std::unique_ptr<Expression> test,
std::unique_ptr<Expression> ifTrue, std::unique_ptr<Expression> ifFalse) std::unique_ptr<Expression> ifTrue, std::unique_ptr<Expression> ifFalse)
: INHERITED(offset, kExpressionKind, &ifTrue->type()) : INHERITED(offset, kExpressionKind, &ifTrue->type()) {
, fTest(std::move(test)) SkASSERT(ifTrue->type() == ifFalse->type());
, fIfTrue(std::move(ifTrue)) fExpressionChildren.reserve(3);
, fIfFalse(std::move(ifFalse)) { fExpressionChildren.push_back(std::move(test));
SkASSERT(fIfTrue->type() == fIfFalse->type()); fExpressionChildren.push_back(std::move(ifTrue));
fExpressionChildren.push_back(std::move(ifFalse));
}
std::unique_ptr<Expression>& test() {
return fExpressionChildren[0];
}
const std::unique_ptr<Expression>& test() const {
return fExpressionChildren[0];
}
std::unique_ptr<Expression>& ifTrue() {
return fExpressionChildren[1];
}
const std::unique_ptr<Expression>& ifTrue() const {
return fExpressionChildren[1];
}
std::unique_ptr<Expression>& ifFalse() {
return fExpressionChildren[2];
}
const std::unique_ptr<Expression>& ifFalse() const {
return fExpressionChildren[2];
} }
bool hasProperty(Property property) const override { bool hasProperty(Property property) const override {
return fTest->hasProperty(property) || fIfTrue->hasProperty(property) || return this->test()->hasProperty(property) || this->ifTrue()->hasProperty(property) ||
fIfFalse->hasProperty(property); this->ifFalse()->hasProperty(property);
} }
bool isConstantOrUniform() const override { bool isConstantOrUniform() const override {
return fTest->isConstantOrUniform() && fIfTrue->isConstantOrUniform() && return this->test()->isConstantOrUniform() && this->ifTrue()->isConstantOrUniform() &&
fIfFalse->isConstantOrUniform(); this->ifFalse()->isConstantOrUniform();
} }
std::unique_ptr<Expression> clone() const override { std::unique_ptr<Expression> clone() const override {
return std::unique_ptr<Expression>(new TernaryExpression(fOffset, fTest->clone(), return std::unique_ptr<Expression>(new TernaryExpression(fOffset, this->test()->clone(),
fIfTrue->clone(), this->ifTrue()->clone(),
fIfFalse->clone())); this->ifFalse()->clone()));
} }
String description() const override { String description() const override {
return "(" + fTest->description() + " ? " + fIfTrue->description() + " : " + return "(" + this->test()->description() + " ? " + this->ifTrue()->description() + " : " +
fIfFalse->description() + ")"; this->ifFalse()->description() + ")";
} }
std::unique_ptr<Expression> fTest; private:
std::unique_ptr<Expression> fIfTrue;
std::unique_ptr<Expression> fIfFalse;
using INHERITED = Expression; using INHERITED = Expression;
}; };