Added SkSL DSLFunction

Change-Id: Ibc995e908e5b4f8d1516e13d56854a4fcf5cc809
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/360556
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
Ethan Nicholas 2021-01-28 10:02:43 -05:00 committed by Skia Commit-Bot
parent 0bad6cf145
commit 1ff760981d
12 changed files with 286 additions and 10 deletions

View File

@ -61,6 +61,7 @@ skia_sksl_sources = [
"$_src/sksl/dsl/DSLBlock.cpp", "$_src/sksl/dsl/DSLBlock.cpp",
"$_src/sksl/dsl/DSLCore.cpp", "$_src/sksl/dsl/DSLCore.cpp",
"$_src/sksl/dsl/DSLExpression.cpp", "$_src/sksl/dsl/DSLExpression.cpp",
"$_src/sksl/dsl/DSLFunction.cpp",
"$_src/sksl/dsl/DSLStatement.cpp", "$_src/sksl/dsl/DSLStatement.cpp",
"$_src/sksl/dsl/DSLType.cpp", "$_src/sksl/dsl/DSLType.cpp",
"$_src/sksl/dsl/DSLVar.cpp", "$_src/sksl/dsl/DSLVar.cpp",

View File

@ -820,18 +820,14 @@ std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(const ASTNode
return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e))); return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
} }
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) { std::unique_ptr<Statement> IRGenerator::convertReturn(int offset,
SkASSERT(r.fKind == ASTNode::Kind::kReturn); std::unique_ptr<Expression> result) {
SkASSERT(fCurrentFunction); SkASSERT(fCurrentFunction);
// early returns from a vertex main function will bypass the sk_Position normalization, so // early returns from a vertex main function will bypass the sk_Position normalization, so
// SkASSERT that we aren't doing that. It is of course possible to fix this by adding a // SkASSERT that we aren't doing that. It is of course possible to fix this by adding a
// normalization before each return, but it will probably never actually be necessary. // normalization before each return, but it will probably never actually be necessary.
SkASSERT(Program::kVertex_Kind != fKind || !fRTAdjust || "main" != fCurrentFunction->name()); SkASSERT(Program::kVertex_Kind != fKind || !fRTAdjust || "main" != fCurrentFunction->name());
if (r.begin() != r.end()) { if (result) {
std::unique_ptr<Expression> result = this->convertExpression(*r.begin());
if (!result) {
return nullptr;
}
if (fCurrentFunction->returnType() == *fContext.fTypes.fVoid) { if (fCurrentFunction->returnType() == *fContext.fTypes.fVoid) {
this->errorReporter().error(result->fOffset, this->errorReporter().error(result->fOffset,
"may not return a value from a void function"); "may not return a value from a void function");
@ -845,11 +841,24 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
return std::make_unique<ReturnStatement>(std::move(result)); return std::make_unique<ReturnStatement>(std::move(result));
} else { } else {
if (fCurrentFunction->returnType() != *fContext.fTypes.fVoid) { if (fCurrentFunction->returnType() != *fContext.fTypes.fVoid) {
this->errorReporter().error(r.fOffset, this->errorReporter().error(offset, "expected function to return '" +
"expected function to return '" +
fCurrentFunction->returnType().displayName() + "'"); fCurrentFunction->returnType().displayName() + "'");
return nullptr;
} }
return std::make_unique<ReturnStatement>(r.fOffset); return std::make_unique<ReturnStatement>(offset);
}
}
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
SkASSERT(r.fKind == ASTNode::Kind::kReturn);
if (r.begin() != r.end()) {
std::unique_ptr<Expression> value = this->convertExpression(*r.begin());
if (!value) {
return nullptr;
}
return this->convertReturn(r.fOffset, std::move(value));
} else {
return this->convertReturn(r.fOffset, /*result=*/nullptr);
} }
} }

View File

@ -34,6 +34,7 @@ namespace SkSL {
namespace dsl { namespace dsl {
class DSLCore; class DSLCore;
class DSLFunction;
class DSLVar; class DSLVar;
class DSLWriter; class DSLWriter;
} }
@ -224,6 +225,7 @@ private:
std::unique_ptr<InterfaceBlock> convertInterfaceBlock(const ASTNode& s); std::unique_ptr<InterfaceBlock> convertInterfaceBlock(const ASTNode& s);
Modifiers convertModifiers(const Modifiers& m); Modifiers convertModifiers(const Modifiers& m);
std::unique_ptr<Expression> convertPrefixExpression(const ASTNode& expression); std::unique_ptr<Expression> convertPrefixExpression(const ASTNode& expression);
std::unique_ptr<Statement> convertReturn(int offset, std::unique_ptr<Expression> result);
std::unique_ptr<Statement> convertReturn(const ASTNode& r); std::unique_ptr<Statement> convertReturn(const ASTNode& r);
std::unique_ptr<Section> convertSection(const ASTNode& e); std::unique_ptr<Section> convertSection(const ASTNode& e);
std::unique_ptr<Expression> convertCallExpression(const ASTNode& expression); std::unique_ptr<Expression> convertCallExpression(const ASTNode& expression);
@ -303,6 +305,7 @@ private:
friend class AutoDisableInline; friend class AutoDisableInline;
friend class Compiler; friend class Compiler;
friend class dsl::DSLCore; friend class dsl::DSLCore;
friend class dsl::DSLFunction;
friend class dsl::DSLVar; friend class dsl::DSLVar;
friend class dsl::DSLWriter; friend class dsl::DSLWriter;
}; };

View File

@ -16,6 +16,7 @@ namespace dsl {
using Block = DSLBlock; using Block = DSLBlock;
using Expression = DSLExpression; using Expression = DSLExpression;
using Function = DSLFunction;
using Modifiers = DSLModifiers; using Modifiers = DSLModifiers;
using Statement = DSLStatement; using Statement = DSLStatement;
using Var = DSLVar; using Var = DSLVar;

View File

@ -13,6 +13,7 @@
#include "src/sksl/ir/SkSLDoStatement.h" #include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLForStatement.h" #include "src/sksl/ir/SkSLForStatement.h"
#include "src/sksl/ir/SkSLIfStatement.h" #include "src/sksl/ir/SkSLIfStatement.h"
#include "src/sksl/ir/SkSLReturnStatement.h"
namespace SkSL { namespace SkSL {
@ -53,6 +54,14 @@ static char swizzle_component(SwizzleComponent c) {
class DSLCore { class DSLCore {
public: public:
static DSLVar sk_FragColor() {
return DSLVar("sk_FragColor");
}
static DSLVar sk_FragCoord() {
return DSLVar("sk_FragCoord");
}
template <typename... Args> template <typename... Args>
static DSLExpression Call(const char* name, Args... args) { static DSLExpression Call(const char* name, Args... args) {
SkSL::IRGenerator& ir = DSLWriter::IRGenerator(); SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
@ -95,6 +104,19 @@ public:
ifTrue.release(), ifFalse.release()); ifTrue.release(), ifFalse.release());
} }
static DSLStatement Return(DSLExpression value) {
// note that because Return is called before the function in which it resides exists, at
// this point we do not know the function's return type. We therefore do not check for
// errors, or coerce the value to the correct type, until the return statement is actually
// added to a function
std::unique_ptr<SkSL::Expression> expr = value.release();
if (expr) {
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(std::move(expr)));
} else {
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(/*offset=*/-1));
}
}
static DSLExpression Swizzle(DSLExpression base, SwizzleComponent a) { static DSLExpression Swizzle(DSLExpression base, SwizzleComponent a) {
char mask[] = { swizzle_component(a), 0 }; char mask[] = { swizzle_component(a), 0 };
return DSLWriter::IRGenerator().convertSwizzle(base.release(), mask); return DSLWriter::IRGenerator().convertSwizzle(base.release(), mask);
@ -131,6 +153,14 @@ private:
static void ignore(std::unique_ptr<SkSL::Expression>&) {} static void ignore(std::unique_ptr<SkSL::Expression>&) {}
}; };
DSLVar sk_FragColor() {
return DSLCore::sk_FragColor();
}
DSLVar sk_FragCoord() {
return DSLCore::sk_FragCoord();
}
DSLStatement Declare(DSLVar& var, DSLExpression initialValue) { DSLStatement Declare(DSLVar& var, DSLExpression initialValue) {
return DSLCore::Declare(var, std::move(initialValue)); return DSLCore::Declare(var, std::move(initialValue));
} }
@ -148,6 +178,10 @@ DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse) {
return DSLCore::If(std::move(test), std::move(ifTrue), std::move(ifFalse)); return DSLCore::If(std::move(test), std::move(ifTrue), std::move(ifFalse));
} }
DSLStatement Return(DSLExpression expr) {
return DSLCore::Return(std::move(expr));
}
DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse) { DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse) {
return DSLCore::Ternary(std::move(test), std::move(ifTrue), std::move(ifFalse)); return DSLCore::Ternary(std::move(test), std::move(ifTrue), std::move(ifFalse));
} }

View File

@ -10,6 +10,7 @@
#include "src/sksl/dsl/DSLBlock.h" #include "src/sksl/dsl/DSLBlock.h"
#include "src/sksl/dsl/DSLExpression.h" #include "src/sksl/dsl/DSLExpression.h"
#include "src/sksl/dsl/DSLFunction.h"
#include "src/sksl/dsl/DSLStatement.h" #include "src/sksl/dsl/DSLStatement.h"
#include "src/sksl/dsl/DSLType.h" #include "src/sksl/dsl/DSLType.h"
#include "src/sksl/dsl/DSLVar.h" #include "src/sksl/dsl/DSLVar.h"
@ -48,6 +49,10 @@ void End();
*/ */
void SetErrorHandler(ErrorHandler* errorHandler); void SetErrorHandler(ErrorHandler* errorHandler);
DSLVar sk_FragColor();
DSLVar sk_FragCoord();
/** /**
* Creates a variable declaration statement with an initial value. * Creates a variable declaration statement with an initial value.
*/ */
@ -69,6 +74,11 @@ DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression nex
*/ */
DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse = DSLStatement()); DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse = DSLStatement());
/**
* return [value];
*/
DSLStatement Return(DSLExpression value = DSLExpression());
/** /**
* test ? ifTrue : ifFalse * test ? ifTrue : ifFalse
*/ */

View File

@ -0,0 +1,56 @@
/*
* Copyright 2021 Google LLC.
*
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#include "src/sksl/dsl/DSLFunction.h"
#include "src/sksl/SkSLAnalysis.h"
#include "src/sksl/SkSLCompiler.h"
#include "src/sksl/SkSLIRGenerator.h"
#include "src/sksl/ir/SkSLReturnStatement.h"
namespace SkSL {
namespace dsl {
void DSLFunction::define(DSLBlock block) {
SkASSERT(fDecl);
SkASSERT(!DSLWriter::CurrentFunction());
DSLWriter::SetCurrentFunction(fDecl);
class FinalizeReturns : public ProgramWriter {
public:
bool visitStatement(Statement& stmt) override {
if (stmt.is<ReturnStatement>()) {
ReturnStatement& r = stmt.as<ReturnStatement>();
std::unique_ptr<Statement> finished = DSLWriter::IRGenerator().convertReturn(
r.fOffset,
std::move(r.expression()));
if (finished) {
r.setExpression(std::move(finished->as<ReturnStatement>().expression()));
} else {
SkASSERT(DSLWriter::Compiler().errorCount());
DSLWriter::ReportError(
DSLWriter::Compiler().errorText(/*showCount=*/false).c_str());
}
}
return INHERITED::visitStatement(stmt);
}
private:
using INHERITED = ProgramWriter;
};
std::unique_ptr<Statement> body = block.release();
FinalizeReturns().visitStatement(*body);
DSLWriter::ProgramElements().emplace_back(new SkSL::FunctionDefinition(/*offset=*/-1,
fDecl,
/*builtin=*/false,
std::move(body)));
DSLWriter::SetCurrentFunction(nullptr);
}
} // namespace dsl
} // namespace SkSL

View File

@ -0,0 +1,63 @@
/*
* Copyright 2021 Google LLC.
*
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#ifndef SKSL_DSL_FUNCTION
#define SKSL_DSL_FUNCTION
#include "src/sksl/SkSLString.h"
#include "src/sksl/dsl/DSLBlock.h"
#include "src/sksl/dsl/DSLType.h"
#include "src/sksl/dsl/priv/DSLWriter.h"
#include "src/sksl/ir/SkSLBlock.h"
#include "src/sksl/ir/SkSLFunctionDefinition.h"
namespace SkSL {
class Block;
class Variable;
namespace dsl {
class DSLType;
class DSLFunction {
public:
template<class... Parameters>
DSLFunction(const DSLType& returnType, const char* name, Parameters&... parameters)
: fReturnType(&returnType.skslType()) {
std::vector<const Variable*> parameterArray;
parameterArray.reserve(sizeof...(parameters));
(parameterArray.push_back(&DSLWriter::Var(parameters)), ...);
SkSL::SymbolTable& symbols = *DSLWriter::SymbolTable();
fDecl = symbols.add(std::make_unique<SkSL::FunctionDeclaration>(
/*offset=*/-1,
DSLWriter::Modifiers(SkSL::Modifiers()),
DSLWriter::Name(name),
std::move(parameterArray), fReturnType,
/*builtin=*/false));
}
virtual ~DSLFunction() = default;
template<class... Stmt>
void define(Stmt... stmts) {
DSLBlock block = DSLBlock(DSLStatement(std::move(stmts))...);
this->define(std::move(block));
}
void define(DSLBlock block);
protected:
const SkSL::Type* fReturnType;
const SkSL::FunctionDeclaration* fDecl;
};
} // namespace dsl
} // namespace SkSL
#endif

View File

@ -57,6 +57,14 @@ const char* DSLWriter::Name(const char* name) {
return name; return name;
} }
const SkSL::FunctionDeclaration* DSLWriter::CurrentFunction() {
return IRGenerator().fCurrentFunction;
}
void DSLWriter::SetCurrentFunction(const SkSL::FunctionDeclaration* fn) {
IRGenerator().fCurrentFunction = fn;
}
std::unique_ptr<SkSL::Expression> DSLWriter::Check(std::unique_ptr<SkSL::Expression> expr) { std::unique_ptr<SkSL::Expression> DSLWriter::Check(std::unique_ptr<SkSL::Expression> expr) {
if (expr == nullptr) { if (expr == nullptr) {
if (DSLWriter::Compiler().errorCount()) { if (DSLWriter::Compiler().errorCount()) {

View File

@ -58,6 +58,13 @@ public:
*/ */
static const SkSL::Context& Context(); static const SkSL::Context& Context();
/**
* Returns the collection to which DSL program elements in this thread should be appended.
*/
static std::vector<std::unique_ptr<SkSL::ProgramElement>>& ProgramElements() {
return Instance().fProgramElements;
}
/** /**
* Returns the SymbolTable of the current thread's IRGenerator. * Returns the SymbolTable of the current thread's IRGenerator.
*/ */
@ -80,6 +87,16 @@ public:
*/ */
static const char* Name(const char* name); static const char* Name(const char* name);
/**
* Returns the current function for which we are generating nodes.
*/
static const SkSL::FunctionDeclaration* CurrentFunction();
/**
* Specifies the function for which we are generating nodes.
*/
static void SetCurrentFunction(const SkSL::FunctionDeclaration* fn);
/** /**
* Reports an error if the argument is null. Returns its argument unmodified. * Reports an error if the argument is null. Returns its argument unmodified.
*/ */
@ -128,6 +145,7 @@ public:
private: private:
SkSL::Program::Settings fSettings; SkSL::Program::Settings fSettings;
SkSL::Compiler* fCompiler; SkSL::Compiler* fCompiler;
std::vector<std::unique_ptr<SkSL::ProgramElement>> fProgramElements;
ErrorHandler* fErrorHandler = nullptr; ErrorHandler* fErrorHandler = nullptr;
bool fMangle = true; bool fMangle = true;
Mangler fMangler; Mangler fMangler;

View File

@ -35,6 +35,10 @@ public:
return fExpression; return fExpression;
} }
void setExpression(std::unique_ptr<Expression> expr) {
fExpression = std::move(expr);
}
std::unique_ptr<Statement> clone() const override { std::unique_ptr<Statement> clone() const override {
if (this->expression()) { if (this->expression()) {
return std::unique_ptr<Statement>(new ReturnStatement(this->expression()->clone())); return std::unique_ptr<Statement>(new ReturnStatement(this->expression()->clone()));

View File

@ -10,6 +10,7 @@
#include "src/sksl/SkSLIRGenerator.h" #include "src/sksl/SkSLIRGenerator.h"
#include "src/sksl/dsl/DSL.h" #include "src/sksl/dsl/DSL.h"
#include "src/sksl/dsl/priv/DSLWriter.h" #include "src/sksl/dsl/priv/DSLWriter.h"
#include "src/sksl/ir/SkSLIRNode.h"
#include "tests/Test.h" #include "tests/Test.h"
@ -76,6 +77,10 @@ static bool whitespace_insensitive_compare(DSLStatement& stmt, const char* descr
return whitespace_insensitive_compare(stmt.release()->description().c_str(), description); return whitespace_insensitive_compare(stmt.release()->description().c_str(), description);
} }
static bool whitespace_insensitive_compare(SkSL::IRNode& node, const char* description) {
return whitespace_insensitive_compare(node.description().c_str(), description);
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLStartup, r, ctxInfo) { DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLStartup, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu()); AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Expression e1 = 1; Expression e1 = 1;
@ -939,6 +944,60 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLFor, r, ctxInfo) {
} }
} }
DEF_GPUTEST_FOR_ALL_CONTEXTS(DSLFunction, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
DSLWriter::ProgramElements().clear();
Var coords(kHalf2, "coords");
DSLFunction(kVoid, "main", coords).define(
sk_FragColor() = Half4(coords, 0, 1)
);
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
"void main(half2 coords) { (sk_FragColor = half4(coords, 0.0, 1.0)); }"));
DSLWriter::ProgramElements().clear();
Var x(kFloat, "x");
DSLFunction(kFloat, "sqr", x).define(
Return(x * x)
);
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
"float sqr(float x) { return (x * x); }"));
{
ExpectError error(r, "error: expected 'float', but found 'bool'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
Return(true)
);
}
{
ExpectError error(r, "error: expected function to return 'float'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
Return()
);
}
{
ExpectError error(r, "error: may not return a value from a void function\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kVoid, "broken").define(
Return(0)
);
}
/* TODO: detect this case
{
ExpectError error(r, "error: expected function to return 'float'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
);
}
*/
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) { DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu()); AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var a(kFloat, "a"), b(kFloat, "b"); Var a(kFloat, "a"), b(kFloat, "b");
@ -954,6 +1013,16 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
} }
} }
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLReturn, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Statement x = Return();
REPORTER_ASSERT(r, whitespace_insensitive_compare(x, "return;"));
Statement y = Return(true);
REPORTER_ASSERT(r, whitespace_insensitive_compare(y, "return true;"));
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSwizzle, r, ctxInfo) { DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSwizzle, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu()); AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var a(kFloat4, "a"); Var a(kFloat4, "a");