Add as<ProgramElementSubclass> to downcast ProgramElements more safely.

The as<T>() function asserts that the ProgramElement is of the correct
kind before performing the downcast, and is also generally easier to
read as function calls flow naturally from left-to-right, and C-style
casts don't.

This CL updates several downcasts throughout SkSL to the as<T>
syntax, but is not intended to exhaustively replace them all (although
that would be ideal). In places where we SkASSERTed the element's
fKind immediately before a cast, the assert has been removed because it
would be redundant with the behavior of as<T>().

Change-Id: I89a487aeaf56e56c720479fee0c2633377a202f1
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/312020
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
This commit is contained in:
John Stiles 2020-08-19 17:48:31 -04:00 committed by Skia Commit-Bot
parent f3c8f5df31
commit 3dc0da6c6d
15 changed files with 122 additions and 92 deletions

View File

@ -316,14 +316,14 @@ bool ProgramVisitor::visitProgramElement(const ProgramElement& pe) {
// Leaf program elements just return false by default
return false;
case ProgramElement::kFunction_Kind:
return this->visitStatement(*((const FunctionDefinition&) pe).fBody);
return this->visitStatement(*pe.as<FunctionDefinition>().fBody);
case ProgramElement::kInterfaceBlock_Kind:
for (const auto& e : ((const InterfaceBlock&) pe).fSizes) {
for (const auto& e : pe.as<InterfaceBlock>().fSizes) {
if (this->visitExpression(*e)) { return true; }
}
return false;
case ProgramElement::kVar_Kind:
for (const auto& v : ((const VarDeclarations&) pe).fVars) {
for (const auto& v : pe.as<VarDeclarations>().fVars) {
if (this->visitStatement(*v)) { return true; }
}
return false;

View File

@ -142,18 +142,19 @@ bool ByteCodeGenerator::generateCode() {
for (const auto& e : fProgram) {
switch (e.fKind) {
case ProgramElement::kFunction_Kind: {
std::unique_ptr<ByteCodeFunction> f = this->writeFunction((FunctionDefinition&) e);
std::unique_ptr<ByteCodeFunction> f =
this->writeFunction(e.as<FunctionDefinition>());
if (!f) {
return false;
}
fOutput->fFunctions.push_back(std::move(f));
fFunctions.push_back(&(FunctionDefinition&)e);
fFunctions.push_back(&e.as<FunctionDefinition>());
break;
}
case ProgramElement::kVar_Kind: {
VarDeclarations& decl = (VarDeclarations&) e;
const VarDeclarations& decl = e.as<VarDeclarations>();
for (const auto& v : decl.fVars) {
const Variable* declVar = ((VarDeclaration&) *v).fVar;
const Variable* declVar = v->as<VarDeclaration>().fVar;
if (declVar->fType == *fContext.fFragmentProcessor_Type) {
fOutput->fChildFPCount++;
}

View File

@ -113,7 +113,7 @@ void CPPCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
void CPPCodeGenerator::writeIndexExpression(const IndexExpression& i) {
const Expression& base = *i.fBase;
if (base.fKind == Expression::kVariableReference_Kind) {
int builtin = ((VariableReference&) base).fVariable.fModifiers.fLayout.fBuiltin;
int builtin = base.as<VariableReference>().fVariable.fModifiers.fLayout.fBuiltin;
if (SK_TEXTURESAMPLERS_BUILTIN == builtin) {
this->write("%s");
if (i.fIndex->fKind != Expression::kIntLiteral_Kind) {
@ -385,7 +385,7 @@ void CPPCodeGenerator::writeFieldAccess(const FieldAccess& access) {
}
const Type::Field& field = fContext.fFragmentProcessor_Type->fields()[access.fFieldIndex];
const Variable& var = ((const VariableReference&) *access.fBase).fVariable;
const Variable& var = access.fBase->as<VariableReference>().fVariable;
String cppAccess = String::printf("_outer.childProcessor(%d)->%s()",
this->getChildFPIndex(var),
String(field.fName).c_str());
@ -405,9 +405,9 @@ int CPPCodeGenerator::getChildFPIndex(const Variable& var) const {
bool found = false;
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
const VarDeclaration& decl = (VarDeclaration&) *raw;
const VarDeclaration& decl = raw->as<VarDeclaration>();
if (decl.fVar == &var) {
found = true;
} else if (decl.fVar->fType.nonnullable() == *fContext.fFragmentProcessor_Type) {
@ -439,7 +439,7 @@ void CPPCodeGenerator::writeFunctionCall(const FunctionCall& c) {
"sample()'s fragmentProcessor argument must be a variable reference\n");
return;
}
const Variable& child = ((const VariableReference&) *c.fArguments[0]).fVariable;
const Variable& child = c.fArguments[0]->as<VariableReference>().fVariable;
// Start a new extra emit code section so that the emitted child processor can depend on
// sksl variables defined in earlier sksl code.
@ -508,7 +508,7 @@ void CPPCodeGenerator::writeFunctionCall(const FunctionCall& c) {
this->write(".%s");
SkASSERT(c.fArguments.size() >= 1);
SkASSERT(c.fArguments[0]->fKind == Expression::kVariableReference_Kind);
String sampler = this->getSamplerHandle(((VariableReference&) *c.fArguments[0]).fVariable);
String sampler = this->getSamplerHandle(c.fArguments[0]->as<VariableReference>().fVariable);
fFormatArgs.push_back("fragBuilder->getProgramBuilder()->samplerSwizzle(" + sampler +
").asString().c_str()");
}
@ -595,7 +595,7 @@ void CPPCodeGenerator::writeFunction(const FunctionDefinition& f) {
fOut = &buffer;
if (decl.fName == "main") {
fInMain = true;
for (const auto& s : ((Block&) *f.fBody).fStatements) {
for (const auto& s : f.fBody->as<Block>().fStatements) {
this->writeStatement(*s);
this->writeLine();
}
@ -615,7 +615,7 @@ void CPPCodeGenerator::writeFunction(const FunctionDefinition& f) {
}
args += "};";
this->addExtraEmitCodeLine(args.c_str());
for (const auto& s : ((Block&) *f.fBody).fStatements) {
for (const auto& s : f.fBody->as<Block>().fStatements) {
this->writeStatement(*s);
this->writeLine();
}
@ -650,11 +650,11 @@ void CPPCodeGenerator::writeProgramElement(const ProgramElement& p) {
return;
}
if (p.fKind == ProgramElement::kVar_Kind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
if (!decls.fVars.size()) {
return;
}
const Variable& var = *((VarDeclaration&) *decls.fVars[0]).fVar;
const Variable& var = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (var.fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kUniform_Flag) ||
-1 != var.fModifiers.fLayout.fBuiltin) {
return;
@ -686,9 +686,9 @@ void CPPCodeGenerator::writeInputVars() {
void CPPCodeGenerator::writePrivateVars() {
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
VarDeclaration& decl = (VarDeclaration&) *raw;
VarDeclaration& decl = raw->as<VarDeclaration>();
if (is_private(*decl.fVar)) {
if (decl.fVar->fType == *fContext.fFragmentProcessor_Type) {
fErrors.error(decl.fOffset,
@ -726,9 +726,9 @@ void CPPCodeGenerator::writePrivateVars() {
void CPPCodeGenerator::writePrivateVarValues() {
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
VarDeclaration& decl = (VarDeclaration&) *raw;
VarDeclaration& decl = raw->as<VarDeclaration>();
if (is_private(*decl.fVar) && decl.fValue) {
this->writef("%s = ", String(decl.fVar->fName).c_str());
fCPPMode = true;
@ -943,9 +943,9 @@ bool CPPCodeGenerator::writeEmitCode(std::vector<const Variable*>& uniforms) {
fFullName.c_str(), fFullName.c_str());
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
VarDeclaration& decl = (VarDeclaration&) *raw;
VarDeclaration& decl = raw->as<VarDeclaration>();
String nameString(decl.fVar->fName);
const char* name = nameString.c_str();
if (SectionAndParameterHelper::IsParameter(*decl.fVar) &&
@ -1059,9 +1059,9 @@ void CPPCodeGenerator::writeSetData(std::vector<const Variable*>& uniforms) {
int samplerIndex = 0;
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const std::unique_ptr<Statement>& raw : decls.fVars) {
const VarDeclaration& decl = static_cast<VarDeclaration&>(*raw);
const VarDeclaration& decl = raw->as<VarDeclaration>();
const Variable& variable = *decl.fVar;
String nameString(variable.fName);
const char* name = nameString.c_str();
@ -1233,9 +1233,9 @@ void CPPCodeGenerator::writeGetKey() {
fFullName.c_str());
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
const VarDeclaration& decl = (VarDeclaration&) *raw;
const VarDeclaration& decl = raw->as<VarDeclaration>();
const Variable& var = *decl.fVar;
String nameString(var.fName);
const char* name = nameString.c_str();
@ -1311,9 +1311,9 @@ bool CPPCodeGenerator::generateCode() {
std::vector<const Variable*> uniforms;
for (const auto& p : fProgram) {
if (ProgramElement::kVar_Kind == p.fKind) {
const VarDeclarations& decls = (const VarDeclarations&) p;
const VarDeclarations& decls = p.as<VarDeclarations>();
for (const auto& raw : decls.fVars) {
VarDeclaration& decl = (VarDeclaration&) *raw;
VarDeclaration& decl = raw->as<VarDeclaration>();
if ((decl.fVar->fModifiers.fFlags & Modifiers::kUniform_Flag) &&
decl.fVar->fType.kind() != Type::kSampler_Kind) {
uniforms.push_back(decl.fVar);

View File

@ -75,7 +75,7 @@ static void grab_intrinsics(std::vector<std::unique_ptr<ProgramElement>>* src,
std::unique_ptr<ProgramElement>& element = *iter;
switch (element->fKind) {
case ProgramElement::kFunction_Kind: {
FunctionDefinition& f = (FunctionDefinition&) *element;
FunctionDefinition& f = element->as<FunctionDefinition>();
SkASSERT(f.fDeclaration.fBuiltin);
String key = f.fDeclaration.description();
SkASSERT(target->find(key) == target->end());
@ -84,7 +84,7 @@ static void grab_intrinsics(std::vector<std::unique_ptr<ProgramElement>>* src,
break;
}
case ProgramElement::kEnum_Kind: {
Enum& e = (Enum&) *element;
Enum& e = element->as<Enum>();
StringFragment name = e.fTypeName;
SkASSERT(target->find(name) == target->end());
(*target)[name] = std::make_pair(std::move(element), false);
@ -1637,7 +1637,7 @@ bool Compiler::optimize(Program& program) {
fIRGenerator->fSettings = &program.fSettings;
for (auto& element : program) {
if (element.fKind == ProgramElement::kFunction_Kind) {
this->scanCFG((FunctionDefinition&) element);
this->scanCFG(element.as<FunctionDefinition>());
}
}
// we wait until after analysis to remove dead functions so that we still report errors
@ -1645,7 +1645,7 @@ bool Compiler::optimize(Program& program) {
if (program.fSettings.fRemoveDeadFunctions) {
for (auto iter = program.fElements.begin(); iter != program.fElements.end(); ) {
if ((*iter)->fKind == ProgramElement::kFunction_Kind) {
const FunctionDefinition& f = (const FunctionDefinition&) **iter;
const FunctionDefinition& f = (*iter)->as<FunctionDefinition>();
if (!f.fDeclaration.fCallCount && f.fDeclaration.fName != "main") {
iter = program.fElements.erase(iter);
continue;
@ -1657,7 +1657,7 @@ bool Compiler::optimize(Program& program) {
if (program.fKind != Program::kFragmentProcessor_Kind) {
for (auto iter = program.fElements.begin(); iter != program.fElements.end();) {
if ((*iter)->fKind == ProgramElement::kVar_Kind) {
VarDeclarations& vars = (VarDeclarations&) **iter;
VarDeclarations& vars = (*iter)->as<VarDeclarations>();
for (auto varIter = vars.fVars.begin(); varIter != vars.fVars.end();) {
const Variable& var = *((VarDeclaration&) **varIter).fVar;
if (var.dead()) {

View File

@ -496,7 +496,7 @@ void Dehydrator::write(const Statement* s) {
void Dehydrator::write(const ProgramElement& e) {
switch (e.fKind) {
case ProgramElement::kEnum_Kind: {
Enum& en = (Enum&) e;
const Enum& en = e.as<Enum>();
this->writeU8(Rehydrator::kEnum_Command);
this->write(en.fTypeName);
AutoDehydratorSymbolTable symbols(this, en.fSymbols);
@ -513,7 +513,7 @@ void Dehydrator::write(const ProgramElement& e) {
SkASSERT(false);
break;
case ProgramElement::kFunction_Kind: {
FunctionDefinition& f = (FunctionDefinition&) e;
const FunctionDefinition& f = e.as<FunctionDefinition>();
this->writeU8(Rehydrator::kFunctionDefinition_Command);
this->writeU16(this->symbolId(&f.fDeclaration));
this->write(f.fBody.get());
@ -528,7 +528,7 @@ void Dehydrator::write(const ProgramElement& e) {
break;
}
case ProgramElement::kInterfaceBlock_Kind: {
InterfaceBlock& i = (InterfaceBlock&) e;
const InterfaceBlock& i = e.as<InterfaceBlock>();
this->writeU8(Rehydrator::kInterfaceBlock_Command);
this->write(i.fVariable);
this->write(i.fTypeName);
@ -546,7 +546,7 @@ void Dehydrator::write(const ProgramElement& e) {
SkASSERT(false);
break;
case ProgramElement::kVar_Kind: {
VarDeclarations& v = (VarDeclarations&) e;
const VarDeclarations& v = e.as<VarDeclarations>();
this->writeU8(Rehydrator::kVarDeclarations_Command);
this->write(v.fBaseType);
this->writeU8(v.fVars.size());

View File

@ -1245,7 +1245,7 @@ void GLSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
OutputStream* oldOut = fOut;
StringStream buffer;
fOut = &buffer;
this->writeStatements(((Block&) *f.fBody).fStatements);
this->writeStatements(f.fBody->as<Block>().fStatements);
if (fProgramKind != Program::kPipelineStage_Kind) {
fIndentation--;
this->writeLine("}");
@ -1416,7 +1416,7 @@ void GLSLCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool g
}
bool wroteType = false;
for (const auto& stmt : decl.fVars) {
VarDeclaration& var = (VarDeclaration&) *stmt;
const VarDeclaration& var = stmt->as<VarDeclaration>();
if (wroteType) {
this->write(", ");
} else {
@ -1460,7 +1460,7 @@ void GLSLCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool g
void GLSLCodeGenerator::writeStatement(const Statement& s) {
switch (s.fKind) {
case Statement::kBlock_Kind:
this->writeBlock((Block&) s);
this->writeBlock(s.as<Block>());
break;
case Statement::kExpression_Kind:
this->writeExpression(*s.as<ExpressionStatement>().fExpression, kTopLevel_Precedence);
@ -1666,19 +1666,19 @@ void GLSLCodeGenerator::writeHeader() {
void GLSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
switch (e.fKind) {
case ProgramElement::kExtension_Kind:
this->writeExtension(((Extension&) e).fName);
this->writeExtension(e.as<Extension>().fName);
break;
case ProgramElement::kVar_Kind: {
VarDeclarations& decl = (VarDeclarations&) e;
const VarDeclarations& decl = e.as<VarDeclarations>();
if (decl.fVars.size() > 0) {
int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
int builtin = decl.fVars[0]->as<VarDeclaration>().fVar->fModifiers.fLayout.fBuiltin;
if (builtin == -1) {
// normal var
this->writeVarDeclarations(decl, true);
this->writeLine();
} else if (builtin == SK_FRAGCOLOR_BUILTIN &&
fProgram.fSettings.fCaps->mustDeclareFragmentShaderOutput() &&
((VarDeclaration&) *decl.fVars[0]).fVar->fWriteCount) {
decl.fVars[0]->as<VarDeclaration>().fVar->fWriteCount) {
if (fProgram.fSettings.fFragColorIsInOut) {
this->write("inout ");
} else {
@ -1693,13 +1693,13 @@ void GLSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
break;
}
case ProgramElement::kInterfaceBlock_Kind:
this->writeInterfaceBlock((InterfaceBlock&) e);
this->writeInterfaceBlock(e.as<InterfaceBlock>());
break;
case ProgramElement::kFunction_Kind:
this->writeFunction((FunctionDefinition&) e);
this->writeFunction(e.as<FunctionDefinition>());
break;
case ProgramElement::kModifiers_Kind: {
const Modifiers& modifiers = ((ModifiersDeclaration&) e).fModifiers;
const Modifiers& modifiers = e.as<ModifiersDeclaration>().fModifiers;
if (!fFoundGSInvocations && modifiers.fLayout.fInvocations >= 0) {
if (fProgram.fSettings.fCaps->gsInvocationsExtensionString()) {
this->writeExtension(fProgram.fSettings.fCaps->gsInvocationsExtensionString());

View File

@ -952,12 +952,12 @@ void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
}
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
VarDeclarations& decls = (VarDeclarations&) e;
const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
for (const auto& stmt: decls.fVars) {
VarDeclaration& var = (VarDeclaration&) *stmt;
VarDeclaration& var = stmt->as<VarDeclaration>();
if (var.fVar->fType.kind() == Type::kSampler_Kind) {
if (var.fVar->fModifiers.fLayout.fBinding < 0) {
fErrors.error(decls.fOffset,
@ -1392,11 +1392,11 @@ void MetalCodeGenerator::writeHeader() {
void MetalCodeGenerator::writeUniformStruct() {
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
VarDeclarations& decls = (VarDeclarations&) e;
const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kUniform_Flag &&
first.fType.kind() != Type::kSampler_Kind) {
if (-1 == fUniformBuffer) {
@ -1415,7 +1415,7 @@ void MetalCodeGenerator::writeUniformStruct() {
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
VarDeclaration& var = (VarDeclaration&) *stmt;
const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
}
this->write(";\n");
@ -1431,18 +1431,18 @@ void MetalCodeGenerator::writeInputStruct() {
this->write("struct Inputs {\n");
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
VarDeclarations& decls = (VarDeclarations&) e;
const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
-1 == first.fModifiers.fLayout.fBuiltin) {
this->write(" ");
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
VarDeclaration& var = (VarDeclaration&) *stmt;
const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
if (fProgram.fKind == Program::kVertex_Kind) {
@ -1470,18 +1470,18 @@ void MetalCodeGenerator::writeOutputStruct() {
}
for (const auto& e : fProgram) {
if (ProgramElement::kVar_Kind == e.fKind) {
VarDeclarations& decls = (VarDeclarations&) e;
const VarDeclarations& decls = e.as<VarDeclarations>();
if (!decls.fVars.size()) {
continue;
}
const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
const Variable& first = *decls.fVars[0]->as<VarDeclaration>().fVar;
if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
-1 == first.fModifiers.fLayout.fBuiltin) {
this->write(" ");
this->writeType(first.fType);
this->write(" ");
for (const auto& stmt : decls.fVars) {
VarDeclaration& var = (VarDeclaration&) *stmt;
const VarDeclaration& var = stmt->as<VarDeclaration>();
this->writeName(var.fVar->fName);
if (fProgram.fKind == Program::kVertex_Kind) {
this->write(" [[user(locn" +
@ -1510,7 +1510,7 @@ void MetalCodeGenerator::writeInterfaceBlocks() {
bool wroteInterfaceBlock = false;
for (const auto& e : fProgram) {
if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
this->writeInterfaceBlock((InterfaceBlock&) e);
this->writeInterfaceBlock(e.as<InterfaceBlock>());
wroteInterfaceBlock = true;
}
}
@ -1662,9 +1662,9 @@ void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
case ProgramElement::kExtension_Kind:
break;
case ProgramElement::kVar_Kind: {
VarDeclarations& decl = (VarDeclarations&) e;
const VarDeclarations& decl = e.as<VarDeclarations>();
if (decl.fVars.size() > 0) {
int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
int builtin = decl.fVars[0]->as<VarDeclaration>().fVar->fModifiers.fLayout.fBuiltin;
if (-1 == builtin) {
// normal var
this->writeVarDeclarations(decl, true);
@ -1679,10 +1679,10 @@ void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
// handled in writeInterfaceBlocks, do nothing
break;
case ProgramElement::kFunction_Kind:
this->writeFunction((FunctionDefinition&) e);
this->writeFunction(e.as<FunctionDefinition>());
break;
case ProgramElement::kModifiers_Kind:
this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
this->writeModifiers(e.as<ModifiersDeclaration>().fModifiers, true);
this->writeLine(";");
break;
default:
@ -1699,7 +1699,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi
}
switch (e->fKind) {
case Expression::kFunctionCall_Kind: {
const FunctionCall& f = (const FunctionCall&) *e;
const FunctionCall& f = e->as<FunctionCall>();
Requirements result = this->requirements(f.fFunction);
for (const auto& arg : f.fArguments) {
result |= this->requirements(arg.get());
@ -1707,7 +1707,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi
return result;
}
case Expression::kConstructor_Kind: {
const Constructor& c = (const Constructor&) *e;
const Constructor& c = e->as<Constructor>();
Requirements result = kNo_Requirements;
for (const auto& arg : c.fArguments) {
result |= this->requirements(arg.get());
@ -1715,33 +1715,33 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expressi
return result;
}
case Expression::kFieldAccess_Kind: {
const FieldAccess& f = (const FieldAccess&) *e;
const FieldAccess& f = e->as<FieldAccess>();
if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
return kGlobals_Requirement;
}
return this->requirements(f.fBase.get());
}
case Expression::kSwizzle_Kind:
return this->requirements(((const Swizzle&) *e).fBase.get());
return this->requirements(e->as<Swizzle>().fBase.get());
case Expression::kBinary_Kind: {
const BinaryExpression& b = (const BinaryExpression&) *e;
const BinaryExpression& b = e->as<BinaryExpression>();
return this->requirements(b.fLeft.get()) | this->requirements(b.fRight.get());
}
case Expression::kIndex_Kind: {
const IndexExpression& idx = (const IndexExpression&) *e;
const IndexExpression& idx = e->as<IndexExpression>();
return this->requirements(idx.fBase.get()) | this->requirements(idx.fIndex.get());
}
case Expression::kPrefix_Kind:
return this->requirements(((const PrefixExpression&) *e).fOperand.get());
return this->requirements(e->as<PrefixExpression>().fOperand.get());
case Expression::kPostfix_Kind:
return this->requirements(((const PostfixExpression&) *e).fOperand.get());
return this->requirements(e->as<PostfixExpression>().fOperand.get());
case Expression::kTernary_Kind: {
const TernaryExpression& t = (const TernaryExpression&) *e;
const TernaryExpression& t = e->as<TernaryExpression>();
return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) |
this->requirements(t.fIfFalse.get());
}
case Expression::kVariableReference_Kind: {
const VariableReference& v = (const VariableReference&) *e;
const VariableReference& v = e->as<VariableReference>();
Requirements result = kNo_Requirements;
if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
result = kGlobals_Requirement | kFragCoord_Requirement;
@ -1771,54 +1771,54 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statemen
switch (s->fKind) {
case Statement::kBlock_Kind: {
Requirements result = kNo_Requirements;
for (const auto& child : ((const Block*) s)->fStatements) {
for (const auto& child : s->as<Block>().fStatements) {
result |= this->requirements(child.get());
}
return result;
}
case Statement::kVarDeclaration_Kind: {
const VarDeclaration& var = (const VarDeclaration&) *s;
const VarDeclaration& var = s->as<VarDeclaration>();
return this->requirements(var.fValue.get());
}
case Statement::kVarDeclarations_Kind: {
Requirements result = kNo_Requirements;
const VarDeclarations& decls = *((const VarDeclarationsStatement&) *s).fDeclaration;
const VarDeclarations& decls = *s->as<VarDeclarationsStatement>().fDeclaration;
for (const auto& stmt : decls.fVars) {
result |= this->requirements(stmt.get());
}
return result;
}
case Statement::kExpression_Kind:
return this->requirements(((const ExpressionStatement&) *s).fExpression.get());
return this->requirements(s->as<ExpressionStatement>().fExpression.get());
case Statement::kReturn_Kind: {
const ReturnStatement& r = (const ReturnStatement&) *s;
const ReturnStatement& r = s->as<ReturnStatement>();
return this->requirements(r.fExpression.get());
}
case Statement::kIf_Kind: {
const IfStatement& i = (const IfStatement&) *s;
const IfStatement& i = s->as<IfStatement>();
return this->requirements(i.fTest.get()) |
this->requirements(i.fIfTrue.get()) |
this->requirements(i.fIfFalse.get());
}
case Statement::kFor_Kind: {
const ForStatement& f = (const ForStatement&) *s;
const ForStatement& f = s->as<ForStatement>();
return this->requirements(f.fInitializer.get()) |
this->requirements(f.fTest.get()) |
this->requirements(f.fNext.get()) |
this->requirements(f.fStatement.get());
}
case Statement::kWhile_Kind: {
const WhileStatement& w = (const WhileStatement&) *s;
const WhileStatement& w = s->as<WhileStatement>();
return this->requirements(w.fTest.get()) |
this->requirements(w.fStatement.get());
}
case Statement::kDo_Kind: {
const DoStatement& d = (const DoStatement&) *s;
const DoStatement& d = s->as<DoStatement>();
return this->requirements(d.fTest.get()) |
this->requirements(d.fStatement.get());
}
case Statement::kSwitch_Kind: {
const SwitchStatement& sw = (const SwitchStatement&) *s;
const SwitchStatement& sw = s->as<SwitchStatement>();
Requirements result = this->requirements(sw.fValue.get());
for (const auto& c : sw.fCases) {
for (const auto& st : c->fStatements) {
@ -1841,7 +1841,7 @@ MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Function
fRequirements[&f] = kNo_Requirements;
for (const auto& e : fProgram) {
if (ProgramElement::kFunction_Kind == e.fKind) {
const FunctionDefinition& def = (const FunctionDefinition&) e;
const FunctionDefinition& def = e.as<FunctionDefinition>();
if (&def.fDeclaration == &f) {
Requirements reqs = this->requirements(def.fBody.get());
fRequirements[&f] = reqs;

View File

@ -21,9 +21,11 @@ namespace SkSL {
struct Symbol;
struct Enum : public ProgramElement {
static constexpr Kind kProgramElementKind = kEnum_Kind;
Enum(int offset, StringFragment typeName, std::shared_ptr<SymbolTable> symbols,
bool isBuiltin = true)
: INHERITED(offset, kEnum_Kind)
: INHERITED(offset, kProgramElementKind)
, fTypeName(typeName)
, fSymbols(std::move(symbols))
, fBuiltin(isBuiltin) {}

View File

@ -16,8 +16,10 @@ namespace SkSL {
* An extension declaration.
*/
struct Extension : public ProgramElement {
static constexpr Kind kProgramElementKind = kExtension_Kind;
Extension(int offset, String name)
: INHERITED(offset, kExtension_Kind)
: INHERITED(offset, kProgramElementKind)
, fName(std::move(name)) {}
std::unique_ptr<ProgramElement> clone() const override {

View File

@ -22,11 +22,13 @@ struct ASTNode;
* A function definition (a declaration plus an associated block of code).
*/
struct FunctionDefinition : public ProgramElement {
static constexpr Kind kProgramElementKind = kFunction_Kind;
FunctionDefinition(int offset,
const FunctionDeclaration& declaration,
std::unique_ptr<Statement> body,
std::unordered_set<const FunctionDeclaration*> referencedIntrinsics = {})
: INHERITED(offset, kFunction_Kind)
: INHERITED(offset, kProgramElementKind)
, fDeclaration(declaration)
, fBody(std::move(body))
, fReferencedIntrinsics(std::move(referencedIntrinsics)) {}

View File

@ -25,10 +25,12 @@ namespace SkSL {
* At the IR level, this is represented by a single variable of struct type.
*/
struct InterfaceBlock : public ProgramElement {
static constexpr Kind kProgramElementKind = kInterfaceBlock_Kind;
InterfaceBlock(int offset, const Variable* var, String typeName, String instanceName,
std::vector<std::unique_ptr<Expression>> sizes,
std::shared_ptr<SymbolTable> typeOwner)
: INHERITED(offset, kInterfaceBlock_Kind)
: INHERITED(offset, kProgramElementKind)
, fVariable(*var)
, fTypeName(std::move(typeName))
, fInstanceName(std::move(instanceName))

View File

@ -19,8 +19,10 @@ namespace SkSL {
* layout(blend_support_all_equations) out;
*/
struct ModifiersDeclaration : public ProgramElement {
static constexpr Kind kProgramElementKind = kModifiers_Kind;
ModifiersDeclaration(Modifiers modifiers)
: INHERITED(-1, kModifiers_Kind)
: INHERITED(-1, kProgramElementKind)
, fModifiers(modifiers) {}
std::unique_ptr<ProgramElement> clone() const override {

View File

@ -32,6 +32,21 @@ struct ProgramElement : public IRNode {
: INHERITED(offset)
, fKind(kind) {}
/**
* Use as<T> to downcast program elements. e.g. replace `(Enum&) el` with `el.as<Enum>()`.
*/
template <typename T>
const T& as() const {
SkASSERT(this->fKind == T::kProgramElementKind);
return static_cast<const T&>(*this);
}
template <typename T>
T& as() {
SkASSERT(this->fKind == T::kProgramElementKind);
return static_cast<T&>(*this);
}
Kind fKind;
virtual std::unique_ptr<ProgramElement> clone() const = 0;

View File

@ -16,8 +16,10 @@ namespace SkSL {
* A section declaration (e.g. @body { body code here })..
*/
struct Section : public ProgramElement {
static constexpr Kind kProgramElementKind = kSection_Kind;
Section(int offset, String name, String arg, String text)
: INHERITED(offset, kSection_Kind)
: INHERITED(offset, kProgramElementKind)
, fName(std::move(name))
, fArgument(std::move(arg))
, fText(std::move(text)) {}

View File

@ -83,9 +83,11 @@ struct VarDeclaration : public Statement {
* A variable declaration statement, which may consist of one or more individual variables.
*/
struct VarDeclarations : public ProgramElement {
static constexpr Kind kProgramElementKind = kVar_Kind;
VarDeclarations(int offset, const Type* baseType,
std::vector<std::unique_ptr<VarDeclaration>> vars)
: INHERITED(offset, kVar_Kind)
: INHERITED(offset, kProgramElementKind)
, fBaseType(*baseType) {
for (auto& var : vars) {
fVars.push_back(std::unique_ptr<Statement>(var.release()));