From 3dc0da6c6d7ed99fa82c28d02b7bcdc08110fa54 Mon Sep 17 00:00:00 2001 From: John Stiles Date: Wed, 19 Aug 2020 17:48:31 -0400 Subject: [PATCH] Add as to downcast ProgramElements more safely. The as() function asserts that the ProgramElement is of the correct kind before performing the downcast, and is also generally easier to read as function calls flow naturally from left-to-right, and C-style casts don't. This CL updates several downcasts throughout SkSL to the as syntax, but is not intended to exhaustively replace them all (although that would be ideal). In places where we SkASSERTed the element's fKind immediately before a cast, the assert has been removed because it would be redundant with the behavior of as(). Change-Id: I89a487aeaf56e56c720479fee0c2633377a202f1 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/312020 Commit-Queue: Ethan Nicholas Reviewed-by: Ethan Nicholas Auto-Submit: John Stiles --- src/sksl/SkSLAnalysis.cpp | 6 +-- src/sksl/SkSLByteCodeGenerator.cpp | 9 ++-- src/sksl/SkSLCPPCodeGenerator.cpp | 44 +++++++-------- src/sksl/SkSLCompiler.cpp | 10 ++-- src/sksl/SkSLDehydrator.cpp | 8 +-- src/sksl/SkSLGLSLCodeGenerator.cpp | 20 +++---- src/sksl/SkSLMetalCodeGenerator.cpp | 74 +++++++++++++------------- src/sksl/ir/SkSLEnum.h | 4 +- src/sksl/ir/SkSLExtension.h | 4 +- src/sksl/ir/SkSLFunctionDefinition.h | 4 +- src/sksl/ir/SkSLInterfaceBlock.h | 4 +- src/sksl/ir/SkSLModifiersDeclaration.h | 4 +- src/sksl/ir/SkSLProgramElement.h | 15 ++++++ src/sksl/ir/SkSLSection.h | 4 +- src/sksl/ir/SkSLVarDeclarations.h | 4 +- 15 files changed, 122 insertions(+), 92 deletions(-) diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp index ca092ebfd8..2b83e31044 100644 --- a/src/sksl/SkSLAnalysis.cpp +++ b/src/sksl/SkSLAnalysis.cpp @@ -316,14 +316,14 @@ bool ProgramVisitor::visitProgramElement(const ProgramElement& pe) { // Leaf program elements just return false by default return false; case ProgramElement::kFunction_Kind: - return this->visitStatement(*((const FunctionDefinition&) pe).fBody); + return this->visitStatement(*pe.as().fBody); case ProgramElement::kInterfaceBlock_Kind: - for (const auto& e : ((const InterfaceBlock&) pe).fSizes) { + for (const auto& e : pe.as().fSizes) { if (this->visitExpression(*e)) { return true; } } return false; case ProgramElement::kVar_Kind: - for (const auto& v : ((const VarDeclarations&) pe).fVars) { + for (const auto& v : pe.as().fVars) { if (this->visitStatement(*v)) { return true; } } return false; diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp index da59730c77..86c874ebe8 100644 --- a/src/sksl/SkSLByteCodeGenerator.cpp +++ b/src/sksl/SkSLByteCodeGenerator.cpp @@ -142,18 +142,19 @@ bool ByteCodeGenerator::generateCode() { for (const auto& e : fProgram) { switch (e.fKind) { case ProgramElement::kFunction_Kind: { - std::unique_ptr f = this->writeFunction((FunctionDefinition&) e); + std::unique_ptr f = + this->writeFunction(e.as()); if (!f) { return false; } fOutput->fFunctions.push_back(std::move(f)); - fFunctions.push_back(&(FunctionDefinition&)e); + fFunctions.push_back(&e.as()); break; } case ProgramElement::kVar_Kind: { - VarDeclarations& decl = (VarDeclarations&) e; + const VarDeclarations& decl = e.as(); for (const auto& v : decl.fVars) { - const Variable* declVar = ((VarDeclaration&) *v).fVar; + const Variable* declVar = v->as().fVar; if (declVar->fType == *fContext.fFragmentProcessor_Type) { fOutput->fChildFPCount++; } diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp index 5244a91418..67cbe36ee7 100644 --- a/src/sksl/SkSLCPPCodeGenerator.cpp +++ b/src/sksl/SkSLCPPCodeGenerator.cpp @@ -113,7 +113,7 @@ void CPPCodeGenerator::writeBinaryExpression(const BinaryExpression& b, void CPPCodeGenerator::writeIndexExpression(const IndexExpression& i) { const Expression& base = *i.fBase; if (base.fKind == Expression::kVariableReference_Kind) { - int builtin = ((VariableReference&) base).fVariable.fModifiers.fLayout.fBuiltin; + int builtin = base.as().fVariable.fModifiers.fLayout.fBuiltin; if (SK_TEXTURESAMPLERS_BUILTIN == builtin) { this->write("%s"); if (i.fIndex->fKind != Expression::kIntLiteral_Kind) { @@ -385,7 +385,7 @@ void CPPCodeGenerator::writeFieldAccess(const FieldAccess& access) { } const Type::Field& field = fContext.fFragmentProcessor_Type->fields()[access.fFieldIndex]; - const Variable& var = ((const VariableReference&) *access.fBase).fVariable; + const Variable& var = access.fBase->as().fVariable; String cppAccess = String::printf("_outer.childProcessor(%d)->%s()", this->getChildFPIndex(var), String(field.fName).c_str()); @@ -405,9 +405,9 @@ int CPPCodeGenerator::getChildFPIndex(const Variable& var) const { bool found = false; for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - const VarDeclaration& decl = (VarDeclaration&) *raw; + const VarDeclaration& decl = raw->as(); if (decl.fVar == &var) { found = true; } else if (decl.fVar->fType.nonnullable() == *fContext.fFragmentProcessor_Type) { @@ -439,7 +439,7 @@ void CPPCodeGenerator::writeFunctionCall(const FunctionCall& c) { "sample()'s fragmentProcessor argument must be a variable reference\n"); return; } - const Variable& child = ((const VariableReference&) *c.fArguments[0]).fVariable; + const Variable& child = c.fArguments[0]->as().fVariable; // Start a new extra emit code section so that the emitted child processor can depend on // sksl variables defined in earlier sksl code. @@ -508,7 +508,7 @@ void CPPCodeGenerator::writeFunctionCall(const FunctionCall& c) { this->write(".%s"); SkASSERT(c.fArguments.size() >= 1); SkASSERT(c.fArguments[0]->fKind == Expression::kVariableReference_Kind); - String sampler = this->getSamplerHandle(((VariableReference&) *c.fArguments[0]).fVariable); + String sampler = this->getSamplerHandle(c.fArguments[0]->as().fVariable); fFormatArgs.push_back("fragBuilder->getProgramBuilder()->samplerSwizzle(" + sampler + ").asString().c_str()"); } @@ -595,7 +595,7 @@ void CPPCodeGenerator::writeFunction(const FunctionDefinition& f) { fOut = &buffer; if (decl.fName == "main") { fInMain = true; - for (const auto& s : ((Block&) *f.fBody).fStatements) { + for (const auto& s : f.fBody->as().fStatements) { this->writeStatement(*s); this->writeLine(); } @@ -615,7 +615,7 @@ void CPPCodeGenerator::writeFunction(const FunctionDefinition& f) { } args += "};"; this->addExtraEmitCodeLine(args.c_str()); - for (const auto& s : ((Block&) *f.fBody).fStatements) { + for (const auto& s : f.fBody->as().fStatements) { this->writeStatement(*s); this->writeLine(); } @@ -650,11 +650,11 @@ void CPPCodeGenerator::writeProgramElement(const ProgramElement& p) { return; } if (p.fKind == ProgramElement::kVar_Kind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); if (!decls.fVars.size()) { return; } - const Variable& var = *((VarDeclaration&) *decls.fVars[0]).fVar; + const Variable& var = *decls.fVars[0]->as().fVar; if (var.fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kUniform_Flag) || -1 != var.fModifiers.fLayout.fBuiltin) { return; @@ -686,9 +686,9 @@ void CPPCodeGenerator::writeInputVars() { void CPPCodeGenerator::writePrivateVars() { for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - VarDeclaration& decl = (VarDeclaration&) *raw; + VarDeclaration& decl = raw->as(); if (is_private(*decl.fVar)) { if (decl.fVar->fType == *fContext.fFragmentProcessor_Type) { fErrors.error(decl.fOffset, @@ -726,9 +726,9 @@ void CPPCodeGenerator::writePrivateVars() { void CPPCodeGenerator::writePrivateVarValues() { for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - VarDeclaration& decl = (VarDeclaration&) *raw; + VarDeclaration& decl = raw->as(); if (is_private(*decl.fVar) && decl.fValue) { this->writef("%s = ", String(decl.fVar->fName).c_str()); fCPPMode = true; @@ -943,9 +943,9 @@ bool CPPCodeGenerator::writeEmitCode(std::vector& uniforms) { fFullName.c_str(), fFullName.c_str()); for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - VarDeclaration& decl = (VarDeclaration&) *raw; + VarDeclaration& decl = raw->as(); String nameString(decl.fVar->fName); const char* name = nameString.c_str(); if (SectionAndParameterHelper::IsParameter(*decl.fVar) && @@ -1059,9 +1059,9 @@ void CPPCodeGenerator::writeSetData(std::vector& uniforms) { int samplerIndex = 0; for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const std::unique_ptr& raw : decls.fVars) { - const VarDeclaration& decl = static_cast(*raw); + const VarDeclaration& decl = raw->as(); const Variable& variable = *decl.fVar; String nameString(variable.fName); const char* name = nameString.c_str(); @@ -1233,9 +1233,9 @@ void CPPCodeGenerator::writeGetKey() { fFullName.c_str()); for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - const VarDeclaration& decl = (VarDeclaration&) *raw; + const VarDeclaration& decl = raw->as(); const Variable& var = *decl.fVar; String nameString(var.fName); const char* name = nameString.c_str(); @@ -1311,9 +1311,9 @@ bool CPPCodeGenerator::generateCode() { std::vector uniforms; for (const auto& p : fProgram) { if (ProgramElement::kVar_Kind == p.fKind) { - const VarDeclarations& decls = (const VarDeclarations&) p; + const VarDeclarations& decls = p.as(); for (const auto& raw : decls.fVars) { - VarDeclaration& decl = (VarDeclaration&) *raw; + VarDeclaration& decl = raw->as(); if ((decl.fVar->fModifiers.fFlags & Modifiers::kUniform_Flag) && decl.fVar->fType.kind() != Type::kSampler_Kind) { uniforms.push_back(decl.fVar); diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index b133fbfb3b..e3352feba2 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -75,7 +75,7 @@ static void grab_intrinsics(std::vector>* src, std::unique_ptr& element = *iter; switch (element->fKind) { case ProgramElement::kFunction_Kind: { - FunctionDefinition& f = (FunctionDefinition&) *element; + FunctionDefinition& f = element->as(); SkASSERT(f.fDeclaration.fBuiltin); String key = f.fDeclaration.description(); SkASSERT(target->find(key) == target->end()); @@ -84,7 +84,7 @@ static void grab_intrinsics(std::vector>* src, break; } case ProgramElement::kEnum_Kind: { - Enum& e = (Enum&) *element; + Enum& e = element->as(); StringFragment name = e.fTypeName; SkASSERT(target->find(name) == target->end()); (*target)[name] = std::make_pair(std::move(element), false); @@ -1637,7 +1637,7 @@ bool Compiler::optimize(Program& program) { fIRGenerator->fSettings = &program.fSettings; for (auto& element : program) { if (element.fKind == ProgramElement::kFunction_Kind) { - this->scanCFG((FunctionDefinition&) element); + this->scanCFG(element.as()); } } // we wait until after analysis to remove dead functions so that we still report errors @@ -1645,7 +1645,7 @@ bool Compiler::optimize(Program& program) { if (program.fSettings.fRemoveDeadFunctions) { for (auto iter = program.fElements.begin(); iter != program.fElements.end(); ) { if ((*iter)->fKind == ProgramElement::kFunction_Kind) { - const FunctionDefinition& f = (const FunctionDefinition&) **iter; + const FunctionDefinition& f = (*iter)->as(); if (!f.fDeclaration.fCallCount && f.fDeclaration.fName != "main") { iter = program.fElements.erase(iter); continue; @@ -1657,7 +1657,7 @@ bool Compiler::optimize(Program& program) { if (program.fKind != Program::kFragmentProcessor_Kind) { for (auto iter = program.fElements.begin(); iter != program.fElements.end();) { if ((*iter)->fKind == ProgramElement::kVar_Kind) { - VarDeclarations& vars = (VarDeclarations&) **iter; + VarDeclarations& vars = (*iter)->as(); for (auto varIter = vars.fVars.begin(); varIter != vars.fVars.end();) { const Variable& var = *((VarDeclaration&) **varIter).fVar; if (var.dead()) { diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp index 290805555c..a47d6fe2dd 100644 --- a/src/sksl/SkSLDehydrator.cpp +++ b/src/sksl/SkSLDehydrator.cpp @@ -496,7 +496,7 @@ void Dehydrator::write(const Statement* s) { void Dehydrator::write(const ProgramElement& e) { switch (e.fKind) { case ProgramElement::kEnum_Kind: { - Enum& en = (Enum&) e; + const Enum& en = e.as(); this->writeU8(Rehydrator::kEnum_Command); this->write(en.fTypeName); AutoDehydratorSymbolTable symbols(this, en.fSymbols); @@ -513,7 +513,7 @@ void Dehydrator::write(const ProgramElement& e) { SkASSERT(false); break; case ProgramElement::kFunction_Kind: { - FunctionDefinition& f = (FunctionDefinition&) e; + const FunctionDefinition& f = e.as(); this->writeU8(Rehydrator::kFunctionDefinition_Command); this->writeU16(this->symbolId(&f.fDeclaration)); this->write(f.fBody.get()); @@ -528,7 +528,7 @@ void Dehydrator::write(const ProgramElement& e) { break; } case ProgramElement::kInterfaceBlock_Kind: { - InterfaceBlock& i = (InterfaceBlock&) e; + const InterfaceBlock& i = e.as(); this->writeU8(Rehydrator::kInterfaceBlock_Command); this->write(i.fVariable); this->write(i.fTypeName); @@ -546,7 +546,7 @@ void Dehydrator::write(const ProgramElement& e) { SkASSERT(false); break; case ProgramElement::kVar_Kind: { - VarDeclarations& v = (VarDeclarations&) e; + const VarDeclarations& v = e.as(); this->writeU8(Rehydrator::kVarDeclarations_Command); this->write(v.fBaseType); this->writeU8(v.fVars.size()); diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp index 522b2febef..e1ced4cda8 100644 --- a/src/sksl/SkSLGLSLCodeGenerator.cpp +++ b/src/sksl/SkSLGLSLCodeGenerator.cpp @@ -1245,7 +1245,7 @@ void GLSLCodeGenerator::writeFunction(const FunctionDefinition& f) { OutputStream* oldOut = fOut; StringStream buffer; fOut = &buffer; - this->writeStatements(((Block&) *f.fBody).fStatements); + this->writeStatements(f.fBody->as().fStatements); if (fProgramKind != Program::kPipelineStage_Kind) { fIndentation--; this->writeLine("}"); @@ -1416,7 +1416,7 @@ void GLSLCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool g } bool wroteType = false; for (const auto& stmt : decl.fVars) { - VarDeclaration& var = (VarDeclaration&) *stmt; + const VarDeclaration& var = stmt->as(); if (wroteType) { this->write(", "); } else { @@ -1460,7 +1460,7 @@ void GLSLCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool g void GLSLCodeGenerator::writeStatement(const Statement& s) { switch (s.fKind) { case Statement::kBlock_Kind: - this->writeBlock((Block&) s); + this->writeBlock(s.as()); break; case Statement::kExpression_Kind: this->writeExpression(*s.as().fExpression, kTopLevel_Precedence); @@ -1666,19 +1666,19 @@ void GLSLCodeGenerator::writeHeader() { void GLSLCodeGenerator::writeProgramElement(const ProgramElement& e) { switch (e.fKind) { case ProgramElement::kExtension_Kind: - this->writeExtension(((Extension&) e).fName); + this->writeExtension(e.as().fName); break; case ProgramElement::kVar_Kind: { - VarDeclarations& decl = (VarDeclarations&) e; + const VarDeclarations& decl = e.as(); if (decl.fVars.size() > 0) { - int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin; + int builtin = decl.fVars[0]->as().fVar->fModifiers.fLayout.fBuiltin; if (builtin == -1) { // normal var this->writeVarDeclarations(decl, true); this->writeLine(); } else if (builtin == SK_FRAGCOLOR_BUILTIN && fProgram.fSettings.fCaps->mustDeclareFragmentShaderOutput() && - ((VarDeclaration&) *decl.fVars[0]).fVar->fWriteCount) { + decl.fVars[0]->as().fVar->fWriteCount) { if (fProgram.fSettings.fFragColorIsInOut) { this->write("inout "); } else { @@ -1693,13 +1693,13 @@ void GLSLCodeGenerator::writeProgramElement(const ProgramElement& e) { break; } case ProgramElement::kInterfaceBlock_Kind: - this->writeInterfaceBlock((InterfaceBlock&) e); + this->writeInterfaceBlock(e.as()); break; case ProgramElement::kFunction_Kind: - this->writeFunction((FunctionDefinition&) e); + this->writeFunction(e.as()); break; case ProgramElement::kModifiers_Kind: { - const Modifiers& modifiers = ((ModifiersDeclaration&) e).fModifiers; + const Modifiers& modifiers = e.as().fModifiers; if (!fFoundGSInvocations && modifiers.fLayout.fInvocations >= 0) { if (fProgram.fSettings.fCaps->gsInvocationsExtensionString()) { this->writeExtension(fProgram.fSettings.fCaps->gsInvocationsExtensionString()); diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp index 9fdee068f1..fa5d17489c 100644 --- a/src/sksl/SkSLMetalCodeGenerator.cpp +++ b/src/sksl/SkSLMetalCodeGenerator.cpp @@ -952,12 +952,12 @@ void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) { } for (const auto& e : fProgram) { if (ProgramElement::kVar_Kind == e.fKind) { - VarDeclarations& decls = (VarDeclarations&) e; + const VarDeclarations& decls = e.as(); if (!decls.fVars.size()) { continue; } for (const auto& stmt: decls.fVars) { - VarDeclaration& var = (VarDeclaration&) *stmt; + VarDeclaration& var = stmt->as(); if (var.fVar->fType.kind() == Type::kSampler_Kind) { if (var.fVar->fModifiers.fLayout.fBinding < 0) { fErrors.error(decls.fOffset, @@ -1392,11 +1392,11 @@ void MetalCodeGenerator::writeHeader() { void MetalCodeGenerator::writeUniformStruct() { for (const auto& e : fProgram) { if (ProgramElement::kVar_Kind == e.fKind) { - VarDeclarations& decls = (VarDeclarations&) e; + const VarDeclarations& decls = e.as(); if (!decls.fVars.size()) { continue; } - const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; + const Variable& first = *decls.fVars[0]->as().fVar; if (first.fModifiers.fFlags & Modifiers::kUniform_Flag && first.fType.kind() != Type::kSampler_Kind) { if (-1 == fUniformBuffer) { @@ -1415,7 +1415,7 @@ void MetalCodeGenerator::writeUniformStruct() { this->writeType(first.fType); this->write(" "); for (const auto& stmt : decls.fVars) { - VarDeclaration& var = (VarDeclaration&) *stmt; + const VarDeclaration& var = stmt->as(); this->writeName(var.fVar->fName); } this->write(";\n"); @@ -1431,18 +1431,18 @@ void MetalCodeGenerator::writeInputStruct() { this->write("struct Inputs {\n"); for (const auto& e : fProgram) { if (ProgramElement::kVar_Kind == e.fKind) { - VarDeclarations& decls = (VarDeclarations&) e; + const VarDeclarations& decls = e.as(); if (!decls.fVars.size()) { continue; } - const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; + const Variable& first = *decls.fVars[0]->as().fVar; if (first.fModifiers.fFlags & Modifiers::kIn_Flag && -1 == first.fModifiers.fLayout.fBuiltin) { this->write(" "); this->writeType(first.fType); this->write(" "); for (const auto& stmt : decls.fVars) { - VarDeclaration& var = (VarDeclaration&) *stmt; + const VarDeclaration& var = stmt->as(); this->writeName(var.fVar->fName); if (-1 != var.fVar->fModifiers.fLayout.fLocation) { if (fProgram.fKind == Program::kVertex_Kind) { @@ -1470,18 +1470,18 @@ void MetalCodeGenerator::writeOutputStruct() { } for (const auto& e : fProgram) { if (ProgramElement::kVar_Kind == e.fKind) { - VarDeclarations& decls = (VarDeclarations&) e; + const VarDeclarations& decls = e.as(); if (!decls.fVars.size()) { continue; } - const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; + const Variable& first = *decls.fVars[0]->as().fVar; if (first.fModifiers.fFlags & Modifiers::kOut_Flag && -1 == first.fModifiers.fLayout.fBuiltin) { this->write(" "); this->writeType(first.fType); this->write(" "); for (const auto& stmt : decls.fVars) { - VarDeclaration& var = (VarDeclaration&) *stmt; + const VarDeclaration& var = stmt->as(); this->writeName(var.fVar->fName); if (fProgram.fKind == Program::kVertex_Kind) { this->write(" [[user(locn" + @@ -1510,7 +1510,7 @@ void MetalCodeGenerator::writeInterfaceBlocks() { bool wroteInterfaceBlock = false; for (const auto& e : fProgram) { if (ProgramElement::kInterfaceBlock_Kind == e.fKind) { - this->writeInterfaceBlock((InterfaceBlock&) e); + this->writeInterfaceBlock(e.as()); wroteInterfaceBlock = true; } } @@ -1662,9 +1662,9 @@ void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) { case ProgramElement::kExtension_Kind: break; case ProgramElement::kVar_Kind: { - VarDeclarations& decl = (VarDeclarations&) e; + const VarDeclarations& decl = e.as(); if (decl.fVars.size() > 0) { - int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin; + int builtin = decl.fVars[0]->as().fVar->fModifiers.fLayout.fBuiltin; if (-1 == builtin) { // normal var this->writeVarDeclarations(decl, true); @@ -1679,10 +1679,10 @@ void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) { // handled in writeInterfaceBlocks, do nothing break; case ProgramElement::kFunction_Kind: - this->writeFunction((FunctionDefinition&) e); + this->writeFunction(e.as()); break; case ProgramElement::kModifiers_Kind: - this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true); + this->writeModifiers(e.as().fModifiers, true); this->writeLine(";"); break; default: @@ -1699,7 +1699,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi } switch (e->fKind) { case Expression::kFunctionCall_Kind: { - const FunctionCall& f = (const FunctionCall&) *e; + const FunctionCall& f = e->as(); Requirements result = this->requirements(f.fFunction); for (const auto& arg : f.fArguments) { result |= this->requirements(arg.get()); @@ -1707,7 +1707,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi return result; } case Expression::kConstructor_Kind: { - const Constructor& c = (const Constructor&) *e; + const Constructor& c = e->as(); Requirements result = kNo_Requirements; for (const auto& arg : c.fArguments) { result |= this->requirements(arg.get()); @@ -1715,33 +1715,33 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi return result; } case Expression::kFieldAccess_Kind: { - const FieldAccess& f = (const FieldAccess&) *e; + const FieldAccess& f = e->as(); if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) { return kGlobals_Requirement; } return this->requirements(f.fBase.get()); } case Expression::kSwizzle_Kind: - return this->requirements(((const Swizzle&) *e).fBase.get()); + return this->requirements(e->as().fBase.get()); case Expression::kBinary_Kind: { - const BinaryExpression& b = (const BinaryExpression&) *e; + const BinaryExpression& b = e->as(); return this->requirements(b.fLeft.get()) | this->requirements(b.fRight.get()); } case Expression::kIndex_Kind: { - const IndexExpression& idx = (const IndexExpression&) *e; + const IndexExpression& idx = e->as(); return this->requirements(idx.fBase.get()) | this->requirements(idx.fIndex.get()); } case Expression::kPrefix_Kind: - return this->requirements(((const PrefixExpression&) *e).fOperand.get()); + return this->requirements(e->as().fOperand.get()); case Expression::kPostfix_Kind: - return this->requirements(((const PostfixExpression&) *e).fOperand.get()); + return this->requirements(e->as().fOperand.get()); case Expression::kTernary_Kind: { - const TernaryExpression& t = (const TernaryExpression&) *e; + const TernaryExpression& t = e->as(); return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) | this->requirements(t.fIfFalse.get()); } case Expression::kVariableReference_Kind: { - const VariableReference& v = (const VariableReference&) *e; + const VariableReference& v = e->as(); Requirements result = kNo_Requirements; if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) { result = kGlobals_Requirement | kFragCoord_Requirement; @@ -1771,54 +1771,54 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statemen switch (s->fKind) { case Statement::kBlock_Kind: { Requirements result = kNo_Requirements; - for (const auto& child : ((const Block*) s)->fStatements) { + for (const auto& child : s->as().fStatements) { result |= this->requirements(child.get()); } return result; } case Statement::kVarDeclaration_Kind: { - const VarDeclaration& var = (const VarDeclaration&) *s; + const VarDeclaration& var = s->as(); return this->requirements(var.fValue.get()); } case Statement::kVarDeclarations_Kind: { Requirements result = kNo_Requirements; - const VarDeclarations& decls = *((const VarDeclarationsStatement&) *s).fDeclaration; + const VarDeclarations& decls = *s->as().fDeclaration; for (const auto& stmt : decls.fVars) { result |= this->requirements(stmt.get()); } return result; } case Statement::kExpression_Kind: - return this->requirements(((const ExpressionStatement&) *s).fExpression.get()); + return this->requirements(s->as().fExpression.get()); case Statement::kReturn_Kind: { - const ReturnStatement& r = (const ReturnStatement&) *s; + const ReturnStatement& r = s->as(); return this->requirements(r.fExpression.get()); } case Statement::kIf_Kind: { - const IfStatement& i = (const IfStatement&) *s; + const IfStatement& i = s->as(); return this->requirements(i.fTest.get()) | this->requirements(i.fIfTrue.get()) | this->requirements(i.fIfFalse.get()); } case Statement::kFor_Kind: { - const ForStatement& f = (const ForStatement&) *s; + const ForStatement& f = s->as(); return this->requirements(f.fInitializer.get()) | this->requirements(f.fTest.get()) | this->requirements(f.fNext.get()) | this->requirements(f.fStatement.get()); } case Statement::kWhile_Kind: { - const WhileStatement& w = (const WhileStatement&) *s; + const WhileStatement& w = s->as(); return this->requirements(w.fTest.get()) | this->requirements(w.fStatement.get()); } case Statement::kDo_Kind: { - const DoStatement& d = (const DoStatement&) *s; + const DoStatement& d = s->as(); return this->requirements(d.fTest.get()) | this->requirements(d.fStatement.get()); } case Statement::kSwitch_Kind: { - const SwitchStatement& sw = (const SwitchStatement&) *s; + const SwitchStatement& sw = s->as(); Requirements result = this->requirements(sw.fValue.get()); for (const auto& c : sw.fCases) { for (const auto& st : c->fStatements) { @@ -1841,7 +1841,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Function fRequirements[&f] = kNo_Requirements; for (const auto& e : fProgram) { if (ProgramElement::kFunction_Kind == e.fKind) { - const FunctionDefinition& def = (const FunctionDefinition&) e; + const FunctionDefinition& def = e.as(); if (&def.fDeclaration == &f) { Requirements reqs = this->requirements(def.fBody.get()); fRequirements[&f] = reqs; diff --git a/src/sksl/ir/SkSLEnum.h b/src/sksl/ir/SkSLEnum.h index 3b3e3bd3fd..feb30249e9 100644 --- a/src/sksl/ir/SkSLEnum.h +++ b/src/sksl/ir/SkSLEnum.h @@ -21,9 +21,11 @@ namespace SkSL { struct Symbol; struct Enum : public ProgramElement { + static constexpr Kind kProgramElementKind = kEnum_Kind; + Enum(int offset, StringFragment typeName, std::shared_ptr symbols, bool isBuiltin = true) - : INHERITED(offset, kEnum_Kind) + : INHERITED(offset, kProgramElementKind) , fTypeName(typeName) , fSymbols(std::move(symbols)) , fBuiltin(isBuiltin) {} diff --git a/src/sksl/ir/SkSLExtension.h b/src/sksl/ir/SkSLExtension.h index 170adeeea9..e0d418b176 100644 --- a/src/sksl/ir/SkSLExtension.h +++ b/src/sksl/ir/SkSLExtension.h @@ -16,8 +16,10 @@ namespace SkSL { * An extension declaration. */ struct Extension : public ProgramElement { + static constexpr Kind kProgramElementKind = kExtension_Kind; + Extension(int offset, String name) - : INHERITED(offset, kExtension_Kind) + : INHERITED(offset, kProgramElementKind) , fName(std::move(name)) {} std::unique_ptr clone() const override { diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h index df65a1d4d6..323169c3fb 100644 --- a/src/sksl/ir/SkSLFunctionDefinition.h +++ b/src/sksl/ir/SkSLFunctionDefinition.h @@ -22,11 +22,13 @@ struct ASTNode; * A function definition (a declaration plus an associated block of code). */ struct FunctionDefinition : public ProgramElement { + static constexpr Kind kProgramElementKind = kFunction_Kind; + FunctionDefinition(int offset, const FunctionDeclaration& declaration, std::unique_ptr body, std::unordered_set referencedIntrinsics = {}) - : INHERITED(offset, kFunction_Kind) + : INHERITED(offset, kProgramElementKind) , fDeclaration(declaration) , fBody(std::move(body)) , fReferencedIntrinsics(std::move(referencedIntrinsics)) {} diff --git a/src/sksl/ir/SkSLInterfaceBlock.h b/src/sksl/ir/SkSLInterfaceBlock.h index ec2875ef03..56b7e2f5b8 100644 --- a/src/sksl/ir/SkSLInterfaceBlock.h +++ b/src/sksl/ir/SkSLInterfaceBlock.h @@ -25,10 +25,12 @@ namespace SkSL { * At the IR level, this is represented by a single variable of struct type. */ struct InterfaceBlock : public ProgramElement { + static constexpr Kind kProgramElementKind = kInterfaceBlock_Kind; + InterfaceBlock(int offset, const Variable* var, String typeName, String instanceName, std::vector> sizes, std::shared_ptr typeOwner) - : INHERITED(offset, kInterfaceBlock_Kind) + : INHERITED(offset, kProgramElementKind) , fVariable(*var) , fTypeName(std::move(typeName)) , fInstanceName(std::move(instanceName)) diff --git a/src/sksl/ir/SkSLModifiersDeclaration.h b/src/sksl/ir/SkSLModifiersDeclaration.h index 2e7c05db0e..411b4b97cb 100644 --- a/src/sksl/ir/SkSLModifiersDeclaration.h +++ b/src/sksl/ir/SkSLModifiersDeclaration.h @@ -19,8 +19,10 @@ namespace SkSL { * layout(blend_support_all_equations) out; */ struct ModifiersDeclaration : public ProgramElement { + static constexpr Kind kProgramElementKind = kModifiers_Kind; + ModifiersDeclaration(Modifiers modifiers) - : INHERITED(-1, kModifiers_Kind) + : INHERITED(-1, kProgramElementKind) , fModifiers(modifiers) {} std::unique_ptr clone() const override { diff --git a/src/sksl/ir/SkSLProgramElement.h b/src/sksl/ir/SkSLProgramElement.h index 0b60d95c20..74d2f4f8b2 100644 --- a/src/sksl/ir/SkSLProgramElement.h +++ b/src/sksl/ir/SkSLProgramElement.h @@ -32,6 +32,21 @@ struct ProgramElement : public IRNode { : INHERITED(offset) , fKind(kind) {} + /** + * Use as to downcast program elements. e.g. replace `(Enum&) el` with `el.as()`. + */ + template + const T& as() const { + SkASSERT(this->fKind == T::kProgramElementKind); + return static_cast(*this); + } + + template + T& as() { + SkASSERT(this->fKind == T::kProgramElementKind); + return static_cast(*this); + } + Kind fKind; virtual std::unique_ptr clone() const = 0; diff --git a/src/sksl/ir/SkSLSection.h b/src/sksl/ir/SkSLSection.h index f3828ff5b8..d3be6468b1 100644 --- a/src/sksl/ir/SkSLSection.h +++ b/src/sksl/ir/SkSLSection.h @@ -16,8 +16,10 @@ namespace SkSL { * A section declaration (e.g. @body { body code here }).. */ struct Section : public ProgramElement { + static constexpr Kind kProgramElementKind = kSection_Kind; + Section(int offset, String name, String arg, String text) - : INHERITED(offset, kSection_Kind) + : INHERITED(offset, kProgramElementKind) , fName(std::move(name)) , fArgument(std::move(arg)) , fText(std::move(text)) {} diff --git a/src/sksl/ir/SkSLVarDeclarations.h b/src/sksl/ir/SkSLVarDeclarations.h index 6ccf942607..1556a13fa9 100644 --- a/src/sksl/ir/SkSLVarDeclarations.h +++ b/src/sksl/ir/SkSLVarDeclarations.h @@ -83,9 +83,11 @@ struct VarDeclaration : public Statement { * A variable declaration statement, which may consist of one or more individual variables. */ struct VarDeclarations : public ProgramElement { + static constexpr Kind kProgramElementKind = kVar_Kind; + VarDeclarations(int offset, const Type* baseType, std::vector> vars) - : INHERITED(offset, kVar_Kind) + : INHERITED(offset, kProgramElementKind) , fBaseType(*baseType) { for (auto& var : vars) { fVars.push_back(std::unique_ptr(var.release()));