Add unique_ptr visitation support to ProgramWriter.

This will allow a subclass to rewrite a Statement or Expression when
visiting it.

Change-Id: Ia8b3121dd0558f2fbbd035d38f7caec9414fe8c3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/382756
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
John Stiles 2021-03-11 14:26:42 -05:00 committed by Skia Commit-Bot
parent 89d6e79ec0
commit 48b255838c
2 changed files with 87 additions and 57 deletions

View File

@ -1097,8 +1097,7 @@ bool ProgramVisitor::visit(const Program& program) {
return false; return false;
} }
template <typename PROG, typename EXPR, typename STMT, typename ELEM> template <typename T> bool TProgramVisitor<T>::visitExpression(typename T::Expression& e) {
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
switch (e.kind()) { switch (e.kind()) {
case Expression::Kind::kBoolLiteral: case Expression::Kind::kBoolLiteral:
case Expression::Kind::kDefined: case Expression::Kind::kDefined:
@ -1114,61 +1113,60 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
case Expression::Kind::kBinary: { case Expression::Kind::kBinary: {
auto& b = e.template as<BinaryExpression>(); auto& b = e.template as<BinaryExpression>();
return (b.left() && this->visitExpression(*b.left())) || return (b.left() && this->visitExpressionPtr(b.left())) ||
(b.right() && this->visitExpression(*b.right())); (b.right() && this->visitExpressionPtr(b.right()));
} }
case Expression::Kind::kConstructor: { case Expression::Kind::kConstructor: {
auto& c = e.template as<Constructor>(); auto& c = e.template as<Constructor>();
for (auto& arg : c.arguments()) { for (auto& arg : c.arguments()) {
if (this->visitExpression(*arg)) { return true; } if (this->visitExpressionPtr(arg)) { return true; }
} }
return false; return false;
} }
case Expression::Kind::kExternalFunctionCall: { case Expression::Kind::kExternalFunctionCall: {
auto& c = e.template as<ExternalFunctionCall>(); auto& c = e.template as<ExternalFunctionCall>();
for (auto& arg : c.arguments()) { for (auto& arg : c.arguments()) {
if (this->visitExpression(*arg)) { return true; } if (this->visitExpressionPtr(arg)) { return true; }
} }
return false; return false;
} }
case Expression::Kind::kFieldAccess: case Expression::Kind::kFieldAccess:
return this->visitExpression(*e.template as<FieldAccess>().base()); return this->visitExpressionPtr(e.template as<FieldAccess>().base());
case Expression::Kind::kFunctionCall: { case Expression::Kind::kFunctionCall: {
auto& c = e.template as<FunctionCall>(); auto& c = e.template as<FunctionCall>();
for (auto& arg : c.arguments()) { for (auto& arg : c.arguments()) {
if (arg && this->visitExpression(*arg)) { return true; } if (arg && this->visitExpressionPtr(arg)) { return true; }
} }
return false; return false;
} }
case Expression::Kind::kIndex: { case Expression::Kind::kIndex: {
auto& i = e.template as<IndexExpression>(); auto& i = e.template as<IndexExpression>();
return this->visitExpression(*i.base()) || this->visitExpression(*i.index()); return this->visitExpressionPtr(i.base()) || this->visitExpressionPtr(i.index());
} }
case Expression::Kind::kPostfix: case Expression::Kind::kPostfix:
return this->visitExpression(*e.template as<PostfixExpression>().operand()); return this->visitExpressionPtr(e.template as<PostfixExpression>().operand());
case Expression::Kind::kPrefix: case Expression::Kind::kPrefix:
return this->visitExpression(*e.template as<PrefixExpression>().operand()); return this->visitExpressionPtr(e.template as<PrefixExpression>().operand());
case Expression::Kind::kSwizzle: { case Expression::Kind::kSwizzle: {
auto& s = e.template as<Swizzle>(); auto& s = e.template as<Swizzle>();
return s.base() && this->visitExpression(*s.base()); return s.base() && this->visitExpressionPtr(s.base());
} }
case Expression::Kind::kTernary: { case Expression::Kind::kTernary: {
auto& t = e.template as<TernaryExpression>(); auto& t = e.template as<TernaryExpression>();
return this->visitExpression(*t.test()) || return this->visitExpressionPtr(t.test()) ||
(t.ifTrue() && this->visitExpression(*t.ifTrue())) || (t.ifTrue() && this->visitExpressionPtr(t.ifTrue())) ||
(t.ifFalse() && this->visitExpression(*t.ifFalse())); (t.ifFalse() && this->visitExpressionPtr(t.ifFalse()));
} }
default: default:
SkUNREACHABLE; SkUNREACHABLE;
} }
} }
template <typename PROG, typename EXPR, typename STMT, typename ELEM> template <typename T> bool TProgramVisitor<T>::visitStatement(typename T::Statement& s) {
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
switch (s.kind()) { switch (s.kind()) {
case Statement::Kind::kBreak: case Statement::Kind::kBreak:
case Statement::Kind::kContinue: case Statement::Kind::kContinue:
@ -1180,7 +1178,7 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
case Statement::Kind::kBlock: case Statement::Kind::kBlock:
for (auto& stmt : s.template as<Block>().children()) { for (auto& stmt : s.template as<Block>().children()) {
if (stmt && this->visitStatement(*stmt)) { if (stmt && this->visitStatementPtr(stmt)) {
return true; return true;
} }
} }
@ -1188,42 +1186,42 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
case Statement::Kind::kSwitchCase: { case Statement::Kind::kSwitchCase: {
auto& sc = s.template as<SwitchCase>(); auto& sc = s.template as<SwitchCase>();
if (sc.value() && this->visitExpression(*sc.value())) { if (sc.value() && this->visitExpressionPtr(sc.value())) {
return true; return true;
} }
return this->visitStatement(*sc.statement()); return this->visitStatementPtr(sc.statement());
} }
case Statement::Kind::kDo: { case Statement::Kind::kDo: {
auto& d = s.template as<DoStatement>(); auto& d = s.template as<DoStatement>();
return this->visitExpression(*d.test()) || this->visitStatement(*d.statement()); return this->visitExpressionPtr(d.test()) || this->visitStatementPtr(d.statement());
} }
case Statement::Kind::kExpression: case Statement::Kind::kExpression:
return this->visitExpression(*s.template as<ExpressionStatement>().expression()); return this->visitExpressionPtr(s.template as<ExpressionStatement>().expression());
case Statement::Kind::kFor: { case Statement::Kind::kFor: {
auto& f = s.template as<ForStatement>(); auto& f = s.template as<ForStatement>();
return (f.initializer() && this->visitStatement(*f.initializer())) || return (f.initializer() && this->visitStatementPtr(f.initializer())) ||
(f.test() && this->visitExpression(*f.test())) || (f.test() && this->visitExpressionPtr(f.test())) ||
(f.next() && this->visitExpression(*f.next())) || (f.next() && this->visitExpressionPtr(f.next())) ||
this->visitStatement(*f.statement()); this->visitStatementPtr(f.statement());
} }
case Statement::Kind::kIf: { case Statement::Kind::kIf: {
auto& i = s.template as<IfStatement>(); auto& i = s.template as<IfStatement>();
return (i.test() && this->visitExpression(*i.test())) || return (i.test() && this->visitExpressionPtr(i.test())) ||
(i.ifTrue() && this->visitStatement(*i.ifTrue())) || (i.ifTrue() && this->visitStatementPtr(i.ifTrue())) ||
(i.ifFalse() && this->visitStatement(*i.ifFalse())); (i.ifFalse() && this->visitStatementPtr(i.ifFalse()));
} }
case Statement::Kind::kReturn: { case Statement::Kind::kReturn: {
auto& r = s.template as<ReturnStatement>(); auto& r = s.template as<ReturnStatement>();
return r.expression() && this->visitExpression(*r.expression()); return r.expression() && this->visitExpressionPtr(r.expression());
} }
case Statement::Kind::kSwitch: { case Statement::Kind::kSwitch: {
auto& sw = s.template as<SwitchStatement>(); auto& sw = s.template as<SwitchStatement>();
if (this->visitExpression(*sw.value())) { if (this->visitExpressionPtr(sw.value())) {
return true; return true;
} }
for (const auto& c : sw.cases()) { for (auto& c : sw.cases()) {
if (this->visitStatement(*c)) { if (this->visitStatementPtr(c)) {
return true; return true;
} }
} }
@ -1231,15 +1229,14 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
} }
case Statement::Kind::kVarDeclaration: { case Statement::Kind::kVarDeclaration: {
auto& v = s.template as<VarDeclaration>(); auto& v = s.template as<VarDeclaration>();
return v.value() && this->visitExpression(*v.value()); return v.value() && this->visitExpressionPtr(v.value());
} }
default: default:
SkUNREACHABLE; SkUNREACHABLE;
} }
} }
template <typename PROG, typename EXPR, typename STMT, typename ELEM> template <typename T> bool TProgramVisitor<T>::visitProgramElement(typename T::ProgramElement& pe) {
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
switch (pe.kind()) { switch (pe.kind()) {
case ProgramElement::Kind::kEnum: case ProgramElement::Kind::kEnum:
case ProgramElement::Kind::kExtension: case ProgramElement::Kind::kExtension:
@ -1252,21 +1249,17 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
return false; return false;
case ProgramElement::Kind::kFunction: case ProgramElement::Kind::kFunction:
return this->visitStatement(*pe.template as<FunctionDefinition>().body()); return this->visitStatementPtr(pe.template as<FunctionDefinition>().body());
case ProgramElement::Kind::kGlobalVar: case ProgramElement::Kind::kGlobalVar:
if (this->visitStatement(*pe.template as<GlobalVarDeclaration>().declaration())) { return this->visitStatementPtr(pe.template as<GlobalVarDeclaration>().declaration());
return true;
}
return false;
default: default:
SkUNREACHABLE; SkUNREACHABLE;
} }
} }
template class TProgramVisitor<const Program&, const Expression&, template class TProgramVisitor<ProgramVisitorTypes>;
const Statement&, const ProgramElement&>; template class TProgramVisitor<ProgramWriterTypes>;
template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
} // namespace SkSL } // namespace SkSL

View File

@ -154,16 +154,37 @@ struct Analysis {
* any visit call returns true, the default behavior stops recursing and propagates true up the * any visit call returns true, the default behavior stops recursing and propagates true up the
* stack. * stack.
*/ */
template <typename T>
template <typename PROG, typename EXPR, typename STMT, typename ELEM>
class TProgramVisitor { class TProgramVisitor {
public: public:
virtual ~TProgramVisitor() = default; virtual ~TProgramVisitor() = default;
protected: protected:
virtual bool visitExpression(EXPR expression); virtual bool visitExpression(typename T::Expression& expression);
virtual bool visitStatement(STMT statement); virtual bool visitStatement(typename T::Statement& statement);
virtual bool visitProgramElement(ELEM programElement); virtual bool visitProgramElement(typename T::ProgramElement& programElement);
virtual bool visitExpressionPtr(typename T::UniquePtrExpression& expr) = 0;
virtual bool visitStatementPtr(typename T::UniquePtrStatement& stmt) = 0;
};
// ProgramVisitors take const types; ProgramWriters do not.
struct ProgramVisitorTypes {
using Program = const SkSL::Program;
using Expression = const SkSL::Expression;
using Statement = const SkSL::Statement;
using ProgramElement = const SkSL::ProgramElement;
using UniquePtrExpression = const std::unique_ptr<SkSL::Expression>;
using UniquePtrStatement = const std::unique_ptr<SkSL::Statement>;
};
struct ProgramWriterTypes {
using Program = SkSL::Program;
using Expression = SkSL::Expression;
using Statement = SkSL::Statement;
using ProgramElement = SkSL::ProgramElement;
using UniquePtrExpression = std::unique_ptr<SkSL::Expression>;
using UniquePtrStatement = std::unique_ptr<SkSL::Statement>;
}; };
// Squelch bogus Clang warning about template vtables: https://bugs.llvm.org/show_bug.cgi?id=18733 // Squelch bogus Clang warning about template vtables: https://bugs.llvm.org/show_bug.cgi?id=18733
@ -171,22 +192,38 @@ protected:
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wweak-template-vtables" #pragma clang diagnostic ignored "-Wweak-template-vtables"
#endif #endif
extern template class TProgramVisitor<const Program&, const Expression&, extern template class TProgramVisitor<ProgramVisitorTypes>;
const Statement&, const ProgramElement&>; extern template class TProgramVisitor<ProgramWriterTypes>;
extern template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif
class ProgramVisitor : public TProgramVisitor<const Program&, class ProgramVisitor : public TProgramVisitor<ProgramVisitorTypes> {
const Expression&,
const Statement&,
const ProgramElement&> {
public: public:
bool visit(const Program& program); bool visit(const Program& program);
private:
// ProgramVisitors shouldn't need access to unique_ptrs, and marking these as final should help
// these accessors inline away. Use ProgramWriter if you need the unique_ptrs.
bool visitExpressionPtr(const std::unique_ptr<Expression>& e) final {
return this->visitExpression(*e);
}
bool visitStatementPtr(const std::unique_ptr<Statement>& s) final {
return this->visitStatement(*s);
}
}; };
using ProgramWriter = TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>; class ProgramWriter : public TProgramVisitor<ProgramWriterTypes> {
public:
// Subclass these methods if you want access to the unique_ptrs of IRNodes in a program.
// This will allow statements or expressions to be replaced during a visit.
bool visitExpressionPtr(std::unique_ptr<Expression>& e) override {
return this->visitExpression(*e);
}
bool visitStatementPtr(std::unique_ptr<Statement>& s) override {
return this->visitStatement(*s);
}
};
} // namespace SkSL } // namespace SkSL