SkSL DSL matrix support

Change-Id: I9d43346df1a7611726f69ea54b4236e32d11d20c
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/395696
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
This commit is contained in:
Ethan Nicholas 2021-04-12 16:56:37 -04:00 committed by Skia Commit-Bot
parent b0fed1dc21
commit 8455893840
5 changed files with 216 additions and 95 deletions

View File

@ -8,6 +8,7 @@
#ifndef SKSL_DSL_TYPE
#define SKSL_DSL_TYPE
#include "include/sksl/DSLExpression.h"
#include "include/sksl/DSLModifiers.h"
#include <cstdint>
@ -30,11 +31,29 @@ enum TypeConstant : uint8_t {
kHalf2_Type,
kHalf3_Type,
kHalf4_Type,
kHalf2x2_Type,
kHalf3x2_Type,
kHalf4x2_Type,
kHalf2x3_Type,
kHalf3x3_Type,
kHalf4x3_Type,
kHalf2x4_Type,
kHalf3x4_Type,
kHalf4x4_Type,
kFloat_Type,
kFloat2_Type,
kFloat3_Type,
kFloat4_Type,
kFragmentProcessor_Type,
kFloat2x2_Type,
kFloat3x2_Type,
kFloat4x2_Type,
kFloat2x3_Type,
kFloat3x3_Type,
kFloat4x3_Type,
kFloat2x4_Type,
kFloat3x4_Type,
kFloat4x4_Type,
kInt_Type,
kInt2_Type,
kInt3_Type,
@ -54,11 +73,37 @@ public:
DSLType(const SkSL::Type* type)
: fSkSLType(type) {}
template<typename... Args>
static DSLExpression Construct(TypeConstant type, Args&&... args) {
SkTArray<DSLExpression> argArray;
argArray.reserve_back(sizeof...(args));
CollectArgs(argArray, std::forward<Args>(args)...);
return Construct(type, std::move(argArray));
}
static DSLExpression Construct(TypeConstant type, SkTArray<DSLExpression> argArray);
private:
const SkSL::Type& skslType() const;
const SkSL::Type* fSkSLType = nullptr;
static void CollectArgs(SkTArray<DSLExpression>& args) {}
template<class... RemainingArgs>
static void CollectArgs(SkTArray<DSLExpression>& args, DSLVar& var,
RemainingArgs&&... remaining) {
args.push_back(var);
CollectArgs(args, std::forward<RemainingArgs>(remaining)...);
}
template<class... RemainingArgs>
static void CollectArgs(SkTArray<DSLExpression>& args, DSLExpression expr,
RemainingArgs&&... remaining) {
args.push_back(std::move(expr));
CollectArgs(args, std::forward<RemainingArgs>(remaining)...);
}
TypeConstant fTypeConstant;
friend DSLType Array(const DSLType& base, int count);
@ -69,24 +114,40 @@ private:
};
#define TYPE(T) \
DSLExpression T(DSLExpression expr); \
DSLExpression T##2(DSLExpression expr); \
DSLExpression T##2(DSLExpression x, DSLExpression y); \
DSLExpression T##3(DSLExpression expr); \
DSLExpression T##3(DSLExpression x, DSLExpression y); \
DSLExpression T##3(DSLExpression x, DSLExpression y, DSLExpression z); \
DSLExpression T##4(DSLExpression expr); \
DSLExpression T##4(DSLExpression x, DSLExpression y); \
DSLExpression T##4(DSLExpression x, DSLExpression y, DSLExpression z); \
DSLExpression T##4(DSLExpression x, DSLExpression y, DSLExpression z, DSLExpression w);
template<typename... Args> \
DSLExpression T(Args&&... args) { \
return DSLType::Construct(k ## T ## _Type, std::forward<Args>(args)...); \
}
TYPE(Bool)
TYPE(Float)
TYPE(Half)
TYPE(Int)
TYPE(Short)
#define VECTOR_TYPE(T) \
TYPE(T) \
TYPE(T ## 2) \
TYPE(T ## 3) \
TYPE(T ## 4)
#define MATRIX_TYPE(T) \
TYPE(T ## 2x2) \
TYPE(T ## 3x2) \
TYPE(T ## 4x2) \
TYPE(T ## 2x3) \
TYPE(T ## 3x3) \
TYPE(T ## 4x3) \
TYPE(T ## 2x4) \
TYPE(T ## 3x4) \
TYPE(T ## 4x4)
VECTOR_TYPE(Bool)
VECTOR_TYPE(Float)
VECTOR_TYPE(Half)
VECTOR_TYPE(Int)
VECTOR_TYPE(Short)
MATRIX_TYPE(Float)
MATRIX_TYPE(Half)
#undef TYPE
#undef VECTOR_TYPE
#undef MATRIX_TYPE
DSLType Array(const DSLType& base, int count);

View File

@ -37,6 +37,24 @@ const SkSL::Type& DSLType::skslType() const {
return *context.fTypes.fHalf3;
case kHalf4_Type:
return *context.fTypes.fHalf4;
case kHalf2x2_Type:
return *context.fTypes.fHalf2x2;
case kHalf3x2_Type:
return *context.fTypes.fHalf3x2;
case kHalf4x2_Type:
return *context.fTypes.fHalf4x2;
case kHalf2x3_Type:
return *context.fTypes.fHalf2x3;
case kHalf3x3_Type:
return *context.fTypes.fHalf3x3;
case kHalf4x3_Type:
return *context.fTypes.fHalf4x3;
case kHalf2x4_Type:
return *context.fTypes.fHalf2x4;
case kHalf3x4_Type:
return *context.fTypes.fHalf3x4;
case kHalf4x4_Type:
return *context.fTypes.fHalf4x4;
case kFloat_Type:
return *context.fTypes.fFloat;
case kFloat2_Type:
@ -47,6 +65,24 @@ const SkSL::Type& DSLType::skslType() const {
return *context.fTypes.fFloat4;
case kFragmentProcessor_Type:
return *context.fTypes.fFragmentProcessor;
case kFloat2x2_Type:
return *context.fTypes.fFloat2x2;
case kFloat3x2_Type:
return *context.fTypes.fFloat3x2;
case kFloat4x2_Type:
return *context.fTypes.fFloat4x2;
case kFloat2x3_Type:
return *context.fTypes.fFloat2x3;
case kFloat3x3_Type:
return *context.fTypes.fFloat3x3;
case kFloat4x3_Type:
return *context.fTypes.fFloat4x3;
case kFloat2x4_Type:
return *context.fTypes.fFloat2x4;
case kFloat3x4_Type:
return *context.fTypes.fFloat3x4;
case kFloat4x4_Type:
return *context.fTypes.fFloat4x4;
case kInt_Type:
return *context.fTypes.fInt;
case kInt2_Type:
@ -70,86 +106,10 @@ const SkSL::Type& DSLType::skslType() const {
}
}
static DSLExpression construct1(const SkSL::Type& type, DSLExpression a) {
std::vector<DSLExpression> args;
args.push_back(std::move(a));
return DSLWriter::Construct(type, std::move(args));
DSLExpression DSLType::Construct(TypeConstant type, SkTArray<DSLExpression> argArray) {
return DSLWriter::Construct(DSLType(type).skslType(), std::move(argArray));
}
static DSLExpression construct2(const SkSL::Type& type, DSLExpression a,
DSLExpression b) {
std::vector<DSLExpression> args;
args.push_back(std::move(a));
args.push_back(std::move(b));
return DSLWriter::Construct(type, std::move(args));
}
static DSLExpression construct3(const SkSL::Type& type, DSLExpression a,
DSLExpression b,
DSLExpression c) {
std::vector<DSLExpression> args;
args.push_back(std::move(a));
args.push_back(std::move(b));
args.push_back(std::move(c));
return DSLWriter::Construct(type, std::move(args));
}
static DSLExpression construct4(const SkSL::Type& type, DSLExpression a, DSLExpression b,
DSLExpression c, DSLExpression d) {
std::vector<DSLExpression> args;
args.push_back(std::move(a));
args.push_back(std::move(b));
args.push_back(std::move(c));
args.push_back(std::move(d));
return DSLWriter::Construct(type, std::move(args));
}
#define TYPE(T) \
DSLExpression T(DSLExpression a) { \
return construct1(*DSLWriter::Context().fTypes.f ## T, std::move(a)); \
} \
DSLExpression T ## 2(DSLExpression a) { \
return construct1(*DSLWriter::Context().fTypes.f ## T ## 2, std::move(a)); \
} \
DSLExpression T ## 2(DSLExpression a, DSLExpression b) { \
return construct2(*DSLWriter::Context().fTypes.f ## T ## 2, std::move(a), \
std::move(b)); \
} \
DSLExpression T ## 3(DSLExpression a) { \
return construct1(*DSLWriter::Context().fTypes.f ## T ## 3, std::move(a)); \
} \
DSLExpression T ## 3(DSLExpression a, DSLExpression b) { \
return construct2(*DSLWriter::Context().fTypes.f ## T ## 3, std::move(a), \
std::move(b)); \
} \
DSLExpression T ## 3(DSLExpression a, DSLExpression b, DSLExpression c) { \
return construct3(*DSLWriter::Context().fTypes.f ## T ## 3, std::move(a), \
std::move(b), std::move(c)); \
} \
DSLExpression T ## 4(DSLExpression a) { \
return construct1(*DSLWriter::Context().fTypes.f ## T ## 4, std::move(a)); \
} \
DSLExpression T ## 4(DSLExpression a, DSLExpression b) { \
return construct2(*DSLWriter::Context().fTypes.f ## T ## 4, std::move(a), \
std::move(b)); \
} \
DSLExpression T ## 4(DSLExpression a, DSLExpression b, DSLExpression c) { \
return construct3(*DSLWriter::Context().fTypes.f ## T ## 4, std::move(a), std::move(b), \
std::move(c)); \
} \
DSLExpression T ## 4(DSLExpression a, DSLExpression b, DSLExpression c, DSLExpression d) { \
return construct4(*DSLWriter::Context().fTypes.f ## T ## 4, std::move(a), std::move(b), \
std::move(c), std::move(d)); \
}
TYPE(Bool)
TYPE(Float)
TYPE(Half)
TYPE(Int)
TYPE(Short)
#undef TYPE
DSLType Array(const DSLType& base, int count) {
SkASSERT(count >= 1);
return DSLWriter::SymbolTable()->addArrayDimension(&base.skslType(), count);

View File

@ -129,7 +129,7 @@ DSLPossibleExpression DSLWriter::Coerce(std::unique_ptr<Expression> left, const
}
DSLPossibleExpression DSLWriter::Construct(const SkSL::Type& type,
std::vector<DSLExpression> rawArgs) {
SkTArray<DSLExpression> rawArgs) {
SkSL::ExpressionArray args;
args.reserve_back(rawArgs.size());

View File

@ -157,7 +157,7 @@ public:
static DSLPossibleExpression Coerce(std::unique_ptr<Expression> left, const SkSL::Type& type);
static DSLPossibleExpression Construct(const SkSL::Type& type,
std::vector<DSLExpression> rawArgs);
SkTArray<DSLExpression> rawArgs);
static std::unique_ptr<Expression> ConvertBinary(std::unique_ptr<Expression> left, Operator op,
std::unique_ptr<Expression> right);

View File

@ -346,6 +346,103 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLBool, r, ctxInfo) {
}
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLMatrices, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var f22(kFloat2x2_Type, "f22");
EXPECT_EQUAL(f22 = Float2x2(1), "(f22 = float2x2(1.0))");
Var f32(kFloat3x2_Type, "f32");
EXPECT_EQUAL(f32 = Float3x2(1, 2, 3, 4, 5, 6),
"(f32 = float3x2(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))");
Var f42(kFloat4x2_Type, "f42");
EXPECT_EQUAL(f42 = Float4x2(Float4(1, 2, 3, 4), 5, 6, 7, 8),
"(f42 = float4x2(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0))");
Var f23(kFloat2x3_Type, "f23");
EXPECT_EQUAL(f23 = Float2x3(1, Float2(2, 3), 4, Float2(5, 6)),
"(f23 = float2x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))");
Var f33(kFloat3x3_Type, "f33");
EXPECT_EQUAL(f33 = Float3x3(Float3(1, 2, 3), 4, Float2(5, 6), 7, 8, 9),
"(f33 = float3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))");
Var f43(kFloat4x3_Type, "f43");
EXPECT_EQUAL(f43 = Float4x3(Float4(1, 2, 3, 4), Float4(5, 6, 7, 8), Float4(9, 10, 11, 12)),
"(f43 = float4x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0))");
Var f24(kFloat2x4_Type, "f24");
EXPECT_EQUAL(f24 = Float2x4(1, 2, 3, 4, 5, 6, 7, 8),
"(f24 = float2x4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0))");
Var f34(kFloat3x4_Type, "f34");
EXPECT_EQUAL(f34 = Float3x4(1, 2, 3, 4, 5, 6, 7, 8, 9, Float3(10, 11, 12)),
"(f34 = float3x4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0))");
Var f44(kFloat4x4_Type, "f44");
EXPECT_EQUAL(f44 = Float4x4(1), "(f44 = float4x4(1.0))");
Var h22(kHalf2x2_Type, "h22");
EXPECT_EQUAL(h22 = Half2x2(1), "(h22 = half2x2(1.0))");
Var h32(kHalf3x2_Type, "h32");
EXPECT_EQUAL(h32 = Half3x2(1, 2, 3, 4, 5, 6),
"(h32 = half3x2(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))");
Var h42(kHalf4x2_Type, "h42");
EXPECT_EQUAL(h42 = Half4x2(Half4(1, 2, 3, 4), 5, 6, 7, 8),
"(h42 = half4x2(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0))");
Var h23(kHalf2x3_Type, "h23");
EXPECT_EQUAL(h23 = Half2x3(1, Half2(2, 3), 4, Half2(5, 6)),
"(h23 = half2x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))");
Var h33(kHalf3x3_Type, "h33");
EXPECT_EQUAL(h33 = Half3x3(Half3(1, 2, 3), 4, Half2(5, 6), 7, 8, 9),
"(h33 = half3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))");
Var h43(kHalf4x3_Type, "h43");
EXPECT_EQUAL(h43 = Half4x3(Half4(1, 2, 3, 4), Half4(5, 6, 7, 8), Half4(9, 10, 11, 12)),
"(h43 = half4x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0))");
Var h24(kHalf2x4_Type, "h24");
EXPECT_EQUAL(h24 = Half2x4(1, 2, 3, 4, 5, 6, 7, 8),
"(h24 = half2x4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0))");
Var h34(kHalf3x4_Type, "h34");
EXPECT_EQUAL(h34 = Half3x4(1, 2, 3, 4, 5, 6, 7, 8, 9, Half3(10, 11, 12)),
"(h34 = half3x4(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0))");
Var h44(kHalf4x4_Type, "h44");
EXPECT_EQUAL(h44 = Half4x4(1), "(h44 = half4x4(1.0))");
EXPECT_EQUAL(f22 * 2, "(f22 * 2.0)");
EXPECT_EQUAL(f22 == Float2x2(1), "(f22 == float2x2(1.0))");
EXPECT_EQUAL(h42[0][1], "h42[0].y");
EXPECT_EQUAL(f43 * Float4(0), "(f43 * float4(0.0))");
EXPECT_EQUAL(h23 * 2, "(h23 * 2.0)");
EXPECT_EQUAL(Inverse(f44), "inverse(f44)");
{
ExpectError error(r, "error: invalid arguments to 'float3x3' constructor (expected 9 "
"scalars, but found 2)\n");
DSLExpression(Float3x3(Float2(1))).release();
}
{
ExpectError error(r, "error: invalid arguments to 'half2x2' constructor (expected 4 "
"scalars, but found 5)\n");
DSLExpression(Half2x2(1, 2, 3, 4, 5)).release();
}
{
ExpectError error(r, "error: type mismatch: '*' cannot operate on 'float4x3', 'float3'\n");
DSLExpression(f43 * Float3(1)).release();
}
{
ExpectError error(r, "error: type mismatch: '=' cannot operate on 'float4x3', "
"'float3x3'\n");
DSLExpression(f43 = f33).release();
}
{
ExpectError error(r, "error: type mismatch: '=' cannot operate on 'half2x2', "
"'float2x2'\n");
DSLExpression(h22 = f22).release();
}
{
ExpectError error(r,
"error: no match for inverse(float4x3)\n");
DSLExpression(Inverse(f43)).release();
}
}
DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLPlus, r, ctxInfo) {
AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
Var a(kFloat_Type, "a"), b(kFloat_Type, "b");
@ -1401,8 +1498,11 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSample, r, ctxInfo) {
DSLVar child(kUniform_Modifier, kFragmentProcessor_Type, "child");
EXPECT_EQUAL(Sample(child), "sample(child)");
EXPECT_EQUAL(Sample(child, Float2(0, 0)), "sample(child, float2(0.0, 0.0))");
EXPECT_EQUAL(Sample(child, Float3x3(1.0)), "sample(child, float3x3(1.0))");
EXPECT_EQUAL(Sample(child, Half4(1)), "sample(child, half4(1.0))");
EXPECT_EQUAL(Sample(child, Half4(1), Float2(0)), "sample(child, half4(1.0), float2(0.0))");
EXPECT_EQUAL(Sample(child, Half4(1), Float3x3(1.0)),
"sample(child, half4(1.0), float3x3(1.0))");
{
ExpectError error(r, "error: no match for sample(fragmentProcessor, bool)\n");