Added SkSL DSLFunction

Change-Id: Ibc995e908e5b4f8d1516e13d56854a4fcf5cc809
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/360556
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
Ethan Nicholas 2021-01-28 10:02:43 -05:00 committed by Skia Commit-Bot
parent 0bad6cf145
commit 1ff760981d
12 changed files with 286 additions and 10 deletions

View File

@ -61,6 +61,7 @@ skia_sksl_sources = [
"$_src/sksl/dsl/DSLBlock.cpp",
"$_src/sksl/dsl/DSLCore.cpp",
"$_src/sksl/dsl/DSLExpression.cpp",
"$_src/sksl/dsl/DSLFunction.cpp",
"$_src/sksl/dsl/DSLStatement.cpp",
"$_src/sksl/dsl/DSLType.cpp",
"$_src/sksl/dsl/DSLVar.cpp",

View File

@ -820,18 +820,14 @@ std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(const ASTNode
return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
}
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
SkASSERT(r.fKind == ASTNode::Kind::kReturn);
std::unique_ptr<Statement> IRGenerator::convertReturn(int offset,
std::unique_ptr<Expression> result) {
SkASSERT(fCurrentFunction);
// early returns from a vertex main function will bypass the sk_Position normalization, so
// SkASSERT that we aren't doing that. It is of course possible to fix this by adding a
// normalization before each return, but it will probably never actually be necessary.
SkASSERT(Program::kVertex_Kind != fKind || !fRTAdjust || "main" != fCurrentFunction->name());
if (r.begin() != r.end()) {
std::unique_ptr<Expression> result = this->convertExpression(*r.begin());
if (!result) {
return nullptr;
}
if (result) {
if (fCurrentFunction->returnType() == *fContext.fTypes.fVoid) {
this->errorReporter().error(result->fOffset,
"may not return a value from a void function");
@ -845,11 +841,24 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
return std::make_unique<ReturnStatement>(std::move(result));
} else {
if (fCurrentFunction->returnType() != *fContext.fTypes.fVoid) {
this->errorReporter().error(r.fOffset,
"expected function to return '" +
this->errorReporter().error(offset, "expected function to return '" +
fCurrentFunction->returnType().displayName() + "'");
return nullptr;
}
return std::make_unique<ReturnStatement>(r.fOffset);
return std::make_unique<ReturnStatement>(offset);
}
}
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
SkASSERT(r.fKind == ASTNode::Kind::kReturn);
if (r.begin() != r.end()) {
std::unique_ptr<Expression> value = this->convertExpression(*r.begin());
if (!value) {
return nullptr;
}
return this->convertReturn(r.fOffset, std::move(value));
} else {
return this->convertReturn(r.fOffset, /*result=*/nullptr);
}
}

View File

@ -34,6 +34,7 @@ namespace SkSL {
namespace dsl {
class DSLCore;
class DSLFunction;
class DSLVar;
class DSLWriter;
}
@ -224,6 +225,7 @@ private:
std::unique_ptr<InterfaceBlock> convertInterfaceBlock(const ASTNode& s);
Modifiers convertModifiers(const Modifiers& m);
std::unique_ptr<Expression> convertPrefixExpression(const ASTNode& expression);
std::unique_ptr<Statement> convertReturn(int offset, std::unique_ptr<Expression> result);
std::unique_ptr<Statement> convertReturn(const ASTNode& r);
std::unique_ptr<Section> convertSection(const ASTNode& e);
std::unique_ptr<Expression> convertCallExpression(const ASTNode& expression);
@ -303,6 +305,7 @@ private:
friend class AutoDisableInline;
friend class Compiler;
friend class dsl::DSLCore;
friend class dsl::DSLFunction;
friend class dsl::DSLVar;
friend class dsl::DSLWriter;
};

View File

@ -16,6 +16,7 @@ namespace dsl {
using Block = DSLBlock;
using Expression = DSLExpression;
using Function = DSLFunction;
using Modifiers = DSLModifiers;
using Statement = DSLStatement;
using Var = DSLVar;

View File

@ -13,6 +13,7 @@
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLForStatement.h"
#include "src/sksl/ir/SkSLIfStatement.h"
#include "src/sksl/ir/SkSLReturnStatement.h"
namespace SkSL {
@ -53,6 +54,14 @@ static char swizzle_component(SwizzleComponent c) {
class DSLCore {
public:
static DSLVar sk_FragColor() {
return DSLVar("sk_FragColor");
}
static DSLVar sk_FragCoord() {
return DSLVar("sk_FragCoord");
}
template <typename... Args>
static DSLExpression Call(const char* name, Args... args) {
SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
@ -95,6 +104,19 @@ public:
ifTrue.release(), ifFalse.release());
}
static DSLStatement Return(DSLExpression value) {
// note that because Return is called before the function in which it resides exists, at
// this point we do not know the function's return type. We therefore do not check for
// errors, or coerce the value to the correct type, until the return statement is actually
// added to a function
std::unique_ptr<SkSL::Expression> expr = value.release();
if (expr) {
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(std::move(expr)));
} else {
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(/*offset=*/-1));
}
}
static DSLExpression Swizzle(DSLExpression base, SwizzleComponent a) {
char mask[] = { swizzle_component(a), 0 };
return DSLWriter::IRGenerator().convertSwizzle(base.release(), mask);
@ -131,6 +153,14 @@ private:
static void ignore(std::unique_ptr<SkSL::Expression>&) {}
};
DSLVar sk_FragColor() {
return DSLCore::sk_FragColor();
}
DSLVar sk_FragCoord() {
return DSLCore::sk_FragCoord();
}
DSLStatement Declare(DSLVar& var, DSLExpression initialValue) {
return DSLCore::Declare(var, std::move(initialValue));
}
@ -148,6 +178,10 @@ DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse) {
return DSLCore::If(std::move(test), std::move(ifTrue), std::move(ifFalse));
}
DSLStatement Return(DSLExpression expr) {
return DSLCore::Return(std::move(expr));
}
DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse) {
return DSLCore::Ternary(std::move(test), std::move(ifTrue), std::move(ifFalse));
}

View File

@ -10,6 +10,7 @@
#include "src/sksl/dsl/DSLBlock.h"
#include "src/sksl/dsl/DSLExpression.h"
#include "src/sksl/dsl/DSLFunction.h"
#include "src/sksl/dsl/DSLStatement.h"
#include "src/sksl/dsl/DSLType.h"
#include "src/sksl/dsl/DSLVar.h"
@ -48,6 +49,10 @@ void End();
*/
void SetErrorHandler(ErrorHandler* errorHandler);
DSLVar sk_FragColor();
DSLVar sk_FragCoord();
/**
* Creates a variable declaration statement with an initial value.
*/
@ -69,6 +74,11 @@ DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression nex
*/
DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse = DSLStatement());
/**
* return [value];
*/
DSLStatement Return(DSLExpression value = DSLExpression());
/**
* test ? ifTrue : ifFalse
*/

View File

@ -0,0 +1,56 @@
/*
* Copyright 2021 Google LLC.
*
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#include "src/sksl/dsl/DSLFunction.h"
#include "src/sksl/SkSLAnalysis.h"
#include "src/sksl/SkSLCompiler.h"
#include "src/sksl/SkSLIRGenerator.h"
#include "src/sksl/ir/SkSLReturnStatement.h"
namespace SkSL {
namespace dsl {
void DSLFunction::define(DSLBlock block) {
SkASSERT(fDecl);
SkASSERT(!DSLWriter::CurrentFunction());
DSLWriter::SetCurrentFunction(fDecl);
class FinalizeReturns : public ProgramWriter {
public:
bool visitStatement(Statement& stmt) override {
if (stmt.is<ReturnStatement>()) {
ReturnStatement& r = stmt.as<ReturnStatement>();
std::unique_ptr<Statement> finished = DSLWriter::IRGenerator().convertReturn(
r.fOffset,
std::move(r.expression()));
if (finished) {
r.setExpression(std::move(finished->as<ReturnStatement>().expression()));
} else {
SkASSERT(DSLWriter::Compiler().errorCount());
DSLWriter::ReportError(
DSLWriter::Compiler().errorText(/*showCount=*/false).c_str());
}
}
return INHERITED::visitStatement(stmt);
}
private:
using INHERITED = ProgramWriter;
};
std::unique_ptr<Statement> body = block.release();
FinalizeReturns().visitStatement(*body);
DSLWriter::ProgramElements().emplace_back(new SkSL::FunctionDefinition(/*offset=*/-1,
fDecl,
/*builtin=*/false,
std::move(body)));
DSLWriter::SetCurrentFunction(nullptr);
}
} // namespace dsl
} // namespace SkSL

View File

@ -0,0 +1,63 @@
/*
* Copyright 2021 Google LLC.
*
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#ifndef SKSL_DSL_FUNCTION
#define SKSL_DSL_FUNCTION
#include "src/sksl/SkSLString.h"
#include "src/sksl/dsl/DSLBlock.h"
#include "src/sksl/dsl/DSLType.h"
#include "src/sksl/dsl/priv/DSLWriter.h"
#include "src/sksl/ir/SkSLBlock.h"
#include "src/sksl/ir/SkSLFunctionDefinition.h"
namespace SkSL {
class Block;
class Variable;
namespace dsl {
class DSLType;
class DSLFunction {
public:
template<class... Parameters>
DSLFunction(const DSLType& returnType, const char* name, Parameters&... parameters)
: fReturnType(&returnType.skslType()) {
std::vector<const Variable*> parameterArray;
parameterArray.reserve(sizeof...(parameters));
(parameterArray.push_back(&DSLWriter::Var(parameters)), ...);
SkSL::SymbolTable& symbols = *DSLWriter::SymbolTable();
fDecl = symbols.add(std::make_unique<SkSL::FunctionDeclaration>(
/*offset=*/-1,
DSLWriter::Modifiers(SkSL::Modifiers()),
DSLWriter::Name(name),
std::move(parameterArray), fReturnType,
/*builtin=*/false));
}
virtual ~DSLFunction() = default;
template<class... Stmt>
void define(Stmt... stmts) {
DSLBlock block = DSLBlock(DSLStatement(std::move(stmts))...);
this->define(std::move(block));
}
void define(DSLBlock block);
protected:
const SkSL::Type* fReturnType;
const SkSL::FunctionDeclaration* fDecl;
};
} // namespace dsl
} // namespace SkSL
#endif

View File

@ -57,6 +57,14 @@ const char* DSLWriter::Name(const char* name) {
return name;
}
const SkSL::FunctionDeclaration* DSLWriter::CurrentFunction() {
return IRGenerator().fCurrentFunction;
}
void DSLWriter::SetCurrentFunction(const SkSL::FunctionDeclaration* fn) {
IRGenerator().fCurrentFunction = fn;
}
std::unique_ptr<SkSL::Expression> DSLWriter::Check(std::unique_ptr<SkSL::Expression> expr) {
if (expr == nullptr) {
if (DSLWriter::Compiler().errorCount()) {

View File

@ -58,6 +58,13 @@ public:
*/
static const SkSL::Context& Context();
/**
* Returns the collection to which DSL program elements in this thread should be appended.
*/
static std::vector<std::unique_ptr<SkSL::ProgramElement>>& ProgramElements() {
return Instance().fProgramElements;
}
/**
* Returns the SymbolTable of the current thread's IRGenerator.
*/
@ -80,6 +87,16 @@ public:
*/
static const char* Name(const char* name);
/**
* Returns the current function for which we are generating nodes.
*/
static const SkSL::FunctionDeclaration* CurrentFunction();
/**
* Specifies the function for which we are generating nodes.
*/
static void SetCurrentFunction(const SkSL::FunctionDeclaration* fn);
/**
* Reports an error if the argument is null. Returns its argument unmodified.
*/
@ -128,6 +145,7 @@ public:
private:
SkSL::Program::Settings fSettings;
SkSL::Compiler* fCompiler;
std::vector<std::unique_ptr<SkSL::ProgramElement>> fProgramElements;
ErrorHandler* fErrorHandler = nullptr;
bool fMangle = true;
Mangler fMangler;

View File

@ -35,6 +35,10 @@ public:
return fExpression;
}
void setExpression(std::unique_ptr<Expression> expr) {
fExpression = std::move(expr);
}
std::unique_ptr<Statement> clone() const override {
if (this->expression()) {
return std::unique_ptr<Statement>(new ReturnStatement(this->expression()->clone()));

View File

@ -10,6 +10,7 @@
#include "src/sksl/SkSLIRGenerator.h"
#include "src/sksl/dsl/DSL.h"
#include "src/sksl/dsl/priv/DSLWriter.h"
#include "src/sksl/ir/SkSLIRNode.h"
#include "tests/Test.h"
@ -76,6 +77,10 @@ static bool whitespace_insensitive_compare(DSLStatement& stmt, const char* descr
return whitespace_insensitive_compare(stmt.release()->description().c_str(), description);
}
static bool whitespace_insensitive_compare(SkSL::IRNode& node, const char* description) {
return whitespace_insensitive_compare(node.description().c_str(), description);
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLStartup, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Expression e1 = 1;
@ -939,6 +944,60 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLFor, r, ctxInfo) {
}
}
DEF_GPUTEST_FOR_ALL_CONTEXTS(DSLFunction, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
DSLWriter::ProgramElements().clear();
Var coords(kHalf2, "coords");
DSLFunction(kVoid, "main", coords).define(
sk_FragColor() = Half4(coords, 0, 1)
);
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
"void main(half2 coords) { (sk_FragColor = half4(coords, 0.0, 1.0)); }"));
DSLWriter::ProgramElements().clear();
Var x(kFloat, "x");
DSLFunction(kFloat, "sqr", x).define(
Return(x * x)
);
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
"float sqr(float x) { return (x * x); }"));
{
ExpectError error(r, "error: expected 'float', but found 'bool'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
Return(true)
);
}
{
ExpectError error(r, "error: expected function to return 'float'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
Return()
);
}
{
ExpectError error(r, "error: may not return a value from a void function\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kVoid, "broken").define(
Return(0)
);
}
/* TODO: detect this case
{
ExpectError error(r, "error: expected function to return 'float'\n");
DSLWriter::ProgramElements().clear();
DSLFunction(kFloat, "broken").define(
);
}
*/
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var a(kFloat, "a"), b(kFloat, "b");
@ -954,6 +1013,16 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
}
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLReturn, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Statement x = Return();
REPORTER_ASSERT(r, whitespace_insensitive_compare(x, "return;"));
Statement y = Return(true);
REPORTER_ASSERT(r, whitespace_insensitive_compare(y, "return true;"));
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSwizzle, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var a(kFloat4, "a");