SkSL now supports ternary lvalues

Bug: skia:
Change-Id: I859b756fe016f80c7a94f812623a16b4865204ba
Reviewed-on: https://skia-review.googlesource.com/96680
Reviewed-by: Greg Daniel <egdaniel@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
Ethan Nicholas 2018-01-18 13:32:11 -05:00 committed by Skia Commit-Bot
parent 7b6ea19c9c
commit a583b813b9
5 changed files with 103 additions and 2 deletions

View File

@ -137,6 +137,15 @@ bool BasicBlock::tryRemoveLValueBefore(std::vector<BasicBlock::Node>::iterator*
return false;
}
return this->tryRemoveExpressionBefore(iter, ((IndexExpression*) lvalue)->fIndex.get());
case Expression::kTernary_Kind:
if (!this->tryRemoveExpressionBefore(iter,
((TernaryExpression*) lvalue)->fTest.get())) {
return false;
}
if (!this->tryRemoveLValueBefore(iter, ((TernaryExpression*) lvalue)->fIfTrue.get())) {
return false;
}
return this->tryRemoveLValueBefore(iter, ((TernaryExpression*) lvalue)->fIfFalse.get());
default:
ABORT("invalid lvalue: %s\n", lvalue->description().c_str());
}
@ -425,6 +434,14 @@ void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr<Expression>* e) {
break;
case Expression::kVariableReference_Kind:
break;
case Expression::kTernary_Kind:
this->addExpression(cfg, &((TernaryExpression&) **e).fTest, true);
// Technically we will of course only evaluate one or the other, but if the test turns
// out to be constant, the ternary will get collapsed down to just one branch anyway. So
// it should be ok to pretend that we always evaluate both branches here.
this->addLValue(cfg, &((TernaryExpression&) **e).fIfTrue);
this->addLValue(cfg, &((TernaryExpression&) **e).fIfFalse);
break;
default:
// not an lvalue, can't happen
ASSERT(false);

View File

@ -247,6 +247,17 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio
(std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
break;
case Expression::kTernary_Kind:
// To simplify analysis, we just pretend that we write to both sides of the ternary.
// This allows for false positives (meaning we fail to detect that a variable might not
// have been assigned), but is preferable to false negatives.
this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
(std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
(std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
break;
default:
// not an lvalue, can't happen
ASSERT(false);
@ -396,6 +407,10 @@ static bool is_dead(const Expression& lvalue) {
const IndexExpression& idx = (IndexExpression&) lvalue;
return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
}
case Expression::kTernary_Kind: {
const TernaryExpression& t = (TernaryExpression&) lvalue;
return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
}
default:
ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
}

View File

@ -2094,6 +2094,12 @@ void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) {
case Expression::kIndex_Kind:
this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite);
break;
case Expression::kTernary_Kind: {
TernaryExpression& t = (TernaryExpression&) expr;
this->markWrittenTo(*t.fIfTrue, readWrite);
this->markWrittenTo(*t.fIfFalse, readWrite);
break;
}
default:
fErrors.error(expr.fOffset, "cannot assign to '" + expr.description() + "'");
break;

View File

@ -1452,7 +1452,6 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const
member,
this->getType(expr.fType)));
}
case Expression::kSwizzle_Kind: {
Swizzle& swizzle = (Swizzle&) expr;
size_t count = swizzle.fComponents.size();
@ -1481,7 +1480,31 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const
expr.fType));
}
}
case Expression::kTernary_Kind: {
TernaryExpression& t = (TernaryExpression&) expr;
SpvId test = this->writeExpression(*t.fTest, out);
SpvId end = this->nextId();
SpvId ifTrueLabel = this->nextId();
SpvId ifFalseLabel = this->nextId();
this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
this->writeLabel(ifTrueLabel, out);
SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
ASSERT(ifTrue);
this->writeInstruction(SpvOpBranch, end, out);
ifTrueLabel = fCurrentBlock;
SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
ASSERT(ifFalse);
ifFalseLabel = fCurrentBlock;
this->writeInstruction(SpvOpBranch, end, out);
SpvId result = this->nextId();
this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
ifTrueLabel, ifFalse, ifFalseLabel, out);
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
*this,
result,
this->getType(expr.fType)));
}
default:
// expr isn't actually an lvalue, create a dummy variable for it. This case happens due
// to the need to store values in temporary variables during function calls (see

View File

@ -1803,4 +1803,44 @@ DEF_TEST(SkSLNormalization, r) {
SkSL::Program::kGeometry_Kind);
}
DEF_TEST(SkSLTernaryLValue, r) {
test(r,
"void main() { half r, g; (true ? r : g) = 1; (false ? r : g) = 0; "
"sk_FragColor = half4(r, g, 1, 1); }",
*SkSL::ShaderCapsFactory::Default(),
"#version 400\n"
"out vec4 sk_FragColor;\n"
"void main() {\n"
" sk_FragColor = vec4(1.0, 0.0, 1.0, 1.0);\n"
"}\n");
test(r,
"void main() { half r, g; (true ? r : g) = sqrt(1); (false ? r : g) = sqrt(0); "
"sk_FragColor = half4(r, g, 1, 1); }",
*SkSL::ShaderCapsFactory::Default(),
"#version 400\n"
"out vec4 sk_FragColor;\n"
"void main() {\n"
" float r, g;\n"
" r = sqrt(1.0);\n"
" g = sqrt(0.0);\n"
" sk_FragColor = vec4(r, g, 1.0, 1.0);\n"
"}\n");
test(r,
"void main() {"
"half r, g;"
"(sqrt(1) > 0 ? r : g) = sqrt(1);"
"(sqrt(0) > 0 ? r : g) = sqrt(0);"
"sk_FragColor = half4(r, g, 1, 1);"
"}",
*SkSL::ShaderCapsFactory::Default(),
"#version 400\n"
"out vec4 sk_FragColor;\n"
"void main() {\n"
" float r, g;\n"
" sqrt(1.0) > 0.0 ? r : g = sqrt(1.0);\n"
" sqrt(0.0) > 0.0 ? r : g = sqrt(0.0);\n"
" sk_FragColor = vec4(r, g, 1.0, 1.0);\n"
"}\n");
}
#endif