moved SkSL ForStatement data into IRNode

Change-Id: I87039ae982c7f5b6ed7a4cc236470f049606c45e
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/321468
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
This commit is contained in:
Ethan Nicholas 2020-10-05 14:47:09 -04:00 committed by Skia Commit-Bot
parent ae4bb98f13
commit 0d31ed5068
13 changed files with 154 additions and 87 deletions

View File

@ -445,10 +445,10 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
case Statement::Kind::kFor: {
auto& f = s.template as<ForStatement>();
return (f.fInitializer && this->visitStatement(*f.fInitializer)) ||
(f.fTest && this->visitExpression(*f.fTest)) ||
(f.fNext && this->visitExpression(*f.fNext)) ||
this->visitStatement(*f.fStatement);
return (f.initializer() && this->visitStatement(*f.initializer())) ||
(f.test() && this->visitExpression(*f.test())) ||
(f.next() && this->visitExpression(*f.next())) ||
this->visitStatement(*f.statement());
}
case Statement::Kind::kIf: {
auto& i = s.template as<IfStatement>();

View File

@ -1703,21 +1703,21 @@ void ByteCodeGenerator::writeDoStatement(const DoStatement& d) {
void ByteCodeGenerator::writeForStatement(const ForStatement& f) {
fContinueTargets.emplace();
fBreakTargets.emplace();
if (f.fInitializer) {
this->writeStatement(*f.fInitializer);
if (f.initializer()) {
this->writeStatement(*f.initializer());
}
this->write(ByteCodeInstruction::kLoopBegin);
size_t start = fCode->size();
if (f.fTest) {
this->writeExpression(*f.fTest);
if (f.test()) {
this->writeExpression(*f.test());
this->write(ByteCodeInstruction::kLoopMask);
}
this->write(ByteCodeInstruction::kBranchIfAllFalse);
DeferredLocation endLocation(this);
this->writeStatement(*f.fStatement);
this->writeStatement(*f.statement());
this->write(ByteCodeInstruction::kLoopNext);
if (f.fNext) {
this->writeExpression(*f.fNext, true);
if (f.next()) {
this->writeExpression(*f.next(), true);
}
this->write(ByteCodeInstruction::kBranch);
this->write16(start);

View File

@ -582,16 +582,16 @@ void CFGGenerator::addStatement(CFG& cfg, std::unique_ptr<Statement>* s) {
}
case Statement::Kind::kFor: {
ForStatement& f = (*s)->as<ForStatement>();
if (f.fInitializer) {
this->addStatement(cfg, &f.fInitializer);
if (f.initializer()) {
this->addStatement(cfg, &f.initializer());
}
BlockId loopStart = cfg.newBlock();
BlockId next = cfg.newIsolatedBlock();
fLoopContinues.push(next);
BlockId loopExit = cfg.newIsolatedBlock();
fLoopExits.push(loopExit);
if (f.fTest) {
this->addExpression(cfg, &f.fTest, /*constantPropagate=*/true);
if (f.test()) {
this->addExpression(cfg, &f.test(), /*constantPropagate=*/true);
// this isn't quite right; we should have an exit from here to the loop exit, and
// remove the exit from the loop body to the loop exit. Structuring it like this
// forces the optimizer to believe that the loop body is always executed at least
@ -601,11 +601,11 @@ void CFGGenerator::addStatement(CFG& cfg, std::unique_ptr<Statement>* s) {
// guaranteed to happen, but for the time being we take the easy way out.
}
cfg.newBlock();
this->addStatement(cfg, &f.fStatement);
this->addStatement(cfg, &f.statement());
cfg.addExit(cfg.fCurrent, next);
cfg.fCurrent = next;
if (f.fNext) {
this->addExpression(cfg, &f.fNext, /*constantPropagate=*/true);
if (f.next()) {
this->addExpression(cfg, &f.next(), /*constantPropagate=*/true);
}
cfg.addExit(cfg.fCurrent, loopStart);
cfg.addExit(cfg.fCurrent, loopExit);

View File

@ -426,11 +426,11 @@ void Dehydrator::write(const Statement* s) {
case Statement::Kind::kFor: {
const ForStatement& f = s->as<ForStatement>();
this->writeU8(Rehydrator::kFor_Command);
this->write(f.fInitializer.get());
this->write(f.fTest.get());
this->write(f.fNext.get());
this->write(f.fStatement.get());
this->write(f.fSymbols);
this->write(f.initializer().get());
this->write(f.test().get());
this->write(f.next().get());
this->write(f.statement().get());
this->write(f.symbols());
break;
}
case Statement::Kind::kIf: {

View File

@ -1370,28 +1370,28 @@ void GLSLCodeGenerator::writeIfStatement(const IfStatement& stmt) {
void GLSLCodeGenerator::writeForStatement(const ForStatement& f) {
this->write("for (");
if (f.fInitializer && !f.fInitializer->isEmpty()) {
this->writeStatement(*f.fInitializer);
if (f.initializer() && !f.initializer()->isEmpty()) {
this->writeStatement(*f.initializer());
} else {
this->write("; ");
}
if (f.fTest) {
if (f.test()) {
if (fProgram.fSettings.fCaps->addAndTrueToLoopCondition()) {
std::unique_ptr<Expression> and_true(new BinaryExpression(
-1, f.fTest->clone(), Token::Kind::TK_LOGICALAND,
-1, f.test()->clone(), Token::Kind::TK_LOGICALAND,
std::make_unique<BoolLiteral>(fContext, -1, true),
fContext.fBool_Type.get()));
this->writeExpression(*and_true, kTopLevel_Precedence);
} else {
this->writeExpression(*f.fTest, kTopLevel_Precedence);
this->writeExpression(*f.test(), kTopLevel_Precedence);
}
}
this->write("; ");
if (f.fNext) {
this->writeExpression(*f.fNext, kTopLevel_Precedence);
if (f.next()) {
this->writeExpression(*f.next(), kTopLevel_Precedence);
}
this->write(") ");
this->writeStatement(*f.fStatement);
this->writeStatement(*f.statement());
}
void GLSLCodeGenerator::writeWhileStatement(const WhileStatement& w) {

View File

@ -587,7 +587,7 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTNode& f) {
auto forStmt = std::make_unique<ForStatement>(f.fOffset, std::move(initializer),
std::move(test), std::move(next),
std::move(statement), fSymbolTable);
fInliner->ensureScopedBlocks(forStmt->fStatement.get(), forStmt.get());
fInliner->ensureScopedBlocks(forStmt->statement().get(), forStmt.get());
return std::move(forStmt);
}

View File

@ -483,9 +483,9 @@ std::unique_ptr<Statement> Inliner::inlineStatement(int offset,
const ForStatement& f = statement.as<ForStatement>();
// need to ensure initializer is evaluated first so that we've already remapped its
// declarations by the time we evaluate test & next
std::unique_ptr<Statement> initializer = stmt(f.fInitializer);
return std::make_unique<ForStatement>(offset, std::move(initializer), expr(f.fTest),
expr(f.fNext), stmt(f.fStatement), f.fSymbols);
std::unique_ptr<Statement> initializer = stmt(f.initializer());
return std::make_unique<ForStatement>(offset, std::move(initializer), expr(f.test()),
expr(f.next()), stmt(f.statement()), f.symbols());
}
case Statement::Kind::kIf: {
const IfStatement& i = statement.as<IfStatement>();
@ -884,14 +884,14 @@ public:
}
case Statement::Kind::kFor: {
ForStatement& forStmt = (*stmt)->as<ForStatement>();
if (forStmt.fSymbols) {
fSymbolTableStack.push_back(forStmt.fSymbols.get());
if (forStmt.symbols()) {
fSymbolTableStack.push_back(forStmt.symbols().get());
}
// The initializer and loop body are candidates for inlining.
this->visitStatement(&forStmt.fInitializer,
this->visitStatement(&forStmt.initializer(),
/*isViableAsEnclosingStatement=*/false);
this->visitStatement(&forStmt.fStatement);
this->visitStatement(&forStmt.statement());
// The inliner isn't smart enough to inline the test- or increment-expressions
// of a for loop loop at this time. There are a handful of limitations:

View File

@ -1332,20 +1332,20 @@ void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
this->write("for (");
if (f.fInitializer && !f.fInitializer->isEmpty()) {
this->writeStatement(*f.fInitializer);
if (f.initializer() && !f.initializer()->isEmpty()) {
this->writeStatement(*f.initializer());
} else {
this->write("; ");
}
if (f.fTest) {
this->writeExpression(*f.fTest, kTopLevel_Precedence);
if (f.test()) {
this->writeExpression(*f.test(), kTopLevel_Precedence);
}
this->write("; ");
if (f.fNext) {
this->writeExpression(*f.fNext, kTopLevel_Precedence);
if (f.next()) {
this->writeExpression(*f.next(), kTopLevel_Precedence);
}
this->write(") ");
this->writeStatement(*f.fStatement);
this->writeStatement(*f.statement());
}
void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
@ -1816,10 +1816,10 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statemen
}
case Statement::Kind::kFor: {
const ForStatement& f = s->as<ForStatement>();
return this->requirements(f.fInitializer.get()) |
this->requirements(f.fTest.get()) |
this->requirements(f.fNext.get()) |
this->requirements(f.fStatement.get());
return this->requirements(f.initializer().get()) |
this->requirements(f.test().get()) |
this->requirements(f.next().get()) |
this->requirements(f.statement().get());
}
case Statement::Kind::kWhile: {
const WhileStatement& w = s->as<WhileStatement>();

View File

@ -2952,8 +2952,8 @@ void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream&
}
void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
if (f.fInitializer) {
this->writeStatement(*f.fInitializer, out);
if (f.initializer()) {
this->writeStatement(*f.initializer(), out);
}
SpvId header = this->nextId();
SpvId start = this->nextId();
@ -2967,18 +2967,18 @@ void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream&
this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
this->writeInstruction(SpvOpBranch, start, out);
this->writeLabel(start, out);
if (f.fTest) {
SpvId test = this->writeExpression(*f.fTest, out);
if (f.test()) {
SpvId test = this->writeExpression(*f.test(), out);
this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
}
this->writeLabel(body, out);
this->writeStatement(*f.fStatement, out);
this->writeStatement(*f.statement(), out);
if (fCurrentBlock) {
this->writeInstruction(SpvOpBranch, next, out);
}
this->writeLabel(next, out);
if (f.fNext) {
this->writeExpression(*f.fNext, out);
if (f.next()) {
this->writeExpression(*f.next(), out);
}
this->writeInstruction(SpvOpBranch, header, out);
this->writeLabel(end, out);

View File

@ -17,55 +17,88 @@ namespace SkSL {
/**
* A 'for' statement.
*/
struct ForStatement : public Statement {
class ForStatement : public Statement {
public:
static constexpr Kind kStatementKind = Kind::kFor;
ForStatement(int offset, std::unique_ptr<Statement> initializer,
std::unique_ptr<Expression> test, std::unique_ptr<Expression> next,
std::unique_ptr<Statement> statement, std::shared_ptr<SymbolTable> symbols)
: INHERITED(offset, kStatementKind)
, fSymbols(symbols)
, fInitializer(std::move(initializer))
, fTest(std::move(test))
, fNext(std::move(next))
, fStatement(std::move(statement)) {}
: INHERITED(offset, kStatementKind, ForStatementData{std::move(symbols)}) {
fStatementChildren.reserve(2);
fStatementChildren.push_back(std::move(initializer));
fStatementChildren.push_back(std::move(statement));
fExpressionChildren.reserve(2);
fExpressionChildren.push_back(std::move(test));
fExpressionChildren.push_back(std::move(next));
}
std::unique_ptr<Statement>& initializer() {
return fStatementChildren[0];
}
const std::unique_ptr<Statement>& initializer() const {
return fStatementChildren[0];
}
std::unique_ptr<Expression>& test() {
return fExpressionChildren[0];
}
const std::unique_ptr<Expression>& test() const {
return fExpressionChildren[0];
}
std::unique_ptr<Expression>& next() {
return fExpressionChildren[1];
}
const std::unique_ptr<Expression>& next() const {
return fExpressionChildren[1];
}
std::unique_ptr<Statement>& statement() {
return fStatementChildren[1];
}
const std::unique_ptr<Statement>& statement() const {
return fStatementChildren[1];
}
std::shared_ptr<SymbolTable> symbols() const {
return this->forStatementData().fSymbolTable;
}
std::unique_ptr<Statement> clone() const override {
return std::unique_ptr<Statement>(new ForStatement(fOffset,
fInitializer ? fInitializer->clone() : nullptr,
fTest ? fTest->clone() : nullptr,
fNext ? fNext->clone() : nullptr,
fStatement->clone(),
fSymbols));
return std::unique_ptr<Statement>(new ForStatement(
fOffset,
this->initializer() ? this->initializer()->clone() : nullptr,
this->test() ? this->test()->clone() : nullptr,
this->next() ? this->next()->clone() : nullptr,
this->statement()->clone(),
this->symbols()));
}
String description() const override {
String result("for (");
if (fInitializer) {
result += fInitializer->description();
if (this->initializer()) {
result += this->initializer()->description();
} else {
result += ";";
}
result += " ";
if (fTest) {
result += fTest->description();
if (this->test()) {
result += this->test()->description();
}
result += "; ";
if (fNext) {
result += fNext->description();
if (this->next()) {
result += this->next()->description();
}
result += ") " + fStatement->description();
result += ") " + this->statement()->description();
return result;
}
// it's important to keep fSymbols defined first (and thus destroyed last) because destroying
// the other fields can update symbol reference counts
const std::shared_ptr<SymbolTable> fSymbols;
std::unique_ptr<Statement> fInitializer;
std::unique_ptr<Expression> fTest;
std::unique_ptr<Expression> fNext;
std::unique_ptr<Statement> fStatement;
private:
using INHERITED = Statement;
};

View File

@ -48,6 +48,11 @@ IRNode::IRNode(int offset, int kind, const FloatLiteralData& data)
, fKind(kind)
, fData(data) {}
IRNode::IRNode(int offset, int kind, const ForStatementData& data)
: fOffset(offset)
, fKind(kind)
, fData(data) {}
IRNode::IRNode(int offset, int kind, const String& data)
: fOffset(offset)
, fKind(kind)

View File

@ -110,6 +110,10 @@ protected:
float fValue;
};
struct ForStatementData {
std::shared_ptr<SymbolTable> fSymbolTable;
};
struct IntLiteralData {
const Type* fType;
int64_t fValue;
@ -133,6 +137,7 @@ protected:
kExternalValue,
kField,
kFloatLiteral,
kForStatement,
kIntLiteral,
kString,
kSymbol,
@ -148,6 +153,7 @@ protected:
ExternalValueData fExternalValue;
FieldData fField;
FloatLiteralData fFloatLiteral;
ForStatementData fForStatement;
IntLiteralData fIntLiteral;
String fString;
SymbolData fSymbol;
@ -189,6 +195,11 @@ protected:
*(new(&fContents) FloatLiteralData) = data;
}
NodeData(const ForStatementData& data)
: fKind(Kind::kForStatement) {
*(new(&fContents) ForStatementData) = data;
}
NodeData(IntLiteralData data)
: fKind(Kind::kIntLiteral) {
*(new(&fContents) IntLiteralData) = data;
@ -240,6 +251,9 @@ protected:
case Kind::kFloatLiteral:
*(new(&fContents) FloatLiteralData) = other.fContents.fFloatLiteral;
break;
case Kind::kForStatement:
*(new(&fContents) ForStatementData) = other.fContents.fForStatement;
break;
case Kind::kIntLiteral:
*(new(&fContents) IntLiteralData) = other.fContents.fIntLiteral;
break;
@ -284,6 +298,9 @@ protected:
case Kind::kFloatLiteral:
fContents.fFloatLiteral.~FloatLiteralData();
break;
case Kind::kForStatement:
fContents.fForStatement.~ForStatementData();
break;
case Kind::kIntLiteral:
fContents.fIntLiteral.~IntLiteralData();
break;
@ -313,10 +330,12 @@ protected:
IRNode(int offset, int kind, const FieldData& data);
IRNode(int offset, int kind, const IntLiteralData& data);
IRNode(int offset, int kind, const FloatLiteralData& data);
IRNode(int offset, int kind, const ForStatementData& data);
IRNode(int offset, int kind, const IntLiteralData& data);
IRNode(int offset, int kind, const String& data);
IRNode(int offset, int kind, const SymbolData& data);
@ -401,6 +420,11 @@ protected:
return fData.fContents.fFloatLiteral;
}
const ForStatementData& forStatementData() const {
SkASSERT(fData.fKind == NodeData::Kind::kForStatement);
return fData.fContents.fForStatement;
}
const IntLiteralData& intLiteralData() const {
SkASSERT(fData.fKind == NodeData::Kind::kIntLiteral);
return fData.fContents.fIntLiteral;

View File

@ -49,6 +49,11 @@ struct Statement : public IRNode {
SkASSERT(kind >= Kind::kFirst && kind <= Kind::kLast);
}
Statement(int offset, Kind kind, const ForStatementData& data)
: INHERITED(offset, (int) kind, data) {
SkASSERT(kind >= Kind::kFirst && kind <= Kind::kLast);
}
Kind kind() const {
return (Kind) fKind;
}