sksl can now fold constant vector or matrix equality expressions
Bug: skia: Change-Id: Icaddae68e53ed3629bcdc04b5f0b541d9e4398e2 Reviewed-on: https://skia-review.googlesource.com/14260 Commit-Queue: Ethan Nicholas <ethannicholas@google.com> Reviewed-by: Ben Wagner <benjaminwagner@google.com>
This commit is contained in:
parent
7ab6a7f40b
commit
3deaeb2dc0
@ -1043,6 +1043,12 @@ std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
|
||||
return std::unique_ptr<Expression>(new Constructor(Position(), left.fType, \
|
||||
std::move(args)));
|
||||
switch (op) {
|
||||
case Token::EQEQ:
|
||||
return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
|
||||
left.compareConstant(fContext, right)));
|
||||
case Token::NEQ:
|
||||
return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
|
||||
!left.compareConstant(fContext, right)));
|
||||
case Token::PLUS: RETURN_VEC_COMPONENTWISE_RESULT(+);
|
||||
case Token::MINUS: RETURN_VEC_COMPONENTWISE_RESULT(-);
|
||||
case Token::STAR: RETURN_VEC_COMPONENTWISE_RESULT(*);
|
||||
@ -1050,6 +1056,20 @@ std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
if (left.fType.kind() == Type::kMatrix_Kind &&
|
||||
right.fType.kind() == Type::kMatrix_Kind &&
|
||||
left.fKind == right.fKind) {
|
||||
switch (op) {
|
||||
case Token::EQEQ:
|
||||
return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
|
||||
left.compareConstant(fContext, right)));
|
||||
case Token::NEQ:
|
||||
return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
|
||||
!left.compareConstant(fContext, right)));
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
#undef RESULT
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -33,6 +33,11 @@ struct BoolLiteral : public Expression {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compareConstant(const Context& context, const Expression& other) const override {
|
||||
BoolLiteral& b = (BoolLiteral&) other;
|
||||
return fValue == b.fValue;
|
||||
}
|
||||
|
||||
const bool fValue;
|
||||
|
||||
typedef Expression INHERITED;
|
||||
|
@ -81,6 +81,44 @@ struct Constructor : public Expression {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compareConstant(const Context& context, const Expression& other) const override {
|
||||
ASSERT(other.fKind == Expression::kConstructor_Kind && other.fType == fType);
|
||||
Constructor& c = (Constructor&) other;
|
||||
if (c.fType.kind() == Type::kVector_Kind) {
|
||||
for (int i = 0; i < fType.columns(); i++) {
|
||||
if (!this->getVecComponent(i).compareConstant(context, c.getVecComponent(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// shouldn't be possible to have a constant constructor that isn't a vector or matrix;
|
||||
// a constant scalar constructor should have been collapsed down to the appropriate
|
||||
// literal
|
||||
ASSERT(fType.kind() == Type::kMatrix_Kind);
|
||||
const FloatLiteral fzero(context, Position(), 0);
|
||||
const IntLiteral izero(context, Position(), 0);
|
||||
const Expression* zero;
|
||||
if (fType.componentType() == *context.fFloat_Type) {
|
||||
zero = &fzero;
|
||||
} else {
|
||||
ASSERT(fType.componentType() == *context.fInt_Type);
|
||||
zero = &izero;
|
||||
}
|
||||
for (int col = 0; col < fType.columns(); col++) {
|
||||
for (int row = 0; row < fType.rows(); row++) {
|
||||
const Expression* component1 = getMatComponent(col, row);
|
||||
const Expression* component2 = c.getMatComponent(col, row);
|
||||
if (!(component1 ? component1 : zero)->compareConstant(
|
||||
context,
|
||||
component2 ? *component2 : *zero)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const Expression& getVecComponent(int index) const {
|
||||
ASSERT(fType.kind() == Type::kVector_Kind);
|
||||
if (fArguments.size() == 1 && fArguments[0]->fType.kind() == Type::kScalar_Kind) {
|
||||
@ -118,6 +156,51 @@ struct Constructor : public Expression {
|
||||
return ((IntLiteral&) c).fValue;
|
||||
}
|
||||
|
||||
// null return should be interpreted as zero
|
||||
const Expression* getMatComponent(int col, int row) const {
|
||||
ASSERT(this->isConstant());
|
||||
ASSERT(fType.kind() == Type::kMatrix_Kind);
|
||||
ASSERT(col < fType.columns() && row < fType.rows());
|
||||
if (fArguments.size() == 1) {
|
||||
if (fArguments[0]->fType.kind() == Type::kScalar_Kind) {
|
||||
// single scalar argument, so matrix is of the form:
|
||||
// x 0 0
|
||||
// 0 x 0
|
||||
// 0 0 x
|
||||
// return x if col == row
|
||||
return col == row ? fArguments[0].get() : nullptr;
|
||||
}
|
||||
if (fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
|
||||
ASSERT(fArguments[0]->fKind == Expression::kConstructor_Kind);
|
||||
// single matrix argument. make sure we're within the argument's bounds.
|
||||
const Type& argType = ((Constructor&) *fArguments[0]).fType;
|
||||
if (col < argType.columns() && row < argType.rows()) {
|
||||
// within bounds, defer to argument
|
||||
return ((Constructor&) *fArguments[0]).getMatComponent(col, row);
|
||||
}
|
||||
// out of bounds, return 0
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
int currentIndex = 0;
|
||||
int targetIndex = col * fType.rows() + row;
|
||||
for (const auto& arg : fArguments) {
|
||||
ASSERT(targetIndex >= currentIndex);
|
||||
ASSERT(arg->fType.rows() == 1);
|
||||
if (currentIndex + arg->fType.columns() > targetIndex) {
|
||||
if (arg->fType.columns() == 1) {
|
||||
return arg.get();
|
||||
} else {
|
||||
ASSERT(arg->fType.kind() == Type::kVector_Kind);
|
||||
ASSERT(arg->fKind == Expression::kConstructor_Kind);
|
||||
return &((Constructor&) *arg).getVecComponent(targetIndex - currentIndex);
|
||||
}
|
||||
}
|
||||
currentIndex += arg->fType.columns();
|
||||
}
|
||||
ABORT("can't happen, matrix component out of bounds");
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Expression>> fArguments;
|
||||
|
||||
typedef Expression INHERITED;
|
||||
|
@ -48,10 +48,23 @@ struct Expression : public IRNode {
|
||||
, fKind(kind)
|
||||
, fType(std::move(type)) {}
|
||||
|
||||
/**
|
||||
* Returns true if this expression is constant. compareConstant must be implemented for all
|
||||
* constants!
|
||||
*/
|
||||
virtual bool isConstant() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compares this constant expression against another constant expression of the same type. It is
|
||||
* an error to call this on non-constant expressions, or if the types of the expressions do not
|
||||
* match.
|
||||
*/
|
||||
virtual bool compareConstant(const Context& context, const Expression& other) const {
|
||||
ABORT("cannot call compareConstant on this type");
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if evaluating the expression potentially has side effects. Expressions may never
|
||||
* return false if they actually have side effects, but it is legal (though suboptimal) to
|
||||
|
@ -34,6 +34,11 @@ struct FloatLiteral : public Expression {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compareConstant(const Context& context, const Expression& other) const override {
|
||||
FloatLiteral& f = (FloatLiteral&) other;
|
||||
return fValue == f.fValue;
|
||||
}
|
||||
|
||||
const double fValue;
|
||||
|
||||
typedef Expression INHERITED;
|
||||
|
@ -35,6 +35,11 @@ struct IntLiteral : public Expression {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compareConstant(const Context& context, const Expression& other) const override {
|
||||
IntLiteral& i = (IntLiteral&) other;
|
||||
return fValue == i.fValue;
|
||||
}
|
||||
|
||||
const int64_t fValue;
|
||||
|
||||
typedef Expression INHERITED;
|
||||
|
@ -125,6 +125,9 @@ DEF_TEST(SkSLConstructorTypeMismatch, r) {
|
||||
test_failure(r,
|
||||
"struct foo { int x; } foo; void main() { vec2 x = vec2(foo); }",
|
||||
"error: 1: 'foo' is not a valid parameter to 'vec2' constructor\n1 error\n");
|
||||
test_failure(r,
|
||||
"void main() { mat2 x = mat2(true); }",
|
||||
"error: 1: expected 'float', but found 'bool'\n1 error\n");
|
||||
}
|
||||
|
||||
DEF_TEST(SkSLConstructorArgumentCount, r) {
|
||||
|
@ -573,6 +573,34 @@ DEF_TEST(SkSLConstantFolding, r) {
|
||||
"sk_FragColor = vec4(2) * vec4(1, 2, 3, 4);"
|
||||
"sk_FragColor = vec4(12) / vec4(1, 2, 3, 4);"
|
||||
"sk_FragColor.r = (vec4(12) / vec4(1, 2, 3, 4)).y;"
|
||||
"sk_FragColor.x = vec4(1) == vec4(1) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec4(1) == vec4(2) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec2(1) == vec2(1, 1) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec2(1, 1) == vec2(1, 1) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec2(1) == vec2(1, 0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec4(1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), 1, 0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) == "
|
||||
"mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(1.0, 1.0)) == "
|
||||
"mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1) == mat2(1) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1) == mat2(0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(2) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat3x2(2) == mat3x2(vec2(2.0, 0.0), vec2(0.0, 2.0), vec2(0.0)) ? "
|
||||
"1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec2(1) != vec2(1, 0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = vec4(1) != vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1) != mat2(1) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1) != mat2(0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 0.0)) == "
|
||||
"mat3(mat2(1.0)) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(mat3(1.0)) == mat2(1.0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(vec4(1.0, 0.0, 0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(1.0, 0.0, vec2(0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;"
|
||||
"sk_FragColor.x = mat2(vec2(1.0, 0.0), 0.0, 1.0) == mat2(1.0) ? 1.0 : 0.0;"
|
||||
"}",
|
||||
*SkSL::ShaderCapsFactory::Default(),
|
||||
"#version 400\n"
|
||||
@ -617,6 +645,30 @@ DEF_TEST(SkSLConstantFolding, r) {
|
||||
" sk_FragColor = vec4(2.0, 4.0, 6.0, 8.0);\n"
|
||||
" sk_FragColor = vec4(12.0, 6.0, 4.0, 3.0);\n"
|
||||
" sk_FragColor.x = 6.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 0.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
" sk_FragColor.x = 1.0;\n"
|
||||
"}\n");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user