Add unique_ptr visitation support to ProgramWriter.

This will allow a subclass to rewrite a Statement or Expression when
visiting it.

Change-Id: Ia8b3121dd0558f2fbbd035d38f7caec9414fe8c3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/382756
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
John Stiles 2021-03-11 14:26:42 -05:00 committed by Skia Commit-Bot
parent 89d6e79ec0
commit 48b255838c
2 changed files with 87 additions and 57 deletions

View File

@ -1097,8 +1097,7 @@ bool ProgramVisitor::visit(const Program& program) {
return false;
}
template <typename PROG, typename EXPR, typename STMT, typename ELEM>
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
template <typename T> bool TProgramVisitor<T>::visitExpression(typename T::Expression& e) {
switch (e.kind()) {
case Expression::Kind::kBoolLiteral:
case Expression::Kind::kDefined:
@ -1114,61 +1113,60 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
case Expression::Kind::kBinary: {
auto& b = e.template as<BinaryExpression>();
return (b.left() && this->visitExpression(*b.left())) ||
(b.right() && this->visitExpression(*b.right()));
return (b.left() && this->visitExpressionPtr(b.left())) ||
(b.right() && this->visitExpressionPtr(b.right()));
}
case Expression::Kind::kConstructor: {
auto& c = e.template as<Constructor>();
for (auto& arg : c.arguments()) {
if (this->visitExpression(*arg)) { return true; }
if (this->visitExpressionPtr(arg)) { return true; }
}
return false;
}
case Expression::Kind::kExternalFunctionCall: {
auto& c = e.template as<ExternalFunctionCall>();
for (auto& arg : c.arguments()) {
if (this->visitExpression(*arg)) { return true; }
if (this->visitExpressionPtr(arg)) { return true; }
}
return false;
}
case Expression::Kind::kFieldAccess:
return this->visitExpression(*e.template as<FieldAccess>().base());
return this->visitExpressionPtr(e.template as<FieldAccess>().base());
case Expression::Kind::kFunctionCall: {
auto& c = e.template as<FunctionCall>();
for (auto& arg : c.arguments()) {
if (arg && this->visitExpression(*arg)) { return true; }
if (arg && this->visitExpressionPtr(arg)) { return true; }
}
return false;
}
case Expression::Kind::kIndex: {
auto& i = e.template as<IndexExpression>();
return this->visitExpression(*i.base()) || this->visitExpression(*i.index());
return this->visitExpressionPtr(i.base()) || this->visitExpressionPtr(i.index());
}
case Expression::Kind::kPostfix:
return this->visitExpression(*e.template as<PostfixExpression>().operand());
return this->visitExpressionPtr(e.template as<PostfixExpression>().operand());
case Expression::Kind::kPrefix:
return this->visitExpression(*e.template as<PrefixExpression>().operand());
return this->visitExpressionPtr(e.template as<PrefixExpression>().operand());
case Expression::Kind::kSwizzle: {
auto& s = e.template as<Swizzle>();
return s.base() && this->visitExpression(*s.base());
return s.base() && this->visitExpressionPtr(s.base());
}
case Expression::Kind::kTernary: {
auto& t = e.template as<TernaryExpression>();
return this->visitExpression(*t.test()) ||
(t.ifTrue() && this->visitExpression(*t.ifTrue())) ||
(t.ifFalse() && this->visitExpression(*t.ifFalse()));
return this->visitExpressionPtr(t.test()) ||
(t.ifTrue() && this->visitExpressionPtr(t.ifTrue())) ||
(t.ifFalse() && this->visitExpressionPtr(t.ifFalse()));
}
default:
SkUNREACHABLE;
}
}
template <typename PROG, typename EXPR, typename STMT, typename ELEM>
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
template <typename T> bool TProgramVisitor<T>::visitStatement(typename T::Statement& s) {
switch (s.kind()) {
case Statement::Kind::kBreak:
case Statement::Kind::kContinue:
@ -1180,7 +1178,7 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
case Statement::Kind::kBlock:
for (auto& stmt : s.template as<Block>().children()) {
if (stmt && this->visitStatement(*stmt)) {
if (stmt && this->visitStatementPtr(stmt)) {
return true;
}
}
@ -1188,42 +1186,42 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
case Statement::Kind::kSwitchCase: {
auto& sc = s.template as<SwitchCase>();
if (sc.value() && this->visitExpression(*sc.value())) {
if (sc.value() && this->visitExpressionPtr(sc.value())) {
return true;
}
return this->visitStatement(*sc.statement());
return this->visitStatementPtr(sc.statement());
}
case Statement::Kind::kDo: {
auto& d = s.template as<DoStatement>();
return this->visitExpression(*d.test()) || this->visitStatement(*d.statement());
return this->visitExpressionPtr(d.test()) || this->visitStatementPtr(d.statement());
}
case Statement::Kind::kExpression:
return this->visitExpression(*s.template as<ExpressionStatement>().expression());
return this->visitExpressionPtr(s.template as<ExpressionStatement>().expression());
case Statement::Kind::kFor: {
auto& f = s.template as<ForStatement>();
return (f.initializer() && this->visitStatement(*f.initializer())) ||
(f.test() && this->visitExpression(*f.test())) ||
(f.next() && this->visitExpression(*f.next())) ||
this->visitStatement(*f.statement());
return (f.initializer() && this->visitStatementPtr(f.initializer())) ||
(f.test() && this->visitExpressionPtr(f.test())) ||
(f.next() && this->visitExpressionPtr(f.next())) ||
this->visitStatementPtr(f.statement());
}
case Statement::Kind::kIf: {
auto& i = s.template as<IfStatement>();
return (i.test() && this->visitExpression(*i.test())) ||
(i.ifTrue() && this->visitStatement(*i.ifTrue())) ||
(i.ifFalse() && this->visitStatement(*i.ifFalse()));
return (i.test() && this->visitExpressionPtr(i.test())) ||
(i.ifTrue() && this->visitStatementPtr(i.ifTrue())) ||
(i.ifFalse() && this->visitStatementPtr(i.ifFalse()));
}
case Statement::Kind::kReturn: {
auto& r = s.template as<ReturnStatement>();
return r.expression() && this->visitExpression(*r.expression());
return r.expression() && this->visitExpressionPtr(r.expression());
}
case Statement::Kind::kSwitch: {
auto& sw = s.template as<SwitchStatement>();
if (this->visitExpression(*sw.value())) {
if (this->visitExpressionPtr(sw.value())) {
return true;
}
for (const auto& c : sw.cases()) {
if (this->visitStatement(*c)) {
for (auto& c : sw.cases()) {
if (this->visitStatementPtr(c)) {
return true;
}
}
@ -1231,15 +1229,14 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
}
case Statement::Kind::kVarDeclaration: {
auto& v = s.template as<VarDeclaration>();
return v.value() && this->visitExpression(*v.value());
return v.value() && this->visitExpressionPtr(v.value());
}
default:
SkUNREACHABLE;
}
}
template <typename PROG, typename EXPR, typename STMT, typename ELEM>
bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
template <typename T> bool TProgramVisitor<T>::visitProgramElement(typename T::ProgramElement& pe) {
switch (pe.kind()) {
case ProgramElement::Kind::kEnum:
case ProgramElement::Kind::kExtension:
@ -1252,21 +1249,17 @@ bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
return false;
case ProgramElement::Kind::kFunction:
return this->visitStatement(*pe.template as<FunctionDefinition>().body());
return this->visitStatementPtr(pe.template as<FunctionDefinition>().body());
case ProgramElement::Kind::kGlobalVar:
if (this->visitStatement(*pe.template as<GlobalVarDeclaration>().declaration())) {
return true;
}
return false;
return this->visitStatementPtr(pe.template as<GlobalVarDeclaration>().declaration());
default:
SkUNREACHABLE;
}
}
template class TProgramVisitor<const Program&, const Expression&,
const Statement&, const ProgramElement&>;
template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
template class TProgramVisitor<ProgramVisitorTypes>;
template class TProgramVisitor<ProgramWriterTypes>;
} // namespace SkSL

View File

@ -154,16 +154,37 @@ struct Analysis {
* any visit call returns true, the default behavior stops recursing and propagates true up the
* stack.
*/
template <typename PROG, typename EXPR, typename STMT, typename ELEM>
template <typename T>
class TProgramVisitor {
public:
virtual ~TProgramVisitor() = default;
protected:
virtual bool visitExpression(EXPR expression);
virtual bool visitStatement(STMT statement);
virtual bool visitProgramElement(ELEM programElement);
virtual bool visitExpression(typename T::Expression& expression);
virtual bool visitStatement(typename T::Statement& statement);
virtual bool visitProgramElement(typename T::ProgramElement& programElement);
virtual bool visitExpressionPtr(typename T::UniquePtrExpression& expr) = 0;
virtual bool visitStatementPtr(typename T::UniquePtrStatement& stmt) = 0;
};
// ProgramVisitors take const types; ProgramWriters do not.
struct ProgramVisitorTypes {
using Program = const SkSL::Program;
using Expression = const SkSL::Expression;
using Statement = const SkSL::Statement;
using ProgramElement = const SkSL::ProgramElement;
using UniquePtrExpression = const std::unique_ptr<SkSL::Expression>;
using UniquePtrStatement = const std::unique_ptr<SkSL::Statement>;
};
struct ProgramWriterTypes {
using Program = SkSL::Program;
using Expression = SkSL::Expression;
using Statement = SkSL::Statement;
using ProgramElement = SkSL::ProgramElement;
using UniquePtrExpression = std::unique_ptr<SkSL::Expression>;
using UniquePtrStatement = std::unique_ptr<SkSL::Statement>;
};
// Squelch bogus Clang warning about template vtables: https://bugs.llvm.org/show_bug.cgi?id=18733
@ -171,22 +192,38 @@ protected:
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wweak-template-vtables"
#endif
extern template class TProgramVisitor<const Program&, const Expression&,
const Statement&, const ProgramElement&>;
extern template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
extern template class TProgramVisitor<ProgramVisitorTypes>;
extern template class TProgramVisitor<ProgramWriterTypes>;
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
class ProgramVisitor : public TProgramVisitor<const Program&,
const Expression&,
const Statement&,
const ProgramElement&> {
class ProgramVisitor : public TProgramVisitor<ProgramVisitorTypes> {
public:
bool visit(const Program& program);
private:
// ProgramVisitors shouldn't need access to unique_ptrs, and marking these as final should help
// these accessors inline away. Use ProgramWriter if you need the unique_ptrs.
bool visitExpressionPtr(const std::unique_ptr<Expression>& e) final {
return this->visitExpression(*e);
}
bool visitStatementPtr(const std::unique_ptr<Statement>& s) final {
return this->visitStatement(*s);
}
};
using ProgramWriter = TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
class ProgramWriter : public TProgramVisitor<ProgramWriterTypes> {
public:
// Subclass these methods if you want access to the unique_ptrs of IRNodes in a program.
// This will allow statements or expressions to be replaced during a visit.
bool visitExpressionPtr(std::unique_ptr<Expression>& e) override {
return this->visitExpression(*e);
}
bool visitStatementPtr(std::unique_ptr<Statement>& s) override {
return this->visitStatement(*s);
}
};
} // namespace SkSL