From 0e9401dafeb34b3459b528e78b6a9c47fe996089 Mon Sep 17 00:00:00 2001 From: Ethan Nicholas Date: Thu, 21 Mar 2019 11:05:37 -0400 Subject: [PATCH] Initial checkin of new SkSL interpreter. Not quite feature complete yet, but at a point where it's worth checking in. Bug: skia: Change-Id: I21141d30e8582a79e94450d84e56bacc067249e0 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/201685 Commit-Queue: Ethan Nicholas Reviewed-by: Brian Salomon --- gm/runtimecolorfilter.cpp | 24 +- gn/sksl.gni | 1 + gn/tests.gni | 1 + src/core/SkColorFilter.cpp | 63 ++- src/sksl/SkSLByteCode.h | 115 ++++ src/sksl/SkSLByteCodeGenerator.cpp | 787 ++++++++++++++++++++++++++ src/sksl/SkSLByteCodeGenerator.h | 241 ++++++++ src/sksl/SkSLCompiler.cpp | 13 + src/sksl/SkSLCompiler.h | 3 + src/sksl/SkSLInterpreter.cpp | 836 ++++++++++++++-------------- src/sksl/SkSLInterpreter.h | 72 ++- src/sksl/SkSLSPIRVCodeGenerator.cpp | 22 - src/sksl/SkSLUtil.cpp | 41 ++ src/sksl/SkSLUtil.h | 8 + tests/SkSLInterpreterTest.cpp | 200 +++++++ 15 files changed, 1949 insertions(+), 478 deletions(-) create mode 100644 src/sksl/SkSLByteCode.h create mode 100644 src/sksl/SkSLByteCodeGenerator.cpp create mode 100644 src/sksl/SkSLByteCodeGenerator.h create mode 100644 tests/SkSLInterpreterTest.cpp diff --git a/gm/runtimecolorfilter.cpp b/gm/runtimecolorfilter.cpp index ebc6a20619..0f84dffffe 100644 --- a/gm/runtimecolorfilter.cpp +++ b/gm/runtimecolorfilter.cpp @@ -33,7 +33,7 @@ DEF_SIMPLE_GPU_GM(runtimecolorfilter, context, rtc, canvas, 768, 256) { canvas->drawImage(img, 0, 0, nullptr); float b = 0.75; - sk_sp data = SkData::MakeWithoutCopy(&b, sizeof(b)); + sk_sp data = SkData::MakeWithCopy(&b, sizeof(b)); auto cf1 = SkRuntimeColorFilterFactory(SkString(SKSL_TEST_SRC), runtimeCpuFunc).make(data); SkPaint p; p.setColorFilter(cf1); @@ -49,3 +49,25 @@ DEF_SIMPLE_GPU_GM(runtimecolorfilter, context, rtc, canvas, 768, 256) { p.setColorFilter(cf2); canvas->drawImage(img, 512, 0, &p); } + +DEF_SIMPLE_GM(runtimecolorfilter_interpreted, canvas, 768, 256) { + auto img = GetResourceAsImage("images/mandrill_256.png"); + canvas->drawImage(img, 0, 0, nullptr); + + float b = 0.75; + sk_sp data = SkData::MakeWithCopy(&b, sizeof(b)); + auto cf1 = SkRuntimeColorFilterFactory(SkString(SKSL_TEST_SRC), nullptr).make(data); + SkPaint p; + p.setColorFilter(cf1); + canvas->drawImage(img, 256, 0, &p); + + static constexpr size_t kBufferSize = 512; + char buffer[kBufferSize]; + SkBinaryWriteBuffer wb(buffer, kBufferSize); + wb.writeFlattenable(cf1.get()); + SkReadBuffer rb(buffer, kBufferSize); + auto cf2 = rb.readColorFilter(); + SkASSERT(cf2); + p.setColorFilter(cf2); + canvas->drawImage(img, 512, 0, &p); +} diff --git a/gn/sksl.gni b/gn/sksl.gni index 755fa95f5d..38c706a937 100644 --- a/gn/sksl.gni +++ b/gn/sksl.gni @@ -7,6 +7,7 @@ _src = get_path_info("../src", "abspath") skia_sksl_sources = [ + "$_src/sksl/SkSLByteCodeGenerator.cpp", "$_src/sksl/SkSLCFGGenerator.cpp", "$_src/sksl/SkSLCompiler.cpp", "$_src/sksl/SkSLCPPCodeGenerator.cpp", diff --git a/gn/tests.gni b/gn/tests.gni index 14738e3175..3ec84ec152 100644 --- a/gn/tests.gni +++ b/gn/tests.gni @@ -239,6 +239,7 @@ tests_sources = [ "$_tests/SkSLErrorTest.cpp", "$_tests/SkSLFPTest.cpp", "$_tests/SkSLGLSLTest.cpp", + "$_tests/SkSLInterpreterTest.cpp", "$_tests/SkSLJITTest.cpp", "$_tests/SkSLMemoryLayoutTest.cpp", "$_tests/SkSLMetalTest.cpp", diff --git a/src/core/SkColorFilter.cpp b/src/core/SkColorFilter.cpp index ab337b1312..8959c1b623 100644 --- a/src/core/SkColorFilter.cpp +++ b/src/core/SkColorFilter.cpp @@ -382,6 +382,8 @@ sk_sp SkColorFilter::MakeLerp(sk_sp cf0, #if SK_SUPPORT_GPU #include "effects/GrSkSLFP.h" #include "GrRecordingContext.h" +#include "SkSLByteCode.h" +#include "SkSLInterpreter.h" class SkRuntimeColorFilter : public SkColorFilter { public: @@ -402,24 +404,51 @@ public: #endif void onAppendStages(const SkStageRec& rec, bool shaderIsOpaque) const override { - // if this assert fails, it either means no CPU function was provided when this filter was - // created, or we have flattened and unflattened the filter, which nulls out this pointer. - // We don't currently have a means to flatten colorfilters containing CPU functions. - SkASSERT(fCpuFunction); - struct Ctx : public SkRasterPipeline_CallbackCtx { - SkRuntimeColorFilterFn cpuFn; - const void* inputs; - }; - auto ctx = rec.fAlloc->make(); - ctx->inputs = fInputs->data(); - ctx->cpuFn = fCpuFunction; - ctx->fn = [](SkRasterPipeline_CallbackCtx* arg, int active_pixels) { - auto ctx = (Ctx*)arg; - for (int i = 0; i < active_pixels; i++) { - ctx->cpuFn(ctx->rgba + i * 4, ctx->inputs); + if (fCpuFunction) { + struct CpuFuncCtx : public SkRasterPipeline_CallbackCtx { + SkRuntimeColorFilterFn cpuFn; + const void* inputs; + }; + auto ctx = rec.fAlloc->make(); + ctx->inputs = fInputs->data(); + ctx->cpuFn = fCpuFunction; + ctx->fn = [](SkRasterPipeline_CallbackCtx* arg, int active_pixels) { + auto ctx = (CpuFuncCtx*)arg; + for (int i = 0; i < active_pixels; i++) { + ctx->cpuFn(ctx->rgba + i * 4, ctx->inputs); + } + }; + rec.fPipeline->append(SkRasterPipeline::callback, ctx); + } else { + struct InterpreterCtx : public SkRasterPipeline_CallbackCtx { + SkSL::ByteCodeFunction* main; + std::unique_ptr interpreter; + const void* inputs; + }; + auto ctx = rec.fAlloc->make(); + ctx->inputs = fInputs->data(); + SkSL::Compiler c; + std::unique_ptr prog = + c.convertProgram(SkSL::Program::kPipelineStage_Kind, + SkSL::String(fSkSL.c_str()), + SkSL::Program::Settings()); + if (c.errorCount()) { + SkDebugf("%s\n", c.errorText().c_str()); + SkASSERT(false); } - }; - rec.fPipeline->append(SkRasterPipeline::callback, ctx); + std::unique_ptr byteCode = c.toByteCode(*prog); + ctx->main = byteCode->fFunctions[0].get(); + ctx->interpreter.reset(new SkSL::Interpreter(std::move(prog), std::move(byteCode))); + ctx->fn = [](SkRasterPipeline_CallbackCtx* arg, int active_pixels) { + auto ctx = (InterpreterCtx*)arg; + for (int i = 0; i < active_pixels; i++) { + ctx->interpreter->run(*ctx->main, + (SkSL::Interpreter::Value*) (ctx->rgba + i * 4), + (SkSL::Interpreter::Value*) ctx->inputs); + } + }; + rec.fPipeline->append(SkRasterPipeline::callback, ctx); + } } protected: diff --git a/src/sksl/SkSLByteCode.h b/src/sksl/SkSLByteCode.h new file mode 100644 index 0000000000..df097e4146 --- /dev/null +++ b/src/sksl/SkSLByteCode.h @@ -0,0 +1,115 @@ +/* + * Copyright 2019 Google LLC + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_BYTECODE +#define SKSL_BYTECODE + +#include "ir/SkSLFunctionDeclaration.h" + +namespace SkSL { + +enum class ByteCodeInstruction : uint8_t { + kInvalid, + // B = bool, F = float, I = int, S = signed, U = unsigned + kAddF, + kAddI, + kAndB, + kAndI, + kBranch, + kCompareIEQ, + kCompareINEQ, + kCompareFEQ, + kCompareFGT, + kCompareFGTEQ, + kCompareFLT, + kCompareFLTEQ, + kCompareFNEQ, + kCompareSGT, + kCompareSGTEQ, + kCompareSLT, + kCompareSLTEQ, + kCompareUGT, + kCompareUGTEQ, + kCompareULT, + kCompareULTEQ, + // Followed by a 16 bit address + kConditionalBranch, + // Pops and prints the top value from the stack + kDebugPrint, + kDivideF, + kDivideS, + kDivideU, + // Duplicates the top stack value + kDup, + // Followed by a byte indicating number of slots to copy below the underlying element. + // dupdown 2 yields: ... value3 value2 value1 => .. value2 value1 value3 value2 value2 + kDupDown, + kFloatToInt, + kSignedToFloat, + kUnsignedToFloat, + kLoad, + // Followed by a byte indicating global slot to load + kLoadGlobal, + // Followed by a count byte (1-4), and then one byte per swizzle component (0-3). + kLoadSwizzle, + kNegateF, + kNegateS, + kMultiplyF, + kMultiplyS, + kMultiplyU, + kNot, + kOrB, + kOrI, + // Followed by a byte indicating parameter slot to load + kParameter, + kPop, + // Followed by a 32 bit value containing the value to push + kPushImmediate, + kRemainderS, + kRemainderU, + kStore, + kStoreGlobal, + // Followed by a count byte (1-4), and then one byte per swizzle component (0-3). Expects the + // stack to look like: ... target v1 v2 v3 v4, where the number of 'v's is equal to the number + // of swizzle components. After the store, the target and all v's are popped from the stack. + kStoreSwizzle, + // Followed by two count bytes (1-4), and then one byte per swizzle component (0-3). The first + // count byte provides the current vector size (the vector is the top n stack elements), and the + // second count byte provides the swizzle component count. + kSwizzle, + kSubtractF, + kSubtractI, + // Followed by a byte indicating vector count. Modifies the next instruction to operate on the + // indicated number of columns, e.g. kVector 2 kMultiplyf performs a float2 * float2 operation. + kVector, +}; + +struct ByteCode; + +struct ByteCodeFunction { + ByteCodeFunction(const ByteCode* owner, const FunctionDeclaration* declaration) + : fOwner(*owner) + , fDeclaration(*declaration) {} + + const ByteCode& fOwner; + const FunctionDeclaration& fDeclaration; + int fParameterCount = 0; + int fLocalCount = 0; + std::vector fCode; +}; + +struct ByteCode { + int fGlobalCount = 0; + int fInputCount = 0; + // one entry per input slot, contains the global slot to which the input slot maps + std::vector fInputSlots; + std::vector> fFunctions; +}; + +} + +#endif diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp new file mode 100644 index 0000000000..5a1299db5c --- /dev/null +++ b/src/sksl/SkSLByteCodeGenerator.cpp @@ -0,0 +1,787 @@ +/* + * Copyright 2019 Google LLC + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#include "SkSLByteCodeGenerator.h" + +namespace SkSL { + +static int slot_count(const Type& type) { + return type.columns() * type.rows(); +} + +bool ByteCodeGenerator::generateCode() { + for (const auto& e : fProgram) { + switch (e.fKind) { + case ProgramElement::kFunction_Kind: { + std::unique_ptr f = this->writeFunction((FunctionDefinition&) e); + if (!f) { + return false; + } + fOutput->fFunctions.push_back(std::move(f)); + break; + } + case ProgramElement::kVar_Kind: { + VarDeclarations& decl = (VarDeclarations&) e; + for (const auto& v : decl.fVars) { + const Variable* declVar = ((VarDeclaration&) *v).fVar; + if (declVar->fModifiers.fLayout.fBuiltin >= 0) { + continue; + } + if (declVar->fModifiers.fFlags & Modifiers::kIn_Flag) { + for (int i = slot_count(declVar->fType); i > 0; --i) { + fOutput->fInputSlots.push_back(fOutput->fGlobalCount++); + } + } else { + fOutput->fGlobalCount += slot_count(declVar->fType); + } + } + break; + } + default: + ; // ignore + } + } + return true; +} + +std::unique_ptr ByteCodeGenerator::writeFunction(const FunctionDefinition& f) { + fFunction = &f; + std::unique_ptr result(new ByteCodeFunction(fOutput, &f.fDeclaration)); + fParameterCount = 0; + for (const auto& p : f.fDeclaration.fParameters) { + fParameterCount += p->fType.columns() * p->fType.rows(); + } + fCode = &result->fCode; + this->writeStatement(*f.fBody); + result->fParameterCount = fParameterCount; + result->fLocalCount = fLocals.size(); + fLocals.clear(); + fFunction = nullptr; + return result; +} + +enum class TypeCategory { + kBool, + kSigned, + kUnsigned, + kFloat, +}; + +static TypeCategory type_category(const Type& type) { + switch (type.kind()) { + case Type::Kind::kVector_Kind: + case Type::Kind::kMatrix_Kind: + return type_category(type.componentType()); + default: + if (type.fName == "bool") { + return TypeCategory::kBool; + } else if (type.fName == "int" || type.fName == "short") { + return TypeCategory::kSigned; + } else if (type.fName == "uint" || type.fName == "ushort") { + return TypeCategory::kUnsigned; + } else { + SkASSERT(type.fName == "float" || type.fName == "half"); + return TypeCategory::kFloat; + } + ABORT("unsupported type: %s\n", type.description().c_str()); + } +} + +int ByteCodeGenerator::getLocation(const Variable& var) { + // given that we seldom have more than a couple of variables, linear search is probably the most + // efficient way to handle lookups + switch (var.fStorage) { + case Variable::kLocal_Storage: { + for (int i = fLocals.size() - 1; i >= 0; --i) { + if (fLocals[i] == &var) { + return fParameterCount + i; + } + } + int result = fParameterCount + fLocals.size(); + fLocals.push_back(&var); + for (int i = 0; i < slot_count(var.fType) - 1; ++i) { + fLocals.push_back(nullptr); + } + return result; + } + case Variable::kParameter_Storage: { + int offset = 0; + for (const auto& p : fFunction->fDeclaration.fParameters) { + if (p == &var) { + return offset; + } + offset += slot_count(p->fType); + } + SkASSERT(false); + return -1; + } + case Variable::kGlobal_Storage: { + int offset = 0; + for (const auto& e : fProgram) { + if (e.fKind == ProgramElement::kVar_Kind) { + VarDeclarations& decl = (VarDeclarations&) e; + for (const auto& v : decl.fVars) { + const Variable* declVar = ((VarDeclaration&) *v).fVar; + if (declVar->fModifiers.fLayout.fBuiltin >= 0) { + continue; + } + if (declVar == &var) { + return offset; + } + offset += slot_count(declVar->fType); + } + } + } + SkASSERT(false); + return -1; + } + default: + SkASSERT(false); + return 0; + } +} + +void ByteCodeGenerator::write8(uint8_t b) { + fCode->push_back(b); +} + +void ByteCodeGenerator::write16(uint16_t i) { + this->write8(i >> 8); + this->write8(i); +} + +void ByteCodeGenerator::write32(uint32_t i) { + this->write8(i >> 24); + this->write8(i >> 16); + this->write8(i >> 8); + this->write8(i); +} + +void ByteCodeGenerator::write(ByteCodeInstruction i) { + this->write8((uint8_t) i); +} + +void ByteCodeGenerator::writeTypedInstruction(const Type& type, ByteCodeInstruction s, + ByteCodeInstruction u, ByteCodeInstruction f) { + switch (type_category(type)) { + case TypeCategory::kSigned: + this->write(s); + break; + case TypeCategory::kUnsigned: + this->write(u); + break; + case TypeCategory::kFloat: + this->write(f); + break; + default: + SkASSERT(false); + } +} + +void ByteCodeGenerator::writeBinaryExpression(const BinaryExpression& b) { + if (b.fOperator == Token::Kind::EQ) { + std::unique_ptr lvalue = this->getLValue(*b.fLeft); + this->writeExpression(*b.fRight); + this->write(ByteCodeInstruction::kDupDown); + this->write8(slot_count(b.fRight->fType)); + lvalue->store(); + return; + } + Token::Kind op; + std::unique_ptr lvalue; + if (is_assignment(b.fOperator)) { + lvalue = this->getLValue(*b.fLeft); + lvalue->load(); + op = remove_assignment(b.fOperator); + } else { + this->writeExpression(*b.fLeft); + op = b.fOperator; + if (b.fLeft->fType.kind() == Type::kScalar_Kind && + b.fRight->fType.kind() == Type::kVector_Kind) { + for (int i = b.fRight->fType.columns(); i > 1; --i) { + this->write(ByteCodeInstruction::kDup); + } + } + } + this->writeExpression(*b.fRight); + if (b.fLeft->fType.kind() == Type::kVector_Kind && + b.fRight->fType.kind() == Type::kScalar_Kind) { + for (int i = b.fLeft->fType.columns(); i > 1; --i) { + this->write(ByteCodeInstruction::kDup); + } + } + int count = slot_count(b.fType); + if (count > 1) { + this->write(ByteCodeInstruction::kVector); + this->write8(count); + } + switch (op) { + case Token::Kind::EQEQ: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ, + ByteCodeInstruction::kCompareIEQ, + ByteCodeInstruction::kCompareFEQ); + break; + case Token::Kind::GT: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGT, + ByteCodeInstruction::kCompareUGT, + ByteCodeInstruction::kCompareFGT); + break; + case Token::Kind::GTEQ: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGTEQ, + ByteCodeInstruction::kCompareUGTEQ, + ByteCodeInstruction::kCompareFGTEQ); + break; + case Token::Kind::LT: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLT, + ByteCodeInstruction::kCompareULT, + ByteCodeInstruction::kCompareFLT); + break; + case Token::Kind::LTEQ: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLTEQ, + ByteCodeInstruction::kCompareULTEQ, + ByteCodeInstruction::kCompareFLTEQ); + break; + case Token::Kind::MINUS: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kSubtractI, + ByteCodeInstruction::kSubtractI, + ByteCodeInstruction::kSubtractF); + break; + case Token::Kind::NEQ: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareINEQ, + ByteCodeInstruction::kCompareINEQ, + ByteCodeInstruction::kCompareFNEQ); + break; + case Token::Kind::PERCENT: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kRemainderS, + ByteCodeInstruction::kRemainderU, + ByteCodeInstruction::kInvalid); + break; + case Token::Kind::PLUS: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kAddI, + ByteCodeInstruction::kAddI, + ByteCodeInstruction::kAddF); + break; + case Token::Kind::SLASH: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kDivideS, + ByteCodeInstruction::kDivideU, + ByteCodeInstruction::kDivideF); + break; + case Token::Kind::STAR: + this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kMultiplyS, + ByteCodeInstruction::kMultiplyU, + ByteCodeInstruction::kMultiplyF); + break; + default: + SkASSERT(false); + } + if (lvalue) { + this->write(ByteCodeInstruction::kDupDown); + this->write8(slot_count(b.fType)); + lvalue->store(); + } +} + +void ByteCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(1); +} + +void ByteCodeGenerator::writeConstructor(const Constructor& c) { + if (c.fArguments.size() == 1 && + type_category(c.fType) == type_category(c.fArguments[0]->fType)) { + // cast from float to half or similar no-op + this->writeExpression(*c.fArguments[0]); + return; + } + for (const auto& arg : c.fArguments) { + this->writeExpression(*arg); + } + if (c.fArguments.size() == 1) { + TypeCategory inCategory = type_category(c.fArguments[0]->fType); + TypeCategory outCategory = type_category(c.fType); + if (inCategory != outCategory) { + int count = c.fType.columns(); + if (count > 1) { + this->write(ByteCodeInstruction::kVector); + this->write8(count); + } + if (inCategory == TypeCategory::kFloat) { + SkASSERT(outCategory == TypeCategory::kSigned || + outCategory == TypeCategory::kUnsigned); + this->write(ByteCodeInstruction::kFloatToInt); + } else if (outCategory == TypeCategory::kFloat) { + if (inCategory == TypeCategory::kSigned) { + this->write(ByteCodeInstruction::kSignedToFloat); + } else { + SkASSERT(inCategory == TypeCategory::kUnsigned); + this->write(ByteCodeInstruction::kUnsignedToFloat); + } + } else { + SkASSERT(false); + } + } + } +} + +void ByteCodeGenerator::writeFieldAccess(const FieldAccess& f) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { + this->write(ByteCodeInstruction::kPushImmediate); + union { float f; uint32_t u; } pun = { (float) f.fValue }; + this->write32(pun.u); +} + +void ByteCodeGenerator::writeFunctionCall(const FunctionCall& f) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeIndexExpression(const IndexExpression& i) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeIntLiteral(const IntLiteral& i) { + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(i.fValue); +} + +void ByteCodeGenerator::writeNullLiteral(const NullLiteral& n) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writePrefixExpression(const PrefixExpression& p) { + switch (p.fOperator) { + case Token::Kind::PLUSPLUS: // fall through + case Token::Kind::MINUSMINUS: { + std::unique_ptr lvalue = this->getLValue(*p.fOperand); + lvalue->load(); + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(1); + if (p.fOperator == Token::Kind::PLUSPLUS) { + this->writeTypedInstruction(p.fType, + ByteCodeInstruction::kAddI, + ByteCodeInstruction::kAddI, + ByteCodeInstruction::kAddF); + } else { + this->writeTypedInstruction(p.fType, + ByteCodeInstruction::kSubtractI, + ByteCodeInstruction::kSubtractI, + ByteCodeInstruction::kSubtractF); + } + this->write(ByteCodeInstruction::kDupDown); + this->write8(slot_count(p.fType)); + lvalue->store(); + break; + } + case Token::Kind::MINUS: + this->writeTypedInstruction(p.fType, + ByteCodeInstruction::kNegateS, + ByteCodeInstruction::kInvalid, + ByteCodeInstruction::kNegateF); + break; + default: + SkASSERT(false); + } +} + +void ByteCodeGenerator::writePostfixExpression(const PostfixExpression& p) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeSwizzle(const Swizzle& s) { + switch (s.fBase->fKind) { + case Expression::kVariableReference_Kind: { + const Variable& var = ((VariableReference&) *s.fBase).fVariable; + int location = this->getLocation(var); + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(location); + this->write(ByteCodeInstruction::kLoadSwizzle); + this->write8(s.fComponents.size()); + for (int c : s.fComponents) { + this->write8(c); + } + break; + } + default: + this->writeExpression(*s.fBase); + this->write(ByteCodeInstruction::kSwizzle); + this->write8(s.fBase->fType.columns()); + this->write8(s.fComponents.size()); + for (int c : s.fComponents) { + this->write8(c); + } + } +} + +void ByteCodeGenerator::writeVariableReference(const VariableReference& v) { + if (v.fVariable.fStorage == Variable::kGlobal_Storage) { + this->write(ByteCodeInstruction::kLoadGlobal); + int location = this->getLocation(v.fVariable); + SkASSERT(location <= 255); + this->write8(location); + } else { + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(this->getLocation(v.fVariable)); + int count = slot_count(v.fType); + if (count > 1) { + this->write(ByteCodeInstruction::kVector); + this->write8(count); + } + this->write(ByteCodeInstruction::kLoad); + } +} + +void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeExpression(const Expression& e) { + switch (e.fKind) { + case Expression::kBinary_Kind: + this->writeBinaryExpression((BinaryExpression&) e); + break; + case Expression::kBoolLiteral_Kind: + this->writeBoolLiteral((BoolLiteral&) e); + break; + case Expression::kConstructor_Kind: + this->writeConstructor((Constructor&) e); + break; + case Expression::kFieldAccess_Kind: + this->writeFieldAccess((FieldAccess&) e); + break; + case Expression::kFloatLiteral_Kind: + this->writeFloatLiteral((FloatLiteral&) e); + break; + case Expression::kFunctionCall_Kind: + this->writeFunctionCall((FunctionCall&) e); + break; + case Expression::kIndex_Kind: + this->writeIndexExpression((IndexExpression&) e); + break; + case Expression::kIntLiteral_Kind: + this->writeIntLiteral((IntLiteral&) e); + break; + case Expression::kNullLiteral_Kind: + this->writeNullLiteral((NullLiteral&) e); + break; + case Expression::kPrefix_Kind: + this->writePrefixExpression((PrefixExpression&) e); + break; + case Expression::kPostfix_Kind: + this->writePostfixExpression((PostfixExpression&) e); + break; + case Expression::kSwizzle_Kind: + this->writeSwizzle((Swizzle&) e); + break; + case Expression::kVariableReference_Kind: + this->writeVariableReference((VariableReference&) e); + break; + case Expression::kTernary_Kind: + this->writeTernaryExpression((TernaryExpression&) e); + break; + default: + printf("unsupported expression %s\n", e.description().c_str()); + SkASSERT(false); + } +} + +void ByteCodeGenerator::writeTarget(const Expression& e) { + switch (e.fKind) { + case Expression::kVariableReference_Kind: + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(this->getLocation(((VariableReference&) e).fVariable)); + break; + case Expression::kIndex_Kind: + case Expression::kTernary_Kind: + default: + printf("unsupported target %s\n", e.description().c_str()); + SkASSERT(false); + } +} + +class ByteCodeSwizzleLValue : public ByteCodeGenerator::LValue { +public: + ByteCodeSwizzleLValue(ByteCodeGenerator* generator, const Swizzle& swizzle) + : INHERITED(*generator) + , fSwizzle(swizzle) { + fGenerator.writeTarget(*swizzle.fBase); + } + + void load() override { + fGenerator.write(ByteCodeInstruction::kDup); + fGenerator.write(ByteCodeInstruction::kLoadSwizzle); + fGenerator.write8(fSwizzle.fComponents.size()); + for (int c : fSwizzle.fComponents) { + fGenerator.write8(c); + } + } + + void store() override { + fGenerator.write(ByteCodeInstruction::kStoreSwizzle); + fGenerator.write8(fSwizzle.fComponents.size()); + for (int c : fSwizzle.fComponents) { + fGenerator.write8(c); + } + } + +private: + const Swizzle& fSwizzle; + + typedef LValue INHERITED; +}; + +class ByteCodeVariableLValue : public ByteCodeGenerator::LValue { +public: + ByteCodeVariableLValue(ByteCodeGenerator* generator, const Variable& var) + : INHERITED(*generator) + , fCount(slot_count(var.fType)) + , fIsGlobal(var.fStorage == Variable::kGlobal_Storage) { + fGenerator.write(ByteCodeInstruction::kPushImmediate); + fGenerator.write32(generator->getLocation(var)); + } + + void load() override { + fGenerator.write(ByteCodeInstruction::kDup); + if (fCount > 1) { + fGenerator.write(ByteCodeInstruction::kVector); + fGenerator.write8(fCount); + } + fGenerator.write(fIsGlobal ? ByteCodeInstruction::kLoadGlobal : ByteCodeInstruction::kLoad); + } + + void store() override { + if (fCount > 1) { + fGenerator.write(ByteCodeInstruction::kVector); + fGenerator.write8(fCount); + } + fGenerator.write(fIsGlobal ? ByteCodeInstruction::kStoreGlobal + : ByteCodeInstruction::kStore); + } + +private: + typedef LValue INHERITED; + + int fCount; + + bool fIsGlobal; +}; + +std::unique_ptr ByteCodeGenerator::getLValue(const Expression& e) { + switch (e.fKind) { + case Expression::kIndex_Kind: + // not yet implemented + abort(); + case Expression::kVariableReference_Kind: + return std::unique_ptr(new ByteCodeVariableLValue(this, + ((VariableReference&) e).fVariable)); + case Expression::kSwizzle_Kind: + return std::unique_ptr(new ByteCodeSwizzleLValue(this, (Swizzle&) e)); + case Expression::kTernary_Kind: + default: + printf("unsupported lvalue %s\n", e.description().c_str()); + return nullptr; + } +} + +void ByteCodeGenerator::writeBlock(const Block& b) { + for (const auto& s : b.fStatements) { + this->writeStatement(*s); + } +} + +void ByteCodeGenerator::setBreakTargets() { + std::vector& breaks = fBreakTargets.top(); + for (DeferredLocation& b : breaks) { + b.set(); + } + fBreakTargets.pop(); +} + +void ByteCodeGenerator::setContinueTargets() { + std::vector& continues = fContinueTargets.top(); + for (DeferredLocation& c : continues) { + c.set(); + } + fContinueTargets.pop(); +} + +void ByteCodeGenerator::writeBreakStatement(const BreakStatement& b) { + this->write(ByteCodeInstruction::kBranch); + fBreakTargets.top().emplace_back(this); +} + +void ByteCodeGenerator::writeContinueStatement(const ContinueStatement& c) { + this->write(ByteCodeInstruction::kBranch); + fContinueTargets.top().emplace_back(this); +} + +void ByteCodeGenerator::writeDoStatement(const DoStatement& d) { + fContinueTargets.emplace(); + fBreakTargets.emplace(); + size_t start = fCode->size(); + this->writeStatement(*d.fStatement); + this->setContinueTargets(); + this->writeExpression(*d.fTest); + this->write(ByteCodeInstruction::kConditionalBranch); + this->write16(start); + this->setBreakTargets(); +} + +void ByteCodeGenerator::writeForStatement(const ForStatement& f) { + fContinueTargets.emplace(); + fBreakTargets.emplace(); + if (f.fInitializer) { + this->writeStatement(*f.fInitializer); + } + size_t start = fCode->size(); + if (f.fTest) { + this->writeExpression(*f.fTest); + this->write(ByteCodeInstruction::kNot); + this->write(ByteCodeInstruction::kConditionalBranch); + DeferredLocation endLocation(this); + this->writeStatement(*f.fStatement); + this->setContinueTargets(); + if (f.fNext) { + this->writeExpression(*f.fNext); + this->write(ByteCodeInstruction::kPop); + this->write8(slot_count(f.fNext->fType)); + } + this->write(ByteCodeInstruction::kBranch); + this->write16(start); + endLocation.set(); + } else { + this->writeStatement(*f.fStatement); + this->setContinueTargets(); + if (f.fNext) { + this->writeExpression(*f.fNext); + this->write(ByteCodeInstruction::kPop); + this->write8(slot_count(f.fNext->fType)); + } + this->write(ByteCodeInstruction::kBranch); + this->write16(start); + } + this->setBreakTargets(); +} + +void ByteCodeGenerator::writeIfStatement(const IfStatement& i) { + this->writeExpression(*i.fTest); + this->write(ByteCodeInstruction::kNot); + this->write(ByteCodeInstruction::kConditionalBranch); + DeferredLocation elseLocation(this); + this->writeStatement(*i.fIfTrue); + this->write(ByteCodeInstruction::kBranch); + DeferredLocation endLocation(this); + elseLocation.set(); + if (i.fIfFalse) { + this->writeStatement(*i.fIfFalse); + } + endLocation.set(); +} + +void ByteCodeGenerator::writeReturnStatement(const ReturnStatement& r) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeSwitchStatement(const SwitchStatement& r) { + // not yet implemented + abort(); +} + +void ByteCodeGenerator::writeVarDeclarations(const VarDeclarations& v) { + for (const auto& declStatement : v.fVars) { + const VarDeclaration& decl = (VarDeclaration&) *declStatement; + // we need to grab the location even if we don't use it, to ensure it + // has been allocated + int location = getLocation(*decl.fVar); + if (decl.fValue) { + this->write(ByteCodeInstruction::kPushImmediate); + this->write32(location); + this->writeExpression(*decl.fValue); + int count = slot_count(decl.fValue->fType); + if (count > 1) { + this->write(ByteCodeInstruction::kVector); + this->write8(count); + } + this->write(ByteCodeInstruction::kStore); + } + } +} + +void ByteCodeGenerator::writeWhileStatement(const WhileStatement& w) { + fContinueTargets.emplace(); + fBreakTargets.emplace(); + size_t start = fCode->size(); + this->writeExpression(*w.fTest); + this->write(ByteCodeInstruction::kNot); + this->write(ByteCodeInstruction::kConditionalBranch); + DeferredLocation endLocation(this); + this->writeStatement(*w.fStatement); + this->setContinueTargets(); + this->write(ByteCodeInstruction::kBranch); + this->write16(start); + endLocation.set(); + this->setBreakTargets(); +} + +void ByteCodeGenerator::writeStatement(const Statement& s) { + switch (s.fKind) { + case Statement::kBlock_Kind: + this->writeBlock((Block&) s); + break; + case Statement::kBreak_Kind: + this->writeBreakStatement((BreakStatement&) s); + break; + case Statement::kContinue_Kind: + this->writeContinueStatement((ContinueStatement&) s); + break; + case Statement::kDiscard_Kind: + // not yet implemented + abort(); + case Statement::kDo_Kind: + this->writeDoStatement((DoStatement&) s); + break; + case Statement::kExpression_Kind: { + const Expression& expr = *((ExpressionStatement&) s).fExpression; + this->writeExpression(expr); + this->write(ByteCodeInstruction::kPop); + this->write8(slot_count(expr.fType)); + break; + } + case Statement::kFor_Kind: + this->writeForStatement((ForStatement&) s); + break; + case Statement::kIf_Kind: + this->writeIfStatement((IfStatement&) s); + break; + case Statement::kNop_Kind: + break; + case Statement::kReturn_Kind: + this->writeReturnStatement((ReturnStatement&) s); + break; + case Statement::kSwitch_Kind: + this->writeSwitchStatement((SwitchStatement&) s); + break; + case Statement::kVarDeclarations_Kind: + this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration); + break; + case Statement::kWhile_Kind: + this->writeWhileStatement((WhileStatement&) s); + break; + default: + SkASSERT(false); + } +} + +} diff --git a/src/sksl/SkSLByteCodeGenerator.h b/src/sksl/SkSLByteCodeGenerator.h new file mode 100644 index 0000000000..e518861488 --- /dev/null +++ b/src/sksl/SkSLByteCodeGenerator.h @@ -0,0 +1,241 @@ +/* + * Copyright 2019 Google LLC + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_BYTECODEGENERATOR +#define SKSL_BYTECODEGENERATOR + +#include +#include +#include + +#include "SkSLByteCode.h" +#include "SkSLCodeGenerator.h" +#include "SkSLMemoryLayout.h" +#include "ir/SkSLBinaryExpression.h" +#include "ir/SkSLBoolLiteral.h" +#include "ir/SkSLBlock.h" +#include "ir/SkSLBreakStatement.h" +#include "ir/SkSLConstructor.h" +#include "ir/SkSLContinueStatement.h" +#include "ir/SkSLDoStatement.h" +#include "ir/SkSLExpressionStatement.h" +#include "ir/SkSLFloatLiteral.h" +#include "ir/SkSLIfStatement.h" +#include "ir/SkSLIndexExpression.h" +#include "ir/SkSLInterfaceBlock.h" +#include "ir/SkSLIntLiteral.h" +#include "ir/SkSLFieldAccess.h" +#include "ir/SkSLForStatement.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionDeclaration.h" +#include "ir/SkSLFunctionDefinition.h" +#include "ir/SkSLNullLiteral.h" +#include "ir/SkSLPrefixExpression.h" +#include "ir/SkSLPostfixExpression.h" +#include "ir/SkSLProgramElement.h" +#include "ir/SkSLReturnStatement.h" +#include "ir/SkSLStatement.h" +#include "ir/SkSLSwitchStatement.h" +#include "ir/SkSLSwizzle.h" +#include "ir/SkSLTernaryExpression.h" +#include "ir/SkSLVarDeclarations.h" +#include "ir/SkSLVarDeclarationsStatement.h" +#include "ir/SkSLVariableReference.h" +#include "ir/SkSLWhileStatement.h" +#include "spirv.h" + +namespace SkSL { + +class ByteCodeGenerator : public CodeGenerator { +public: + class LValue { + public: + LValue(ByteCodeGenerator& generator) + : fGenerator(generator) {} + + virtual ~LValue() {} + + /** + * Stack before call: ... lvalue + * Stack after call: ... lvalue load + */ + virtual void load() = 0; + + /** + * Stack before call: ... lvalue value + * Stack after call: ... + */ + virtual void store() = 0; + + protected: + ByteCodeGenerator& fGenerator; + }; + + ByteCodeGenerator(const Context* context, const Program* program, ErrorReporter* errors, + ByteCode* output) + : INHERITED(program, errors, nullptr) + , fContext(*context) + , fOutput(output) {} + + bool generateCode() override; + + void write8(uint8_t b); + + void write16(uint16_t b); + + void write32(uint32_t b); + + void write(ByteCodeInstruction inst); + + /** + * Based on 'type', writes the s (signed), u (unsigned), or f (float) instruction. + */ + void writeTypedInstruction(const Type& type, ByteCodeInstruction s, ByteCodeInstruction u, + ByteCodeInstruction f); + + /** + * Pushes the storage location of an lvalue to the stack. + */ + void writeTarget(const Expression& expr); + +private: + // reserves 16 bits in the output code, to be filled in later with an address once we determine + // it + class DeferredLocation { + public: + DeferredLocation(ByteCodeGenerator* generator) + : fGenerator(*generator) + , fOffset(generator->fCode->size()) { + generator->write16(0); + } + +#ifdef SK_DEBUG + ~DeferredLocation() { + SkASSERT(fSet); + } +#endif + + void set() { + int target = fGenerator.fCode->size(); + SkASSERT(target <= 65535); + (*fGenerator.fCode)[fOffset] = target >> 8; + (*fGenerator.fCode)[fOffset + 1] = target; +#ifdef SK_DEBUG + fSet = true; +#endif + } + + private: + ByteCodeGenerator& fGenerator; + size_t fOffset; +#ifdef SK_DEBUG + bool fSet = false; +#endif + }; + + /** + * Returns the local slot into which var should be stored, allocating a new slot if it has not + * already been assigned one. Compound variables (e.g. vectors) will consume more than one local + * slot, with the getLocation return value indicating where the first element should be stored. + */ + int getLocation(const Variable& var); + + std::unique_ptr writeFunction(const FunctionDefinition& f); + + void writeVarDeclarations(const VarDeclarations& decl); + + void writeVariableReference(const VariableReference& ref); + + void writeExpression(const Expression& expr); + + /** + * Pushes whatever values are required by the lvalue onto the stack, and returns an LValue + * permitting loads and stores to it. + */ + std::unique_ptr getLValue(const Expression& expr); + + void writeFunctionCall(const FunctionCall& c); + + void writeConstructor(const Constructor& c); + + void writeFieldAccess(const FieldAccess& f); + + void writeSwizzle(const Swizzle& swizzle); + + void writeBinaryExpression(const BinaryExpression& b); + + void writeTernaryExpression(const TernaryExpression& t); + + void writeIndexExpression(const IndexExpression& expr); + + void writeLogicalAnd(const BinaryExpression& b); + + void writeLogicalOr(const BinaryExpression& o); + + void writeNullLiteral(const NullLiteral& n); + + void writePrefixExpression(const PrefixExpression& p); + + void writePostfixExpression(const PostfixExpression& p); + + void writeBoolLiteral(const BoolLiteral& b); + + void writeIntLiteral(const IntLiteral& i); + + void writeFloatLiteral(const FloatLiteral& f); + + void writeStatement(const Statement& s); + + void writeBlock(const Block& b); + + void writeBreakStatement(const BreakStatement& b); + + void writeContinueStatement(const ContinueStatement& c); + + void writeIfStatement(const IfStatement& stmt); + + void writeForStatement(const ForStatement& f); + + void writeWhileStatement(const WhileStatement& w); + + void writeDoStatement(const DoStatement& d); + + void writeSwitchStatement(const SwitchStatement& s); + + void writeReturnStatement(const ReturnStatement& r); + + // updates the current set of breaks to branch to the current location + void setBreakTargets(); + + // updates the current set of continues to branch to the current location + void setContinueTargets(); + + const Context& fContext; + + ByteCode* fOutput; + + const FunctionDefinition* fFunction; + + std::vector* fCode; + + std::vector fLocals; + + std::stack> fContinueTargets; + + std::stack> fBreakTargets; + + int fParameterCount; + + friend class DeferredLocation; + friend class ByteCodeVariableLValue; + + typedef CodeGenerator INHERITED; +}; + +} + +#endif diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index b19c94a4d1..720fcb663a 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -7,6 +7,7 @@ #include "SkSLCompiler.h" +#include "SkSLByteCodeGenerator.h" #include "SkSLCFGGenerator.h" #include "SkSLCPPCodeGenerator.h" #include "SkSLGLSLCodeGenerator.h" @@ -1467,6 +1468,18 @@ bool Compiler::toPipelineStage(const Program& program, String* out, return result; } +std::unique_ptr Compiler::toByteCode(Program& program) { + if (!this->optimize(program)) { + return nullptr; + } + std::unique_ptr result(new ByteCode()); + ByteCodeGenerator cg(fContext.get(), &program, this, result.get()); + if (cg.generateCode()) { + return result; + } + return nullptr; +} + const char* Compiler::OperatorName(Token::Kind kind) { switch (kind) { case Token::PLUS: return "+"; diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h index eefcf2a6c5..9b87925896 100644 --- a/src/sksl/SkSLCompiler.h +++ b/src/sksl/SkSLCompiler.h @@ -13,6 +13,7 @@ #include #include "ir/SkSLProgram.h" #include "ir/SkSLSymbolTable.h" +#include "SkSLByteCode.h" #include "SkSLCFGGenerator.h" #include "SkSLContext.h" #include "SkSLErrorReporter.h" @@ -111,6 +112,8 @@ public: bool toH(Program& program, String name, OutputStream& out); + std::unique_ptr toByteCode(Program& program); + bool toPipelineStage(const Program& program, String* out, std::vector* outFormatArgs); diff --git a/src/sksl/SkSLInterpreter.cpp b/src/sksl/SkSLInterpreter.cpp index c39c7b2395..f427cf212c 100644 --- a/src/sksl/SkSLInterpreter.cpp +++ b/src/sksl/SkSLInterpreter.cpp @@ -27,189 +27,42 @@ namespace SkSL { -void Interpreter::run() { - for (const auto& e : *fProgram) { - if (ProgramElement::kFunction_Kind == e.fKind) { - const FunctionDefinition& f = (const FunctionDefinition&) e; - if ("appendStages" == f.fDeclaration.fName) { - this->run(f); - return; +static constexpr int UNINITIALIZED = 0xDEADBEEF; + +Interpreter::Value Interpreter::run(const ByteCodeFunction& f, Interpreter::Value args[], + Interpreter::Value inputs[]) { + fIP = 0; + fCurrentFunction = &f; + fStack.clear(); + fGlobals.clear(); +#ifdef TRACE + this->disassemble(f); +#endif + for (int i = 0; i < f.fParameterCount; ++i) { + this->push(args[i]); + } + for (int i = 0; i < f.fLocalCount; ++i) { + this->push(Value((int) UNINITIALIZED)); + } + for (int i = 0; i < f.fOwner.fGlobalCount; ++i) { + fGlobals.push_back(Value((int) UNINITIALIZED)); + } + for (int i = f.fOwner.fInputSlots.size() - 1; i >= 0; --i) { + fGlobals[f.fOwner.fInputSlots[i]] = inputs[i]; + } + run(); + int offset = 0; + for (const auto& p : f.fDeclaration.fParameters) { + if (p->fModifiers.fFlags & Modifiers::kOut_Flag) { + for (int i = p->fType.columns() * p->fType.rows() - 1; i >= 0; --i) { + args[offset] = fStack[offset]; + ++offset; } + } else { + offset += p->fType.columns() * p->fType.rows(); } } - SkASSERT(false); -} - -static int SizeOf(const Type& type) { - return 1; -} - -void Interpreter::run(const FunctionDefinition& f) { - fVars.emplace_back(); - StackIndex current = (StackIndex) fStack.size(); - for (int i = f.fDeclaration.fParameters.size() - 1; i >= 0; --i) { - current -= SizeOf(f.fDeclaration.fParameters[i]->fType); - fVars.back()[f.fDeclaration.fParameters[i]] = current; - } - fCurrentIndex.push_back({ f.fBody.get(), 0 }); - while (fCurrentIndex.size()) { - this->runStatement(); - } -} - -void Interpreter::push(Value value) { - fStack.push_back(value); -} - -Interpreter::Value Interpreter::pop() { - auto iter = fStack.end() - 1; - Value result = *iter; - fStack.erase(iter); - return result; -} - - Interpreter::StackIndex Interpreter::stackAlloc(int count) { - int result = fStack.size(); - for (int i = 0; i < count; ++i) { - fStack.push_back(Value((int) 0xDEADBEEF)); - } - return result; -} - -void Interpreter::runStatement() { - const Statement& stmt = *fCurrentIndex.back().fStatement; - const size_t index = fCurrentIndex.back().fIndex; - fCurrentIndex.pop_back(); - switch (stmt.fKind) { - case Statement::kBlock_Kind: { - const Block& b = (const Block&) stmt; - if (!b.fStatements.size()) { - break; - } - SkASSERT(index < b.fStatements.size()); - if (index < b.fStatements.size() - 1) { - fCurrentIndex.push_back({ &b, index + 1 }); - } - fCurrentIndex.push_back({ b.fStatements[index].get(), 0 }); - break; - } - case Statement::kBreak_Kind: - SkASSERT(index == 0); - abort(); - case Statement::kContinue_Kind: - SkASSERT(index == 0); - abort(); - case Statement::kDiscard_Kind: - SkASSERT(index == 0); - abort(); - case Statement::kDo_Kind: - abort(); - case Statement::kExpression_Kind: - SkASSERT(index == 0); - this->evaluate(*((const ExpressionStatement&) stmt).fExpression); - break; - case Statement::kFor_Kind: { - ForStatement& f = (ForStatement&) stmt; - switch (index) { - case 0: - // initializer - fCurrentIndex.push_back({ &f, 1 }); - if (f.fInitializer) { - fCurrentIndex.push_back({ f.fInitializer.get(), 0 }); - } - break; - case 1: - // test & body - if (f.fTest && !evaluate(*f.fTest).fBool) { - break; - } else { - fCurrentIndex.push_back({ &f, 2 }); - fCurrentIndex.push_back({ f.fStatement.get(), 0 }); - } - break; - case 2: - // next - if (f.fNext) { - this->evaluate(*f.fNext); - } - fCurrentIndex.push_back({ &f, 1 }); - break; - default: - SkASSERT(false); - } - break; - } - case Statement::kGroup_Kind: - abort(); - case Statement::kIf_Kind: { - IfStatement& i = (IfStatement&) stmt; - if (evaluate(*i.fTest).fBool) { - fCurrentIndex.push_back({ i.fIfTrue.get(), 0 }); - } else if (i.fIfFalse) { - fCurrentIndex.push_back({ i.fIfFalse.get(), 0 }); - } - break; - } - case Statement::kNop_Kind: - SkASSERT(index == 0); - break; - case Statement::kReturn_Kind: - SkASSERT(index == 0); - abort(); - case Statement::kSwitch_Kind: - abort(); - case Statement::kVarDeclarations_Kind: - SkASSERT(index == 0); - for (const auto& decl :((const VarDeclarationsStatement&) stmt).fDeclaration->fVars) { - const Variable* var = ((VarDeclaration&) *decl).fVar; - StackIndex pos = this->stackAlloc(SizeOf(var->fType)); - fVars.back()[var] = pos; - if (var->fInitialValue) { - fStack[pos] = this->evaluate(*var->fInitialValue); - } - } - break; - case Statement::kWhile_Kind: - abort(); - default: - abort(); - } -} - -static Interpreter::TypeKind type_kind(const Type& type) { - if (type.fName == "int") { - return Interpreter::kInt_TypeKind; - } else if (type.fName == "float") { - return Interpreter::kFloat_TypeKind; - } - ABORT("unsupported type: %s\n", type.description().c_str()); -} - -Interpreter::StackIndex Interpreter::getLValue(const Expression& expr) { - switch (expr.fKind) { - case Expression::kFieldAccess_Kind: - break; - case Expression::kIndex_Kind: { - const IndexExpression& idx = (const IndexExpression&) expr; - return this->evaluate(*idx.fBase).fInt + this->evaluate(*idx.fIndex).fInt; - } - case Expression::kSwizzle_Kind: - break; - case Expression::kVariableReference_Kind: - SkASSERT(fVars.size()); - SkASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) != - fVars.back().end()); - return fVars.back()[&((VariableReference&) expr).fVariable]; - case Expression::kTernary_Kind: { - const TernaryExpression& t = (const TernaryExpression&) expr; - return this->getLValue(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse); - } - case Expression::kTypeReference_Kind: - break; - default: - break; - } - ABORT("unsupported lvalue"); + return fReturnValue; } struct CallbackCtx : public SkRasterPipeline_CallbackCtx { @@ -217,254 +70,421 @@ struct CallbackCtx : public SkRasterPipeline_CallbackCtx { const FunctionDefinition* fFunction; }; -static void do_callback(SkRasterPipeline_CallbackCtx* raw, int activePixels) { - CallbackCtx& ctx = (CallbackCtx&) *raw; - for (int i = 0; i < activePixels; ++i) { - ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 0])); - ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 1])); - ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 2])); - ctx.fInterpreter->run(*ctx.fFunction); - ctx.read_from[i * 4 + 2] = ctx.fInterpreter->pop().fFloat; - ctx.read_from[i * 4 + 1] = ctx.fInterpreter->pop().fFloat; - ctx.read_from[i * 4 + 0] = ctx.fInterpreter->pop().fFloat; +uint8_t Interpreter::read8() { + return fCurrentFunction->fCode[fIP++]; +} + +uint16_t Interpreter::read16() { + uint16_t result = (fCurrentFunction->fCode[fIP ] << 8) + + fCurrentFunction->fCode[fIP + 1]; + fIP += 2; + return result; +} + +uint32_t Interpreter::read32() { + uint32_t result = (fCurrentFunction->fCode[fIP] << 24) + + (fCurrentFunction->fCode[fIP + 1] << 16) + + (fCurrentFunction->fCode[fIP + 2] << 8) + + fCurrentFunction->fCode[fIP + 3]; + fIP += 4; + return result; +} + +void Interpreter::push(Value v) { + fStack.push_back(v); +} + +Interpreter::Value Interpreter::pop() { + Value v = fStack.back(); + fStack.pop_back(); + return v; +} + +static String value_string(uint32_t v) { + union { uint32_t u; float f; } pun = { v }; + return to_string(v) + "(" + to_string(pun.f) + ")"; +} + +void Interpreter::disassemble(const ByteCodeFunction& f) { + SkASSERT(fIP == 0); + while (fIP < (int) f.fCode.size()) { + printf("%d: ", fIP); + switch ((ByteCodeInstruction) this->read8()) { + case ByteCodeInstruction::kAddF: printf("addf"); break; + case ByteCodeInstruction::kAddI: printf("addi"); break; + case ByteCodeInstruction::kAndB: printf("andb"); break; + case ByteCodeInstruction::kAndI: printf("andi"); break; + case ByteCodeInstruction::kBranch: printf("branch %d", this->read16()); break; + case ByteCodeInstruction::kCompareIEQ: printf("comparei eq"); break; + case ByteCodeInstruction::kCompareINEQ: printf("comparei neq"); break; + case ByteCodeInstruction::kCompareFEQ: printf("comparef eq"); break; + case ByteCodeInstruction::kCompareFGT: printf("comparef gt"); break; + case ByteCodeInstruction::kCompareFGTEQ: printf("comparef gteq"); break; + case ByteCodeInstruction::kCompareFLT: printf("comparef lt"); break; + case ByteCodeInstruction::kCompareFLTEQ: printf("comparef lteq"); break; + case ByteCodeInstruction::kCompareFNEQ: printf("comparef neq"); break; + case ByteCodeInstruction::kCompareSGT: printf("compares sgt"); break; + case ByteCodeInstruction::kCompareSGTEQ: printf("compares sgteq"); break; + case ByteCodeInstruction::kCompareSLT: printf("compares lt"); break; + case ByteCodeInstruction::kCompareSLTEQ: printf("compares lteq"); break; + case ByteCodeInstruction::kCompareUGT: printf("compareu gt"); break; + case ByteCodeInstruction::kCompareUGTEQ: printf("compareu gteq"); break; + case ByteCodeInstruction::kCompareULT: printf("compareu lt"); break; + case ByteCodeInstruction::kCompareULTEQ: printf("compareu lteq"); break; + case ByteCodeInstruction::kConditionalBranch: + printf("conditionalbranch %d", this->read16()); + break; + case ByteCodeInstruction::kDebugPrint: printf("debugprint"); break; + case ByteCodeInstruction::kDivideF: printf("dividef"); break; + case ByteCodeInstruction::kDivideS: printf("divides"); break; + case ByteCodeInstruction::kDivideU: printf("divideu"); break; + case ByteCodeInstruction::kDup: printf("dup"); break; + case ByteCodeInstruction::kDupDown: printf("dupdown %d", this->read8()); break; + case ByteCodeInstruction::kFloatToInt: printf("floattoint"); break; + case ByteCodeInstruction::kLoad: printf("load"); break; + case ByteCodeInstruction::kLoadGlobal: printf("loadglobal"); break; + case ByteCodeInstruction::kLoadSwizzle: { + int count = this->read8(); + printf("loadswizzle %d", count); + for (int i = 0; i < count; ++i) { + printf(", %d", this->read8()); + } + break; + } + case ByteCodeInstruction::kMultiplyF: printf("multiplyf"); break; + case ByteCodeInstruction::kMultiplyS: printf("multiplys"); break; + case ByteCodeInstruction::kMultiplyU: printf("multiplyu"); break; + case ByteCodeInstruction::kNegateF: printf("negatef"); break; + case ByteCodeInstruction::kNegateS: printf("negates"); break; + case ByteCodeInstruction::kNot: printf("not"); break; + case ByteCodeInstruction::kOrB: printf("orb"); break; + case ByteCodeInstruction::kOrI: printf("ori"); break; + case ByteCodeInstruction::kParameter: printf("parameter"); break; + case ByteCodeInstruction::kPop: printf("pop %d", this->read8()); break; + case ByteCodeInstruction::kPushImmediate: + printf("pushimmediate %s", value_string(this->read32()).c_str()); + break; + case ByteCodeInstruction::kRemainderS: printf("remainders"); break; + case ByteCodeInstruction::kRemainderU: printf("remainderu"); break; + case ByteCodeInstruction::kSignedToFloat: printf("signedtofloat"); break; + case ByteCodeInstruction::kStore: printf("store"); break; + case ByteCodeInstruction::kStoreSwizzle: { + int count = this->read8(); + printf("storeswizzle %d", count); + for (int i = 0; i < count; ++i) { + printf(", %d", this->read8()); + } + break; + } + case ByteCodeInstruction::kSubtractF: printf("subtractf"); break; + case ByteCodeInstruction::kSubtractI: printf("subtracti"); break; + case ByteCodeInstruction::kSwizzle: { + printf("swizzle %d, ", this->read8()); + int count = this->read8(); + printf("%d", count); + for (int i = 0; i < count; ++i) { + printf(", %d", this->read8()); + } + break; + } + case ByteCodeInstruction::kUnsignedToFloat: printf("unsignedtofloat"); break; + case ByteCodeInstruction::kVector: printf("vector%d", this->read8()); break; + default: SkASSERT(false); + } + printf("\n"); } + fIP = 0; } -void Interpreter::appendStage(const AppendStage& a) { - switch (a.fStage) { - case SkRasterPipeline::matrix_4x5: { - SkASSERT(a.fArguments.size() == 1); - StackIndex transpose = evaluate(*a.fArguments[0]).fInt; - fPipeline.append(SkRasterPipeline::matrix_4x5, &fStack[transpose]); - break; - } - case SkRasterPipeline::callback: { - SkASSERT(a.fArguments.size() == 1); - CallbackCtx* ctx = new CallbackCtx(); - ctx->fInterpreter = this; - ctx->fn = do_callback; - for (const auto& e : *fProgram) { - if (ProgramElement::kFunction_Kind == e.fKind) { - const FunctionDefinition& f = (const FunctionDefinition&) e; - if (&f.fDeclaration == - ((const FunctionReference&) *a.fArguments[0]).fFunctions[0]) { - ctx->fFunction = &f; - } - } - } - fPipeline.append(SkRasterPipeline::callback, ctx); - break; - } - default: - fPipeline.append(a.fStage); +void Interpreter::dumpStack() { + printf("STACK:"); + for (size_t i = 0; i < fStack.size(); ++i) { + printf(" %d(%f)", fStack[i].fSigned, fStack[i].fFloat); } + printf("\n"); } -Interpreter::Value Interpreter::call(const FunctionCall& c) { - abort(); -} +#define BINARY_OP(inst, type, field, op) \ + case ByteCodeInstruction::inst: { \ + type b = this->pop().field; \ + type a = this->pop().field; \ + this->push(Value(a op b)); \ + break; \ + } -Interpreter::Value Interpreter::evaluate(const Expression& expr) { - switch (expr.fKind) { - case Expression::kAppendStage_Kind: - this->appendStage((const AppendStage&) expr); - return Value((int) 0xDEADBEEF); - case Expression::kBinary_Kind: { - #define ARITHMETIC(op) { \ - Value left = this->evaluate(*b.fLeft); \ - Value right = this->evaluate(*b.fRight); \ - switch (type_kind(b.fLeft->fType)) { \ - case kFloat_TypeKind: \ - return Value(left.fFloat op right.fFloat); \ - case kInt_TypeKind: \ - return Value(left.fInt op right.fInt); \ - default: \ - abort(); \ - } \ - } - #define BITWISE(op) { \ - Value left = this->evaluate(*b.fLeft); \ - Value right = this->evaluate(*b.fRight); \ - switch (type_kind(b.fLeft->fType)) { \ - case kInt_TypeKind: \ - return Value(left.fInt op right.fInt); \ - default: \ - abort(); \ - } \ - } - #define LOGIC(op) { \ - Value left = this->evaluate(*b.fLeft); \ - Value right = this->evaluate(*b.fRight); \ - switch (type_kind(b.fLeft->fType)) { \ - case kFloat_TypeKind: \ - return Value(left.fFloat op right.fFloat); \ - case kInt_TypeKind: \ - return Value(left.fInt op right.fInt); \ - default: \ - abort(); \ - } \ - } - #define COMPOUND_ARITHMETIC(op) { \ - StackIndex left = this->getLValue(*b.fLeft); \ - Value right = this->evaluate(*b.fRight); \ - Value result = fStack[left]; \ - switch (type_kind(b.fLeft->fType)) { \ - case kFloat_TypeKind: \ - result.fFloat op right.fFloat; \ - break; \ - case kInt_TypeKind: \ - result.fInt op right.fInt; \ - break; \ - default: \ - abort(); \ - } \ - fStack[left] = result; \ - return result; \ - } - #define COMPOUND_BITWISE(op) { \ - StackIndex left = this->getLValue(*b.fLeft); \ - Value right = this->evaluate(*b.fRight); \ - Value result = fStack[left]; \ - switch (type_kind(b.fLeft->fType)) { \ - case kInt_TypeKind: \ - result.fInt op right.fInt; \ - break; \ - default: \ - abort(); \ - } \ - fStack[left] = result; \ - return result; \ - } - const BinaryExpression& b = (const BinaryExpression&) expr; - switch (b.fOperator) { - case Token::PLUS: ARITHMETIC(+) - case Token::MINUS: ARITHMETIC(-) - case Token::STAR: ARITHMETIC(*) - case Token::SLASH: ARITHMETIC(/) - case Token::BITWISEAND: BITWISE(&) - case Token::BITWISEOR: BITWISE(|) - case Token::BITWISEXOR: BITWISE(^) - case Token::LT: LOGIC(<) - case Token::GT: LOGIC(>) - case Token::LTEQ: LOGIC(<=) - case Token::GTEQ: LOGIC(>=) - case Token::LOGICALAND: { - Value result = this->evaluate(*b.fLeft); - if (result.fBool) { - result = this->evaluate(*b.fRight); - } - return result; - } - case Token::LOGICALOR: { - Value result = this->evaluate(*b.fLeft); - if (!result.fBool) { - result = this->evaluate(*b.fRight); - } - return result; - } - case Token::EQ: { - StackIndex left = this->getLValue(*b.fLeft); - Value right = this->evaluate(*b.fRight); - fStack[left] = right; - return right; - } - case Token::PLUSEQ: COMPOUND_ARITHMETIC(+=) - case Token::MINUSEQ: COMPOUND_ARITHMETIC(-=) - case Token::STAREQ: COMPOUND_ARITHMETIC(*=) - case Token::SLASHEQ: COMPOUND_ARITHMETIC(/=) - case Token::BITWISEANDEQ: COMPOUND_BITWISE(&=) - case Token::BITWISEOREQ: COMPOUND_BITWISE(|=) - case Token::BITWISEXOREQ: COMPOUND_BITWISE(^=) - default: - ABORT("unsupported operator: %s\n", expr.description().c_str()); +void Interpreter::next() { +#ifdef TRACE + printf("at %d\n", fIP); +#endif + ByteCodeInstruction inst = (ByteCodeInstruction) this->read8(); + switch (inst) { + BINARY_OP(kAddI, int32_t, fSigned, +) + BINARY_OP(kAddF, float, fFloat, +) + case ByteCodeInstruction::kBranch: + fIP = this->read16(); + break; + BINARY_OP(kCompareIEQ, int32_t, fSigned, ==) + BINARY_OP(kCompareFEQ, float, fFloat, ==) + BINARY_OP(kCompareINEQ, int32_t, fSigned, !=) + BINARY_OP(kCompareFNEQ, float, fFloat, !=) + BINARY_OP(kCompareSGT, int32_t, fSigned, >) + BINARY_OP(kCompareUGT, uint32_t, fUnsigned, >) + BINARY_OP(kCompareFGT, float, fFloat, >) + BINARY_OP(kCompareSGTEQ, int32_t, fSigned, >=) + BINARY_OP(kCompareUGTEQ, uint32_t, fUnsigned, >=) + BINARY_OP(kCompareFGTEQ, float, fFloat, >=) + BINARY_OP(kCompareSLT, int32_t, fSigned, <) + BINARY_OP(kCompareULT, uint32_t, fUnsigned, <) + BINARY_OP(kCompareFLT, float, fFloat, <) + BINARY_OP(kCompareSLTEQ, int32_t, fSigned, <=) + BINARY_OP(kCompareULTEQ, uint32_t, fUnsigned, <=) + BINARY_OP(kCompareFLTEQ, float, fFloat, <=) + case ByteCodeInstruction::kConditionalBranch: { + int target = this->read16(); + if (this->pop().fBool) { + fIP = target; } break; } - case Expression::kBoolLiteral_Kind: - return Value(((const BoolLiteral&) expr).fValue); - case Expression::kConstructor_Kind: + case ByteCodeInstruction::kDebugPrint: { + Value v = this->pop(); + printf("Debug: %d(int), %d(uint), %f(float)\n", v.fSigned, v.fUnsigned, v.fFloat); break; - case Expression::kIntLiteral_Kind: - return Value((int) ((const IntLiteral&) expr).fValue); - case Expression::kFieldAccess_Kind: - break; - case Expression::kFloatLiteral_Kind: - return Value((float) ((const FloatLiteral&) expr).fValue); - case Expression::kFunctionCall_Kind: - return this->call((const FunctionCall&) expr); - case Expression::kIndex_Kind: { - const IndexExpression& idx = (const IndexExpression&) expr; - StackIndex pos = this->evaluate(*idx.fBase).fInt + - this->evaluate(*idx.fIndex).fInt; - return fStack[pos]; } - case Expression::kPrefix_Kind: { - const PrefixExpression& p = (const PrefixExpression&) expr; - switch (p.fOperator) { - case Token::MINUS: { - Value base = this->evaluate(*p.fOperand); - switch (type_kind(p.fType)) { - case kFloat_TypeKind: - return Value(-base.fFloat); - case kInt_TypeKind: - return Value(-base.fInt); - default: - abort(); - } - } - case Token::LOGICALNOT: { - Value base = this->evaluate(*p.fOperand); - return Value(!base.fBool); - } - default: - abort(); + BINARY_OP(kDivideS, int32_t, fSigned, /) + BINARY_OP(kDivideU, uint32_t, fUnsigned, /) + BINARY_OP(kDivideF, float, fFloat, /) + case ByteCodeInstruction::kDup: + this->push(fStack.back()); + break; + case ByteCodeInstruction::kDupDown: { + int count = this->read8(); + for (int i = 0; i < count; ++i) { + fStack.insert(fStack.end() - i - count - 1, fStack[fStack.size() - i - 1]); } + break; } - case Expression::kPostfix_Kind: { - const PostfixExpression& p = (const PostfixExpression&) expr; - StackIndex lvalue = this->getLValue(*p.fOperand); - Value result = fStack[lvalue]; - switch (type_kind(p.fType)) { - case kFloat_TypeKind: - if (Token::PLUSPLUS == p.fOperator) { - ++fStack[lvalue].fFloat; - } else { - SkASSERT(Token::MINUSMINUS == p.fOperator); - --fStack[lvalue].fFloat; - } - break; - case kInt_TypeKind: - if (Token::PLUSPLUS == p.fOperator) { - ++fStack[lvalue].fInt; - } else { - SkASSERT(Token::MINUSMINUS == p.fOperator); - --fStack[lvalue].fInt; - } - break; - default: - abort(); + case ByteCodeInstruction::kFloatToInt: { + Value& top = fStack.back(); + top.fSigned = (int) top.fFloat; + break; + } + case ByteCodeInstruction::kSignedToFloat: { + Value& top = fStack.back(); + top.fFloat = (float) top.fSigned; + break; + } + case ByteCodeInstruction::kUnsignedToFloat: { + Value& top = fStack.back(); + top.fFloat = (float) top.fUnsigned; + break; + } + case ByteCodeInstruction::kLoad: { + int target = this->pop().fSigned; + SkASSERT(target < (int) fStack.size()); + this->push(fStack[target]); + break; + } + case ByteCodeInstruction::kLoadGlobal: { + int target = this->read8(); + SkASSERT(target < (int) fGlobals.size()); + this->push(fGlobals[target]); + break; + } + case ByteCodeInstruction::kLoadSwizzle: { + Value target = this->pop(); + int count = read8(); + for (int i = 0; i < count; ++i) { + SkASSERT(target.fSigned + fCurrentFunction->fCode[fIP + i] < (int) fStack.size()); + this->push(fStack[target.fSigned + fCurrentFunction->fCode[fIP + i]]); } - return result; - } - case Expression::kSetting_Kind: + fIP += count; break; - case Expression::kSwizzle_Kind: - break; - case Expression::kVariableReference_Kind: - SkASSERT(fVars.size()); - SkASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) != - fVars.back().end()); - return fStack[fVars.back()[&((VariableReference&) expr).fVariable]]; - case Expression::kTernary_Kind: { - const TernaryExpression& t = (const TernaryExpression&) expr; - return this->evaluate(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse); } - case Expression::kTypeReference_Kind: + BINARY_OP(kMultiplyS, int32_t, fSigned, *) + BINARY_OP(kMultiplyU, uint32_t, fUnsigned, *) + BINARY_OP(kMultiplyF, float, fFloat, *) + case ByteCodeInstruction::kNot: { + Value& top = fStack.back(); + top.fBool = !top.fBool; + break; + } + case ByteCodeInstruction::kNegateF: + this->push(-this->pop().fFloat); + case ByteCodeInstruction::kNegateS: + this->push(-this->pop().fSigned); + case ByteCodeInstruction::kPop: + for (int i = read8(); i > 0; --i) { + this->pop(); + } + break; + case ByteCodeInstruction::kPushImmediate: + this->push(Value((int) read32())); + break; + BINARY_OP(kRemainderS, int32_t, fSigned, %) + BINARY_OP(kRemainderU, uint32_t, fUnsigned, %) + case ByteCodeInstruction::kStore: { + Value value = this->pop(); + int target = this->pop().fSigned; + SkASSERT(target < (int) fStack.size()); + fStack[target] = value; + break; + } + case ByteCodeInstruction::kStoreGlobal: { + Value value = this->pop(); + int target = this->pop().fSigned; + SkASSERT(target < (int) fGlobals.size()); + fGlobals[target] = value; + break; + } + case ByteCodeInstruction::kStoreSwizzle: { + int count = read8(); + int target = fStack[fStack.size() - count - 1].fSigned; + for (int i = count - 1; i >= 0; --i) { + SkASSERT(target + fCurrentFunction->fCode[fIP + i] < (int) fStack.size()); + fStack[target + fCurrentFunction->fCode[fIP + i]] = this->pop(); + } + this->pop(); + fIP += count; + break; + } + BINARY_OP(kSubtractI, int32_t, fSigned, -) + BINARY_OP(kSubtractF, float, fFloat, -) + case ByteCodeInstruction::kSwizzle: { + Value vec[4]; + for (int i = this->read8() - 1; i >= 0; --i) { + vec[i] = this->pop(); + } + for (int i = this->read8() - 1; i >= 0; --i) { + this->push(vec[this->read8()]); + } + break; + } + case ByteCodeInstruction::kVector: + this->nextVector(this->read8()); break; default: - break; + printf("unsupported instruction %d\n", (int) inst); + SkASSERT(false); + } +#ifdef TRACE + this->dumpStack(); +#endif +} + +static constexpr int VECTOR_MAX = 16; + +#define VECTOR_BINARY_OP(inst, type, field, op) \ + case ByteCodeInstruction::inst: { \ + Value result[VECTOR_MAX]; \ + for (int i = count - 1; i >= 0; --i) { \ + result[i] = this->pop(); \ + } \ + for (int i = count - 1; i >= 0; --i) { \ + result[i] = this->pop().field op result[i].field; \ + } \ + for (int i = 0; i < count; ++i) { \ + this->push(result[i]); \ + } \ + break; \ + } + +void Interpreter::nextVector(int count) { + ByteCodeInstruction inst = (ByteCodeInstruction) this->read8(); + switch (inst) { + VECTOR_BINARY_OP(kAddI, int32_t, fSigned, +) + VECTOR_BINARY_OP(kAddF, float, fFloat, +) + case ByteCodeInstruction::kBranch: + fIP = this->read16(); + break; + VECTOR_BINARY_OP(kCompareIEQ, int32_t, fSigned, ==) + VECTOR_BINARY_OP(kCompareFEQ, float, fFloat, ==) + VECTOR_BINARY_OP(kCompareINEQ, int32_t, fSigned, !=) + VECTOR_BINARY_OP(kCompareFNEQ, float, fFloat, !=) + VECTOR_BINARY_OP(kCompareSGT, int32_t, fSigned, >) + VECTOR_BINARY_OP(kCompareUGT, uint32_t, fUnsigned, >) + VECTOR_BINARY_OP(kCompareFGT, float, fFloat, >) + VECTOR_BINARY_OP(kCompareSGTEQ, int32_t, fSigned, >=) + VECTOR_BINARY_OP(kCompareUGTEQ, uint32_t, fUnsigned, >=) + VECTOR_BINARY_OP(kCompareFGTEQ, float, fFloat, >=) + VECTOR_BINARY_OP(kCompareSLT, int32_t, fSigned, <) + VECTOR_BINARY_OP(kCompareULT, uint32_t, fUnsigned, <) + VECTOR_BINARY_OP(kCompareFLT, float, fFloat, <) + VECTOR_BINARY_OP(kCompareSLTEQ, int32_t, fSigned, <=) + VECTOR_BINARY_OP(kCompareULTEQ, uint32_t, fUnsigned, <=) + VECTOR_BINARY_OP(kCompareFLTEQ, float, fFloat, <=) + case ByteCodeInstruction::kConditionalBranch: { + int target = this->read16(); + if (this->pop().fBool) { + fIP = target; + } + break; + } + VECTOR_BINARY_OP(kDivideS, int32_t, fSigned, /) + VECTOR_BINARY_OP(kDivideU, uint32_t, fUnsigned, /) + VECTOR_BINARY_OP(kDivideF, float, fFloat, /) + case ByteCodeInstruction::kFloatToInt: { + for (int i = 0; i < count; ++i) { + Value& v = fStack[fStack.size() - i - 1]; + v.fSigned = (int) v.fFloat; + } + break; + } + case ByteCodeInstruction::kSignedToFloat: { + for (int i = 0; i < count; ++i) { + Value& v = fStack[fStack.size() - i - 1]; + v.fFloat = (float) v.fSigned; + } + break; + } + case ByteCodeInstruction::kUnsignedToFloat: { + for (int i = 0; i < count; ++i) { + Value& v = fStack[fStack.size() - i - 1]; + v.fFloat = (float) v.fUnsigned; + } + break; + } + case ByteCodeInstruction::kLoad: { + int target = this->pop().fSigned; + for (int i = 0; i < count; ++i) { + SkASSERT(target < (int) fStack.size()); + this->push(fStack[target++]); + } + break; + } + case ByteCodeInstruction::kLoadGlobal: { + int target = this->read8(); + SkASSERT(target < (int) fGlobals.size()); + this->push(fGlobals[target]); + break; + } + VECTOR_BINARY_OP(kMultiplyS, int32_t, fSigned, *) + VECTOR_BINARY_OP(kMultiplyU, uint32_t, fUnsigned, *) + VECTOR_BINARY_OP(kMultiplyF, float, fFloat, *) + VECTOR_BINARY_OP(kRemainderS, int32_t, fSigned, %) + VECTOR_BINARY_OP(kRemainderU, uint32_t, fUnsigned, %) + case ByteCodeInstruction::kStore: { + int target = fStack[fStack.size() - count - 1].fSigned + count; + for (int i = count - 1; i >= 0; --i) { + SkASSERT(target < (int) fStack.size()); + fStack[--target] = this->pop(); + } + break; + } + VECTOR_BINARY_OP(kSubtractI, int32_t, fSigned, -) + VECTOR_BINARY_OP(kSubtractF, float, fFloat, -) + case ByteCodeInstruction::kVector: + this->nextVector(this->read8()); + default: + printf("unsupported instruction %d\n", (int) inst); + SkASSERT(false); + } +} + +void Interpreter::run() { + while (fIP < (int) fCurrentFunction->fCode.size()) { + next(); } - ABORT("unsupported expression: %s\n", expr.description().c_str()); } } // namespace diff --git a/src/sksl/SkSLInterpreter.h b/src/sksl/SkSLInterpreter.h index 449e380541..018768abb9 100644 --- a/src/sksl/SkSLInterpreter.h +++ b/src/sksl/SkSLInterpreter.h @@ -8,6 +8,7 @@ #ifndef SKSL_INTERPRETER #define SKSL_INTERPRETER +#include "SkSLByteCode.h" #include "ir/SkSLAppendStage.h" #include "ir/SkSLExpression.h" #include "ir/SkSLFunctionCall.h" @@ -17,31 +18,30 @@ #include -class SkRasterPipeline; - namespace SkSL { class Interpreter { typedef int StackIndex; - struct StatementIndex { - const Statement* fStatement; - size_t fIndex; - }; - public: union Value { + Value() {} + Value(float f) : fFloat(f) {} - Value(int i) - : fInt(i) {} + Value(int32_t s) + : fSigned(s) {} + + Value(uint32_t u) + : fUnsigned(u) {} Value(bool b) : fBool(b) {} float fFloat; - int fInt; + int32_t fSigned; + uint32_t fUnsigned; bool fBool; }; @@ -51,37 +51,49 @@ public: kBool_TypeKind }; - Interpreter(std::unique_ptr program, SkRasterPipeline* pipeline, std::vector* stack) + Interpreter(std::unique_ptr program, std::unique_ptr byteCode) : fProgram(std::move(program)) - , fPipeline(*pipeline) - , fStack(*stack) {} + , fByteCode(std::move(byteCode)) + , fReturnValue(0) {} + + /** + * Invokes the specified function with the given arguments, returning its return value. 'out' + * and 'inout' parameters will result in the 'args' array being modified. + */ + Value run(const ByteCodeFunction& f, Value args[], Value inputs[]); + +private: + StackIndex stackAlloc(int count); + + uint8_t read8(); + + uint16_t read16(); + + uint32_t read32(); + + void next(); + + void nextVector(int count); void run(); - void run(const FunctionDefinition& f); - - void push(Value value); + void push(Value v); Value pop(); - StackIndex stackAlloc(int count); + void swizzle(); - void runStatement(); + void disassemble(const ByteCodeFunction& f); - StackIndex getLValue(const Expression& expr); + void dumpStack(); - Value call(const FunctionCall& c); - - void appendStage(const AppendStage& c); - - Value evaluate(const Expression& expr); - -private: std::unique_ptr fProgram; - SkRasterPipeline& fPipeline; - std::vector fCurrentIndex; - std::vector> fVars; - std::vector &fStack; + std::unique_ptr fByteCode; + int fIP; + const ByteCodeFunction* fCurrentFunction; + std::vector fGlobals; + std::vector fStack; + Value fReturnValue; }; } // namespace diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 27787bbc62..7500ea05c4 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -1915,28 +1915,6 @@ SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, return result; } -bool is_assignment(Token::Kind op) { - switch (op) { - case Token::EQ: // fall through - case Token::PLUSEQ: // fall through - case Token::MINUSEQ: // fall through - case Token::STAREQ: // fall through - case Token::SLASHEQ: // fall through - case Token::PERCENTEQ: // fall through - case Token::SHLEQ: // fall through - case Token::SHREQ: // fall through - case Token::BITWISEOREQ: // fall through - case Token::BITWISEXOREQ: // fall through - case Token::BITWISEANDEQ: // fall through - case Token::LOGICALOREQ: // fall through - case Token::LOGICALXOREQ: // fall through - case Token::LOGICALANDEQ: - return true; - default: - return false; - } -} - SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out) { if (operandType.kind() == Type::kVector_Kind) { diff --git a/src/sksl/SkSLUtil.cpp b/src/sksl/SkSLUtil.cpp index 49d37e3056..4684df4093 100644 --- a/src/sksl/SkSLUtil.cpp +++ b/src/sksl/SkSLUtil.cpp @@ -30,4 +30,45 @@ void write_stringstream(const StringStream& s, OutputStream& out) { out.write(s.str().c_str(), s.str().size()); } +bool is_assignment(Token::Kind op) { + switch (op) { + case Token::EQ: // fall through + case Token::PLUSEQ: // fall through + case Token::MINUSEQ: // fall through + case Token::STAREQ: // fall through + case Token::SLASHEQ: // fall through + case Token::PERCENTEQ: // fall through + case Token::SHLEQ: // fall through + case Token::SHREQ: // fall through + case Token::BITWISEOREQ: // fall through + case Token::BITWISEXOREQ: // fall through + case Token::BITWISEANDEQ: // fall through + case Token::LOGICALOREQ: // fall through + case Token::LOGICALXOREQ: // fall through + case Token::LOGICALANDEQ: + return true; + default: + return false; + } +} + +Token::Kind remove_assignment(Token::Kind op) { + switch (op) { + case Token::PLUSEQ: return Token::PLUS; + case Token::MINUSEQ: return Token::MINUS; + case Token::STAREQ: return Token::STAR; + case Token::SLASHEQ: return Token::SLASH; + case Token::PERCENTEQ: return Token::PERCENT; + case Token::SHLEQ: return Token::SHL; + case Token::SHREQ: return Token::SHR; + case Token::BITWISEOREQ: return Token::BITWISEOR; + case Token::BITWISEXOREQ: return Token::BITWISEXOR; + case Token::BITWISEANDEQ: return Token::BITWISEAND; + case Token::LOGICALOREQ: return Token::LOGICALOR; + case Token::LOGICALXOREQ: return Token::LOGICALXOR; + case Token::LOGICALANDEQ: return Token::LOGICALAND; + default: return Token::INVALID; + } +} + } // namespace diff --git a/src/sksl/SkSLUtil.h b/src/sksl/SkSLUtil.h index ca00c9bdbc..d221f4ed16 100644 --- a/src/sksl/SkSLUtil.h +++ b/src/sksl/SkSLUtil.h @@ -12,6 +12,7 @@ #include #include "stdlib.h" #include "string.h" +#include "SkSLLexer.h" #include "SkSLDefines.h" #include "SkSLString.h" #include "SkSLStringStream.h" @@ -396,6 +397,13 @@ public: void write_stringstream(const StringStream& d, OutputStream& out); +// Returns true if op is '=' or any compound assignment operator ('+=', '-=', etc.) +bool is_assignment(Token::Kind op); + +// Given a compound assignment operator, returns the non-assignment version of the operator (e.g. +// '+=' becomes '+') +Token::Kind remove_assignment(Token::Kind op); + NORETURN void sksl_abort(); } // namespace diff --git a/tests/SkSLInterpreterTest.cpp b/tests/SkSLInterpreterTest.cpp new file mode 100644 index 0000000000..10d2f3299e --- /dev/null +++ b/tests/SkSLInterpreterTest.cpp @@ -0,0 +1,200 @@ +/* + * Copyright 2019 Google LLC + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#include "SkSLCompiler.h" +#include "SkSLInterpreter.h" + +#include "Test.h" + +void test(skiatest::Reporter* r, const char* src, float inR, float inG, float inB, float inA, + float expectedR, float expectedG, float expectedB, float expectedA) { + SkSL::Compiler compiler; + SkSL::Program::Settings settings; + std::unique_ptr program = compiler.convertProgram( + SkSL::Program::kPipelineStage_Kind, + SkSL::String(src), settings); + REPORTER_ASSERT(r, program); + if (program) { + std::unique_ptr byteCode = compiler.toByteCode(*program); + REPORTER_ASSERT(r, !compiler.errorCount()); + if (compiler.errorCount() > 0) { + printf("%s\n%s", src, compiler.errorText().c_str()); + return; + } + SkSL::ByteCodeFunction* main = byteCode->fFunctions[0].get(); + SkSL::Interpreter interpreter(std::move(program), std::move(byteCode)); + float inoutColor[4] = { inR, inG, inB, inA }; + interpreter.run(*main, (SkSL::Interpreter::Value*) inoutColor, nullptr); + if (inoutColor[0] != expectedR || inoutColor[1] != expectedG || + inoutColor[2] != expectedB || inoutColor[3] != expectedA) { + printf("for program: %s\n", src); + printf(" expected (%f, %f, %f, %f), but received (%f, %f, %f, %f)\n", expectedR, + expectedG, expectedB, expectedA, inoutColor[0], inoutColor[1], inoutColor[2], + inoutColor[3]); + } + REPORTER_ASSERT(r, inoutColor[0] == expectedR); + REPORTER_ASSERT(r, inoutColor[1] == expectedG); + REPORTER_ASSERT(r, inoutColor[2] == expectedB); + REPORTER_ASSERT(r, inoutColor[3] == expectedA); + } else { + printf("%s\n%s", src, compiler.errorText().c_str()); + } +} + +DEF_TEST(SkSLInterpreterTEMP_TEST, r) { + test(r, "void main(inout half4 color) { half4 c = color; color += c; }", 0.25, 0.5, 0.75, 1, + 0.5, 1, 1.5, 2); +} + +DEF_TEST(SkSLInterpreterAdd, r) { + test(r, "void main(inout half4 color) { color.r = color.r + color.g; }", 0.25, 0.75, 0, 0, 1, + 0.75, 0, 0); + test(r, "void main(inout half4 color) { color += half4(1, 2, 3, 4); }", 4, 3, 2, 1, 5, 5, 5, 5); + test(r, "void main(inout half4 color) { half4 c = color; color += c; }", 0.25, 0.5, 0.75, 1, + 0.5, 1, 1.5, 2); + test(r, "void main(inout half4 color) { int a = 1; int b = 3; color.r = a + b; }", 1, 2, 3, 4, + 4, 2, 3, 4); +} + +DEF_TEST(SkSLInterpreterSubtract, r) { + test(r, "void main(inout half4 color) { color.r = color.r - color.g; }", 1, 0.75, 0, 0, 0.25, + 0.75, 0, 0); + test(r, "void main(inout half4 color) { color -= half4(1, 2, 3, 4); }", 5, 5, 5, 5, 4, 3, 2, 1); + test(r, "void main(inout half4 color) { half4 c = color; color -= c; }", 4, 3, 2, 1, + 0, 0, 0, 0); + test(r, "void main(inout half4 color) { int a = 3; int b = 1; color.r = a - b; }", 0, 0, 0, 0, + 2, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterMultiply, r) { + test(r, "void main(inout half4 color) { color.r = color.r * color.g; }", 2, 3, 0, 0, 6, 3, 0, + 0); + test(r, "void main(inout half4 color) { color *= half4(1, 2, 3, 4); }", 2, 3, 4, 5, 2, 6, 12, + 20); + test(r, "void main(inout half4 color) { half4 c = color; color *= c; }", 4, 3, 2, 1, + 16, 9, 4, 1); + test(r, "void main(inout half4 color) { int a = 3; int b = -2; color.r = a * b; }", 0, 0, 0, 0, + -6, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterDivide, r) { + test(r, "void main(inout half4 color) { color.r = color.r / color.g; }", 1, 2, 0, 0, 0.5, 2, 0, + 0); + test(r, "void main(inout half4 color) { color /= half4(1, 2, 3, 4); }", 12, 12, 12, 12, 12, 6, + 4, 3); + test(r, "void main(inout half4 color) { half4 c = color; color /= c; }", 4, 3, 2, 1, + 1, 1, 1, 1); + test(r, "void main(inout half4 color) { int a = 8; int b = -2; color.r = a / b; }", 0, 0, 0, 0, + -4, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterRemainder, r) { + test(r, "void main(inout half4 color) { int a = 8; int b = 3; a %= b; color.r = a; }", 0, 0, 0, + 0, 2, 0, 0, 0); + test(r, "void main(inout half4 color) { int a = 8; int b = 3; color.r = a % b; }", 0, 0, 0, 0, + 2, 0, 0, 0); + test(r, "void main(inout half4 color) { int2 a = int2(8, 10); a %= 6; color.rg = a; }", 0, 0, 0, + 0, 2, 4, 0, 0); +} + +DEF_TEST(SkSLInterpreterIf, r) { + test(r, "void main(inout half4 color) { if (color.r > color.g) color.a = 1; }", 5, 3, 0, 0, + 5, 3, 0, 1); + test(r, "void main(inout half4 color) { if (color.r > color.g) color.a = 1; }", 5, 5, 0, 0, + 5, 5, 0, 0); + test(r, "void main(inout half4 color) { if (color.r > color.g) color.a = 1; }", 5, 6, 0, 0, + 5, 6, 0, 0); + test(r, "void main(inout half4 color) { if (color.r < color.g) color.a = 1; }", 3, 5, 0, 0, + 3, 5, 0, 1); + test(r, "void main(inout half4 color) { if (color.r < color.g) color.a = 1; }", 5, 5, 0, 0, + 5, 5, 0, 0); + test(r, "void main(inout half4 color) { if (color.r < color.g) color.a = 1; }", 6, 5, 0, 0, + 6, 5, 0, 0); + test(r, "void main(inout half4 color) { if (color.r >= color.g) color.a = 1; }", 5, 3, 0, 0, + 5, 3, 0, 1); + test(r, "void main(inout half4 color) { if (color.r >= color.g) color.a = 1; }", 5, 5, 0, 0, + 5, 5, 0, 1); + test(r, "void main(inout half4 color) { if (color.r >= color.g) color.a = 1; }", 5, 6, 0, 0, + 5, 6, 0, 0); + test(r, "void main(inout half4 color) { if (color.r <= color.g) color.a = 1; }", 3, 5, 0, 0, + 3, 5, 0, 1); + test(r, "void main(inout half4 color) { if (color.r <= color.g) color.a = 1; }", 5, 5, 0, 0, + 5, 5, 0, 1); + test(r, "void main(inout half4 color) { if (color.r <= color.g) color.a = 1; }", 6, 5, 0, 0, + 6, 5, 0, 0); + test(r, "void main(inout half4 color) { if (color.r == color.g) color.a = 1; }", 2, 2, 0, 0, + 2, 2, 0, 1); + test(r, "void main(inout half4 color) { if (color.r == color.g) color.a = 1; }", 2, -2, 0, 0, + 2, -2, 0, 0); + test(r, "void main(inout half4 color) { if (color.r != color.g) color.a = 1; }", 2, 2, 0, 0, + 2, 2, 0, 0); + test(r, "void main(inout half4 color) { if (color.r != color.g) color.a = 1; }", 2, -2, 0, 0, + 2, -2, 0, 1); + test(r, "void main(inout half4 color) { if (color.r == color.g) color.a = 1; else " + "color.a = 2; }", 1, 1, 0, 0, 1, 1, 0, 1); + test(r, "void main(inout half4 color) { if (color.r == color.g) color.a = 1; else " + "color.a = 2; }", 2, -2, 0, 0, 2, -2, 0, 2); +} + +DEF_TEST(SkSLInterpreterWhile, r) { + test(r, "void main(inout half4 color) { while (color.r < 1) color.r += 0.25; }", 0, 0, 0, 0, 1, + 0, 0, 0); + test(r, "void main(inout half4 color) { while (color.r > 1) color.r += 0.25; }", 0, 0, 0, 0, 0, + 0, 0, 0); + test(r, "void main(inout half4 color) { while (true) { color.r += 0.5; " + "if (color.r > 1) break; } }", 0, 0, 0, 0, 1.5, 0, 0, 0); + test(r, "void main(inout half4 color) { while (color.r < 10) { color.r += 0.5; " + "if (color.r < 5) continue; break; } }", 0, 0, 0, 0, 5, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterDo, r) { + test(r, "void main(inout half4 color) { do color.r += 0.25; while (color.r < 1); }", 0, 0, 0, 0, + 1, 0, 0, 0); + test(r, "void main(inout half4 color) { do color.r += 0.25; while (color.r > 1); }", 0, 0, 0, 0, + 0.25, 0, 0, 0); + test(r, "void main(inout half4 color) { do { color.r += 0.5; if (color.r > 1) break; } while " + "(true); }", 0, 0, 0, 0, 1.5, 0, 0, 0); + test(r, "void main(inout half4 color) {do { color.r += 0.5; if (color.r < 5) " + "continue; if (color.r >= 5) break; } while (true); }", 0, 0, 0, 0, 5, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterFor, r) { + test(r, "void main(inout half4 color) { for (int i = 1; i <= 10; ++i) color.r += i; }", 0, 0, 0, + 0, 55, 0, 0, 0); + test(r, + "void main(inout half4 color) {" + " for (int i = 1; i <= 10; ++i)" + " for (int j = i; j <= 10; ++j)" + " color.r += j;" + "}", + 0, 0, 0, 0, + 385, 0, 0, 0); + test(r, + "void main(inout half4 color) {" + " for (int i = 1; i <= 10; ++i)" + " for (int j = 1; ; ++j) {" + " if (i == j) continue;" + " if (j > 10) break;" + " color.r += j;" + " }" + "}", + 0, 0, 0, 0, + 495, 0, 0, 0); +} + +DEF_TEST(SkSLInterpreterSwizzle, r) { + test(r, "void main(inout half4 color) { color = color.abgr; }", 1, 2, 3, 4, 4, 3, 2, 1); + test(r, "void main(inout half4 color) { color.rgb = half4(5, 6, 7, 8).bbg; }", 1, 2, 3, 4, 7, 7, + 6, 4); + test(r, "void main(inout half4 color) { color.bgr = int3(5, 6, 7); }", 1, 2, 3, 4, 7, 6, + 5, 4); +} + +DEF_TEST(SkSLInterpreterGlobal, r) { + test(r, "int x; void main(inout half4 color) { x = 10; color.b = x; }", 1, 2, 3, 4, 1, 2, 10, + 4); +}