Added SkSL DSLFunction
Change-Id: Ibc995e908e5b4f8d1516e13d56854a4fcf5cc809 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/360556 Commit-Queue: Ethan Nicholas <ethannicholas@google.com> Reviewed-by: John Stiles <johnstiles@google.com> Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
parent
0bad6cf145
commit
1ff760981d
@ -61,6 +61,7 @@ skia_sksl_sources = [
|
||||
"$_src/sksl/dsl/DSLBlock.cpp",
|
||||
"$_src/sksl/dsl/DSLCore.cpp",
|
||||
"$_src/sksl/dsl/DSLExpression.cpp",
|
||||
"$_src/sksl/dsl/DSLFunction.cpp",
|
||||
"$_src/sksl/dsl/DSLStatement.cpp",
|
||||
"$_src/sksl/dsl/DSLType.cpp",
|
||||
"$_src/sksl/dsl/DSLVar.cpp",
|
||||
|
@ -820,18 +820,14 @@ std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(const ASTNode
|
||||
return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
|
||||
}
|
||||
|
||||
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
|
||||
SkASSERT(r.fKind == ASTNode::Kind::kReturn);
|
||||
std::unique_ptr<Statement> IRGenerator::convertReturn(int offset,
|
||||
std::unique_ptr<Expression> result) {
|
||||
SkASSERT(fCurrentFunction);
|
||||
// early returns from a vertex main function will bypass the sk_Position normalization, so
|
||||
// SkASSERT that we aren't doing that. It is of course possible to fix this by adding a
|
||||
// normalization before each return, but it will probably never actually be necessary.
|
||||
SkASSERT(Program::kVertex_Kind != fKind || !fRTAdjust || "main" != fCurrentFunction->name());
|
||||
if (r.begin() != r.end()) {
|
||||
std::unique_ptr<Expression> result = this->convertExpression(*r.begin());
|
||||
if (!result) {
|
||||
return nullptr;
|
||||
}
|
||||
if (result) {
|
||||
if (fCurrentFunction->returnType() == *fContext.fTypes.fVoid) {
|
||||
this->errorReporter().error(result->fOffset,
|
||||
"may not return a value from a void function");
|
||||
@ -845,11 +841,24 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
|
||||
return std::make_unique<ReturnStatement>(std::move(result));
|
||||
} else {
|
||||
if (fCurrentFunction->returnType() != *fContext.fTypes.fVoid) {
|
||||
this->errorReporter().error(r.fOffset,
|
||||
"expected function to return '" +
|
||||
this->errorReporter().error(offset, "expected function to return '" +
|
||||
fCurrentFunction->returnType().displayName() + "'");
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_unique<ReturnStatement>(r.fOffset);
|
||||
return std::make_unique<ReturnStatement>(offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
|
||||
SkASSERT(r.fKind == ASTNode::Kind::kReturn);
|
||||
if (r.begin() != r.end()) {
|
||||
std::unique_ptr<Expression> value = this->convertExpression(*r.begin());
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
return this->convertReturn(r.fOffset, std::move(value));
|
||||
} else {
|
||||
return this->convertReturn(r.fOffset, /*result=*/nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,7 @@ namespace SkSL {
|
||||
|
||||
namespace dsl {
|
||||
class DSLCore;
|
||||
class DSLFunction;
|
||||
class DSLVar;
|
||||
class DSLWriter;
|
||||
}
|
||||
@ -224,6 +225,7 @@ private:
|
||||
std::unique_ptr<InterfaceBlock> convertInterfaceBlock(const ASTNode& s);
|
||||
Modifiers convertModifiers(const Modifiers& m);
|
||||
std::unique_ptr<Expression> convertPrefixExpression(const ASTNode& expression);
|
||||
std::unique_ptr<Statement> convertReturn(int offset, std::unique_ptr<Expression> result);
|
||||
std::unique_ptr<Statement> convertReturn(const ASTNode& r);
|
||||
std::unique_ptr<Section> convertSection(const ASTNode& e);
|
||||
std::unique_ptr<Expression> convertCallExpression(const ASTNode& expression);
|
||||
@ -303,6 +305,7 @@ private:
|
||||
friend class AutoDisableInline;
|
||||
friend class Compiler;
|
||||
friend class dsl::DSLCore;
|
||||
friend class dsl::DSLFunction;
|
||||
friend class dsl::DSLVar;
|
||||
friend class dsl::DSLWriter;
|
||||
};
|
||||
|
@ -16,6 +16,7 @@ namespace dsl {
|
||||
|
||||
using Block = DSLBlock;
|
||||
using Expression = DSLExpression;
|
||||
using Function = DSLFunction;
|
||||
using Modifiers = DSLModifiers;
|
||||
using Statement = DSLStatement;
|
||||
using Var = DSLVar;
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "src/sksl/ir/SkSLDoStatement.h"
|
||||
#include "src/sksl/ir/SkSLForStatement.h"
|
||||
#include "src/sksl/ir/SkSLIfStatement.h"
|
||||
#include "src/sksl/ir/SkSLReturnStatement.h"
|
||||
|
||||
namespace SkSL {
|
||||
|
||||
@ -53,6 +54,14 @@ static char swizzle_component(SwizzleComponent c) {
|
||||
|
||||
class DSLCore {
|
||||
public:
|
||||
static DSLVar sk_FragColor() {
|
||||
return DSLVar("sk_FragColor");
|
||||
}
|
||||
|
||||
static DSLVar sk_FragCoord() {
|
||||
return DSLVar("sk_FragCoord");
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static DSLExpression Call(const char* name, Args... args) {
|
||||
SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
|
||||
@ -95,6 +104,19 @@ public:
|
||||
ifTrue.release(), ifFalse.release());
|
||||
}
|
||||
|
||||
static DSLStatement Return(DSLExpression value) {
|
||||
// note that because Return is called before the function in which it resides exists, at
|
||||
// this point we do not know the function's return type. We therefore do not check for
|
||||
// errors, or coerce the value to the correct type, until the return statement is actually
|
||||
// added to a function
|
||||
std::unique_ptr<SkSL::Expression> expr = value.release();
|
||||
if (expr) {
|
||||
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(std::move(expr)));
|
||||
} else {
|
||||
return std::unique_ptr<SkSL::Statement>(new ReturnStatement(/*offset=*/-1));
|
||||
}
|
||||
}
|
||||
|
||||
static DSLExpression Swizzle(DSLExpression base, SwizzleComponent a) {
|
||||
char mask[] = { swizzle_component(a), 0 };
|
||||
return DSLWriter::IRGenerator().convertSwizzle(base.release(), mask);
|
||||
@ -131,6 +153,14 @@ private:
|
||||
static void ignore(std::unique_ptr<SkSL::Expression>&) {}
|
||||
};
|
||||
|
||||
DSLVar sk_FragColor() {
|
||||
return DSLCore::sk_FragColor();
|
||||
}
|
||||
|
||||
DSLVar sk_FragCoord() {
|
||||
return DSLCore::sk_FragCoord();
|
||||
}
|
||||
|
||||
DSLStatement Declare(DSLVar& var, DSLExpression initialValue) {
|
||||
return DSLCore::Declare(var, std::move(initialValue));
|
||||
}
|
||||
@ -148,6 +178,10 @@ DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse) {
|
||||
return DSLCore::If(std::move(test), std::move(ifTrue), std::move(ifFalse));
|
||||
}
|
||||
|
||||
DSLStatement Return(DSLExpression expr) {
|
||||
return DSLCore::Return(std::move(expr));
|
||||
}
|
||||
|
||||
DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse) {
|
||||
return DSLCore::Ternary(std::move(test), std::move(ifTrue), std::move(ifFalse));
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include "src/sksl/dsl/DSLBlock.h"
|
||||
#include "src/sksl/dsl/DSLExpression.h"
|
||||
#include "src/sksl/dsl/DSLFunction.h"
|
||||
#include "src/sksl/dsl/DSLStatement.h"
|
||||
#include "src/sksl/dsl/DSLType.h"
|
||||
#include "src/sksl/dsl/DSLVar.h"
|
||||
@ -48,6 +49,10 @@ void End();
|
||||
*/
|
||||
void SetErrorHandler(ErrorHandler* errorHandler);
|
||||
|
||||
DSLVar sk_FragColor();
|
||||
|
||||
DSLVar sk_FragCoord();
|
||||
|
||||
/**
|
||||
* Creates a variable declaration statement with an initial value.
|
||||
*/
|
||||
@ -69,6 +74,11 @@ DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression nex
|
||||
*/
|
||||
DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse = DSLStatement());
|
||||
|
||||
/**
|
||||
* return [value];
|
||||
*/
|
||||
DSLStatement Return(DSLExpression value = DSLExpression());
|
||||
|
||||
/**
|
||||
* test ? ifTrue : ifFalse
|
||||
*/
|
||||
|
56
src/sksl/dsl/DSLFunction.cpp
Normal file
56
src/sksl/dsl/DSLFunction.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright 2021 Google LLC.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license that can be
|
||||
* found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include "src/sksl/dsl/DSLFunction.h"
|
||||
|
||||
#include "src/sksl/SkSLAnalysis.h"
|
||||
#include "src/sksl/SkSLCompiler.h"
|
||||
#include "src/sksl/SkSLIRGenerator.h"
|
||||
#include "src/sksl/ir/SkSLReturnStatement.h"
|
||||
|
||||
namespace SkSL {
|
||||
|
||||
namespace dsl {
|
||||
|
||||
void DSLFunction::define(DSLBlock block) {
|
||||
SkASSERT(fDecl);
|
||||
SkASSERT(!DSLWriter::CurrentFunction());
|
||||
DSLWriter::SetCurrentFunction(fDecl);
|
||||
class FinalizeReturns : public ProgramWriter {
|
||||
public:
|
||||
bool visitStatement(Statement& stmt) override {
|
||||
if (stmt.is<ReturnStatement>()) {
|
||||
ReturnStatement& r = stmt.as<ReturnStatement>();
|
||||
std::unique_ptr<Statement> finished = DSLWriter::IRGenerator().convertReturn(
|
||||
r.fOffset,
|
||||
std::move(r.expression()));
|
||||
if (finished) {
|
||||
r.setExpression(std::move(finished->as<ReturnStatement>().expression()));
|
||||
} else {
|
||||
SkASSERT(DSLWriter::Compiler().errorCount());
|
||||
DSLWriter::ReportError(
|
||||
DSLWriter::Compiler().errorText(/*showCount=*/false).c_str());
|
||||
}
|
||||
}
|
||||
return INHERITED::visitStatement(stmt);
|
||||
}
|
||||
|
||||
private:
|
||||
using INHERITED = ProgramWriter;
|
||||
};
|
||||
std::unique_ptr<Statement> body = block.release();
|
||||
FinalizeReturns().visitStatement(*body);
|
||||
DSLWriter::ProgramElements().emplace_back(new SkSL::FunctionDefinition(/*offset=*/-1,
|
||||
fDecl,
|
||||
/*builtin=*/false,
|
||||
std::move(body)));
|
||||
DSLWriter::SetCurrentFunction(nullptr);
|
||||
}
|
||||
|
||||
} // namespace dsl
|
||||
|
||||
} // namespace SkSL
|
63
src/sksl/dsl/DSLFunction.h
Normal file
63
src/sksl/dsl/DSLFunction.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright 2021 Google LLC.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license that can be
|
||||
* found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#ifndef SKSL_DSL_FUNCTION
|
||||
#define SKSL_DSL_FUNCTION
|
||||
|
||||
#include "src/sksl/SkSLString.h"
|
||||
#include "src/sksl/dsl/DSLBlock.h"
|
||||
#include "src/sksl/dsl/DSLType.h"
|
||||
#include "src/sksl/dsl/priv/DSLWriter.h"
|
||||
#include "src/sksl/ir/SkSLBlock.h"
|
||||
#include "src/sksl/ir/SkSLFunctionDefinition.h"
|
||||
|
||||
namespace SkSL {
|
||||
|
||||
class Block;
|
||||
class Variable;
|
||||
|
||||
namespace dsl {
|
||||
|
||||
class DSLType;
|
||||
|
||||
class DSLFunction {
|
||||
public:
|
||||
template<class... Parameters>
|
||||
DSLFunction(const DSLType& returnType, const char* name, Parameters&... parameters)
|
||||
: fReturnType(&returnType.skslType()) {
|
||||
std::vector<const Variable*> parameterArray;
|
||||
parameterArray.reserve(sizeof...(parameters));
|
||||
(parameterArray.push_back(&DSLWriter::Var(parameters)), ...);
|
||||
SkSL::SymbolTable& symbols = *DSLWriter::SymbolTable();
|
||||
fDecl = symbols.add(std::make_unique<SkSL::FunctionDeclaration>(
|
||||
/*offset=*/-1,
|
||||
DSLWriter::Modifiers(SkSL::Modifiers()),
|
||||
DSLWriter::Name(name),
|
||||
std::move(parameterArray), fReturnType,
|
||||
/*builtin=*/false));
|
||||
}
|
||||
|
||||
virtual ~DSLFunction() = default;
|
||||
|
||||
template<class... Stmt>
|
||||
void define(Stmt... stmts) {
|
||||
DSLBlock block = DSLBlock(DSLStatement(std::move(stmts))...);
|
||||
this->define(std::move(block));
|
||||
}
|
||||
|
||||
void define(DSLBlock block);
|
||||
|
||||
protected:
|
||||
const SkSL::Type* fReturnType;
|
||||
const SkSL::FunctionDeclaration* fDecl;
|
||||
};
|
||||
|
||||
} // namespace dsl
|
||||
|
||||
} // namespace SkSL
|
||||
|
||||
#endif
|
@ -57,6 +57,14 @@ const char* DSLWriter::Name(const char* name) {
|
||||
return name;
|
||||
}
|
||||
|
||||
const SkSL::FunctionDeclaration* DSLWriter::CurrentFunction() {
|
||||
return IRGenerator().fCurrentFunction;
|
||||
}
|
||||
|
||||
void DSLWriter::SetCurrentFunction(const SkSL::FunctionDeclaration* fn) {
|
||||
IRGenerator().fCurrentFunction = fn;
|
||||
}
|
||||
|
||||
std::unique_ptr<SkSL::Expression> DSLWriter::Check(std::unique_ptr<SkSL::Expression> expr) {
|
||||
if (expr == nullptr) {
|
||||
if (DSLWriter::Compiler().errorCount()) {
|
||||
|
@ -58,6 +58,13 @@ public:
|
||||
*/
|
||||
static const SkSL::Context& Context();
|
||||
|
||||
/**
|
||||
* Returns the collection to which DSL program elements in this thread should be appended.
|
||||
*/
|
||||
static std::vector<std::unique_ptr<SkSL::ProgramElement>>& ProgramElements() {
|
||||
return Instance().fProgramElements;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the SymbolTable of the current thread's IRGenerator.
|
||||
*/
|
||||
@ -80,6 +87,16 @@ public:
|
||||
*/
|
||||
static const char* Name(const char* name);
|
||||
|
||||
/**
|
||||
* Returns the current function for which we are generating nodes.
|
||||
*/
|
||||
static const SkSL::FunctionDeclaration* CurrentFunction();
|
||||
|
||||
/**
|
||||
* Specifies the function for which we are generating nodes.
|
||||
*/
|
||||
static void SetCurrentFunction(const SkSL::FunctionDeclaration* fn);
|
||||
|
||||
/**
|
||||
* Reports an error if the argument is null. Returns its argument unmodified.
|
||||
*/
|
||||
@ -128,6 +145,7 @@ public:
|
||||
private:
|
||||
SkSL::Program::Settings fSettings;
|
||||
SkSL::Compiler* fCompiler;
|
||||
std::vector<std::unique_ptr<SkSL::ProgramElement>> fProgramElements;
|
||||
ErrorHandler* fErrorHandler = nullptr;
|
||||
bool fMangle = true;
|
||||
Mangler fMangler;
|
||||
|
@ -35,6 +35,10 @@ public:
|
||||
return fExpression;
|
||||
}
|
||||
|
||||
void setExpression(std::unique_ptr<Expression> expr) {
|
||||
fExpression = std::move(expr);
|
||||
}
|
||||
|
||||
std::unique_ptr<Statement> clone() const override {
|
||||
if (this->expression()) {
|
||||
return std::unique_ptr<Statement>(new ReturnStatement(this->expression()->clone()));
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "src/sksl/SkSLIRGenerator.h"
|
||||
#include "src/sksl/dsl/DSL.h"
|
||||
#include "src/sksl/dsl/priv/DSLWriter.h"
|
||||
#include "src/sksl/ir/SkSLIRNode.h"
|
||||
|
||||
#include "tests/Test.h"
|
||||
|
||||
@ -76,6 +77,10 @@ static bool whitespace_insensitive_compare(DSLStatement& stmt, const char* descr
|
||||
return whitespace_insensitive_compare(stmt.release()->description().c_str(), description);
|
||||
}
|
||||
|
||||
static bool whitespace_insensitive_compare(SkSL::IRNode& node, const char* description) {
|
||||
return whitespace_insensitive_compare(node.description().c_str(), description);
|
||||
}
|
||||
|
||||
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLStartup, r, ctxInfo) {
|
||||
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
|
||||
Expression e1 = 1;
|
||||
@ -939,6 +944,60 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLFor, r, ctxInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
DEF_GPUTEST_FOR_ALL_CONTEXTS(DSLFunction, r, ctxInfo) {
|
||||
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
|
||||
DSLWriter::ProgramElements().clear();
|
||||
Var coords(kHalf2, "coords");
|
||||
DSLFunction(kVoid, "main", coords).define(
|
||||
sk_FragColor() = Half4(coords, 0, 1)
|
||||
);
|
||||
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
|
||||
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
|
||||
"void main(half2 coords) { (sk_FragColor = half4(coords, 0.0, 1.0)); }"));
|
||||
|
||||
DSLWriter::ProgramElements().clear();
|
||||
Var x(kFloat, "x");
|
||||
DSLFunction(kFloat, "sqr", x).define(
|
||||
Return(x * x)
|
||||
);
|
||||
REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
|
||||
REPORTER_ASSERT(r, whitespace_insensitive_compare(*DSLWriter::ProgramElements()[0],
|
||||
"float sqr(float x) { return (x * x); }"));
|
||||
|
||||
{
|
||||
ExpectError error(r, "error: expected 'float', but found 'bool'\n");
|
||||
DSLWriter::ProgramElements().clear();
|
||||
DSLFunction(kFloat, "broken").define(
|
||||
Return(true)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
ExpectError error(r, "error: expected function to return 'float'\n");
|
||||
DSLWriter::ProgramElements().clear();
|
||||
DSLFunction(kFloat, "broken").define(
|
||||
Return()
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
ExpectError error(r, "error: may not return a value from a void function\n");
|
||||
DSLWriter::ProgramElements().clear();
|
||||
DSLFunction(kVoid, "broken").define(
|
||||
Return(0)
|
||||
);
|
||||
}
|
||||
|
||||
/* TODO: detect this case
|
||||
{
|
||||
ExpectError error(r, "error: expected function to return 'float'\n");
|
||||
DSLWriter::ProgramElements().clear();
|
||||
DSLFunction(kFloat, "broken").define(
|
||||
);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
|
||||
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
|
||||
Var a(kFloat, "a"), b(kFloat, "b");
|
||||
@ -954,6 +1013,16 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLIf, r, ctxInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLReturn, r, ctxInfo) {
|
||||
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
|
||||
|
||||
Statement x = Return();
|
||||
REPORTER_ASSERT(r, whitespace_insensitive_compare(x, "return;"));
|
||||
|
||||
Statement y = Return(true);
|
||||
REPORTER_ASSERT(r, whitespace_insensitive_compare(y, "return true;"));
|
||||
}
|
||||
|
||||
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSwizzle, r, ctxInfo) {
|
||||
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
|
||||
Var a(kFloat4, "a");
|
||||
|
Loading…
Reference in New Issue
Block a user