From 642cde289d3eab7b5775cc8fc739b402f19ed9bb Mon Sep 17 00:00:00 2001 From: John Stiles Date: Tue, 23 Feb 2021 14:57:01 -0500 Subject: [PATCH] Optimize @switch statements in SwitchStatement::Make. At IR generation time, this CL limits our optimizations to only @switch statements. A regular switch statement will only be optimized during the optimization phase even if the switch-value is a known compile-time constant. This is done to avoid upsetting our reachability analysis. Most of this CL is moving existing logic from SkSLCompiler into SkSLAnalysis and SkSLSwitchStatement. Although the diffs look large, the actual changes are very small. Change-Id: I90920f41bc386dfa7a980ae7510f6681231a5120 Bug: skia:11340, skia:11342, skia:11319 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/372679 Commit-Queue: John Stiles Auto-Submit: John Stiles Reviewed-by: Brian Osman --- resources/sksl/shared/Discard.sksl | 4 +- resources/sksl/shared/Enum.sksl | 4 +- resources/sksl/shared/StaticSwitch.sksl | 3 +- src/sksl/SkSLAnalysis.cpp | 70 +++++++++ src/sksl/SkSLAnalysis.h | 12 ++ src/sksl/SkSLCompiler.cpp | 186 +----------------------- src/sksl/ir/SkSLSwitchStatement.cpp | 151 ++++++++++++++++++- src/sksl/ir/SkSLSwitchStatement.h | 22 ++- tests/SkSLDSLTest.cpp | 12 +- 9 files changed, 268 insertions(+), 196 deletions(-) diff --git a/resources/sksl/shared/Discard.sksl b/resources/sksl/shared/Discard.sksl index 331c36c6be..695427f7a1 100644 --- a/resources/sksl/shared/Discard.sksl +++ b/resources/sksl/shared/Discard.sksl @@ -1,6 +1,6 @@ void main() { -half x; - @switch (1) { + half x; + switch (1) { case 0: x = 0; break; default: x = 1; discard; } diff --git a/resources/sksl/shared/Enum.sksl b/resources/sksl/shared/Enum.sksl index 7d10b01a6c..6b80f0c5ff 100644 --- a/resources/sksl/shared/Enum.sksl +++ b/resources/sksl/shared/Enum.sksl @@ -25,7 +25,7 @@ void main() { case E::kZero: sk_FragColor = half4(11); break; case E::kOne: sk_FragColor = half4(12); break; } - @switch (e) { + @switch (E::kZero) { case E::kZero: sk_FragColor = half4(13); break; case E::kOne: sk_FragColor = half4(14); break; } @@ -40,7 +40,7 @@ void main() { sk_FragColor = (m == SkBlendMode::kClear) ? half4(19) : half4(-19); sk_FragColor = (m != SkBlendMode::kSrc) ? half4(20) : half4(-20); - @switch (m) { + @switch (SkBlendMode::kClear) { case SkBlendMode::kClear: sk_FragColor = half4(21); break; case SkBlendMode::kSrc: sk_FragColor = half4(22); break; case SkBlendMode::kDst: sk_FragColor = half4(23); break; diff --git a/resources/sksl/shared/StaticSwitch.sksl b/resources/sksl/shared/StaticSwitch.sksl index 0c02ec8c81..0d28b6c468 100644 --- a/resources/sksl/shared/StaticSwitch.sksl +++ b/resources/sksl/shared/StaticSwitch.sksl @@ -1,6 +1,5 @@ void main() { - int x = 1; - @switch (x) { + @switch (1) { case 1: sk_FragColor = half4(1); break; default: sk_FragColor = half4(0); } diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp index e203002dbe..dbb95f66aa 100644 --- a/src/sksl/SkSLAnalysis.cpp +++ b/src/sksl/SkSLAnalysis.cpp @@ -329,6 +329,68 @@ private: using INHERITED = ProgramVisitor; }; +class SwitchCaseContainsExit : public ProgramVisitor { +public: + SwitchCaseContainsExit(bool conditionalExits) : fConditionalExits(conditionalExits) {} + + bool visitStatement(const Statement& stmt) override { + switch (stmt.kind()) { + case Statement::Kind::kBlock: + return INHERITED::visitStatement(stmt); + + case Statement::Kind::kReturn: + // Returns are an early exit regardless of the surrounding control structures. + return fConditionalExits ? fInConditional : !fInConditional; + + case Statement::Kind::kContinue: + // Continues are an early exit from switches, but not loops. + return !fInLoop && + (fConditionalExits ? fInConditional : !fInConditional); + + case Statement::Kind::kBreak: + // Breaks cannot escape from switches or loops. + return !fInLoop && !fInSwitch && + (fConditionalExits ? fInConditional : !fInConditional); + + case Statement::Kind::kIf: { + ++fInConditional; + bool result = INHERITED::visitStatement(stmt); + --fInConditional; + return result; + } + + case Statement::Kind::kFor: + case Statement::Kind::kDo: { + // Loops are treated as conditionals because a loop could potentially execute zero + // times. We don't have a straightforward way to determine that a loop definitely + // executes at least once. + ++fInConditional; + ++fInLoop; + bool result = INHERITED::visitStatement(stmt); + --fInLoop; + --fInConditional; + return result; + } + + case Statement::Kind::kSwitch: { + ++fInSwitch; + bool result = INHERITED::visitStatement(stmt); + --fInSwitch; + return result; + } + + default: + return false; + } + } + + bool fConditionalExits = false; + int fInConditional = 0; + int fInLoop = 0; + int fInSwitch = 0; + using INHERITED = ProgramVisitor; +}; + } // namespace //////////////////////////////////////////////////////////////////////////////// @@ -356,6 +418,14 @@ int Analysis::NodeCountUpToLimit(const FunctionDefinition& function, int limit) return NodeCountVisitor{limit}.visit(*function.body()); } +bool Analysis::SwitchCaseContainsUnconditionalExit(Statement& stmt) { + return SwitchCaseContainsExit{/*conditionalExits=*/false}.visitStatement(stmt); +} + +bool Analysis::SwitchCaseContainsConditionalExit(Statement& stmt) { + return SwitchCaseContainsExit{/*conditionalExits=*/true}.visitStatement(stmt); +} + std::unique_ptr Analysis::GetUsage(const Program& program) { auto usage = std::make_unique(); ProgramUsageVisitor addRefs(usage.get(), /*delta=*/+1); diff --git a/src/sksl/SkSLAnalysis.h b/src/sksl/SkSLAnalysis.h index 0e217717b3..ee782746cd 100644 --- a/src/sksl/SkSLAnalysis.h +++ b/src/sksl/SkSLAnalysis.h @@ -42,6 +42,18 @@ struct Analysis { static int NodeCountUpToLimit(const FunctionDefinition& function, int limit); + /** + * Finds unconditional exits from a switch-case. Returns true if this statement unconditionally + * causes an exit from this switch (via continue, break or return). + */ + static bool SwitchCaseContainsUnconditionalExit(Statement& stmt); + + /** + * Finds conditional exits from a switch-case. Returns true if this statement contains a + * conditional that wraps a potential exit from the switch (via continue, break or return). + */ + static bool SwitchCaseContainsConditionalExit(Statement& stmt); + static std::unique_ptr GetUsage(const Program& program); static std::unique_ptr GetUsage(const LoadedModule& module); diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index 68226e2eec..c4df1826b7 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -1172,184 +1172,6 @@ void Compiler::simplifyExpression(DefinitionMap& definitions, } } -static bool contains_exit(Statement& stmt, bool conditionalExits) { - class ContainsExit : public ProgramVisitor { - public: - ContainsExit(bool e) : fConditionalExits(e) {} - - bool visitStatement(const Statement& stmt) override { - switch (stmt.kind()) { - case Statement::Kind::kBlock: - return INHERITED::visitStatement(stmt); - - case Statement::Kind::kReturn: - // Returns are an early exit regardless of the surrounding control structures. - return fConditionalExits ? fInConditional : !fInConditional; - - case Statement::Kind::kContinue: - // Continues are an early exit from switches, but not loops. - return !fInLoop && - (fConditionalExits ? fInConditional : !fInConditional); - - case Statement::Kind::kBreak: - // Breaks cannot escape from switches or loops. - return !fInLoop && !fInSwitch && - (fConditionalExits ? fInConditional : !fInConditional); - - case Statement::Kind::kIf: { - ++fInConditional; - bool result = INHERITED::visitStatement(stmt); - --fInConditional; - return result; - } - - case Statement::Kind::kFor: - case Statement::Kind::kDo: { - // Loops are treated as conditionals because a loop could potentially execute - // zero times. We don't have a straightforward way to determine that a loop - // definitely executes at least once. - ++fInConditional; - ++fInLoop; - bool result = INHERITED::visitStatement(stmt); - --fInLoop; - --fInConditional; - return result; - } - - case Statement::Kind::kSwitch: { - ++fInSwitch; - bool result = INHERITED::visitStatement(stmt); - --fInSwitch; - return result; - } - - default: - return false; - } - } - - bool fConditionalExits = false; - int fInConditional = 0; - int fInLoop = 0; - int fInSwitch = 0; - using INHERITED = ProgramVisitor; - }; - - return ContainsExit{conditionalExits}.visitStatement(stmt); -} - -// Finds unconditional exits from a switch-case. Returns true if this statement unconditionally -// causes an exit from this switch (via continue, break or return). -static bool contains_unconditional_exit(Statement& stmt) { - return contains_exit(stmt, /*conditionalExits=*/false); -} - -// Finds conditional exits from a switch-case. Returns true if this statement contains a conditional -// that wraps a potential exit from the switch (via continue, break or return). -static bool contains_conditional_exit(Statement& stmt) { - return contains_exit(stmt, /*conditionalExits=*/true); -} - -static void move_all_but_break(std::unique_ptr& stmt, StatementArray* target) { - switch (stmt->kind()) { - case Statement::Kind::kBlock: { - // Recurse into the block. - Block& block = static_cast(*stmt); - - StatementArray blockStmts; - blockStmts.reserve_back(block.children().size()); - for (std::unique_ptr& stmt : block.children()) { - move_all_but_break(stmt, &blockStmts); - } - - target->push_back(std::make_unique(block.fOffset, std::move(blockStmts), - block.symbolTable(), block.isScope())); - break; - } - - case Statement::Kind::kBreak: - // Do not append a break to the target. - break; - - default: - // Append normal statements to the target. - target->push_back(std::move(stmt)); - break; - } -} - -// Returns a block containing all of the statements that will be run if the given case matches -// (which, owing to the statements being owned by unique_ptrs, means the switch itself will be -// broken by this call and must then be discarded). -// Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as -// when break statements appear inside conditionals. -static std::unique_ptr block_for_case(SwitchStatement* switchStatement, - SwitchCase* caseToCapture) { - // We have to be careful to not move any of the pointers until after we're sure we're going to - // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan - // of action. First, find the switch-case we are interested in. - auto iter = switchStatement->cases().begin(); - for (; iter != switchStatement->cases().end(); ++iter) { - if (iter->get() == caseToCapture) { - break; - } - } - - // Next, walk forward through the rest of the switch. If we find a conditional break, we're - // stuck and can't simplify at all. If we find an unconditional break, we have a range of - // statements that we can use for simplification. - auto startIter = iter; - Statement* stripBreakStmt = nullptr; - for (; iter != switchStatement->cases().end(); ++iter) { - for (std::unique_ptr& stmt : (*iter)->statements()) { - if (contains_conditional_exit(*stmt)) { - // We can't reduce switch-cases to a block when they have conditional exits. - return nullptr; - } - if (contains_unconditional_exit(*stmt)) { - // We found an unconditional exit. We can use this block, but we need to strip - // out a break statement if it has one. - stripBreakStmt = stmt.get(); - break; - } - } - - if (stripBreakStmt) { - break; - } - } - - // We fell off the bottom of the switch or encountered a break. We know the range of statements - // that we need to move over, and we know it's safe to do so. - StatementArray caseStmts; - - // We can move over most of the statements as-is. - while (startIter != iter) { - for (std::unique_ptr& stmt : (*startIter)->statements()) { - caseStmts.push_back(std::move(stmt)); - } - ++startIter; - } - - // For the last statement, we need to move what we can, stopping at a break if there is one. - if (stripBreakStmt != nullptr) { - for (std::unique_ptr& stmt : (*startIter)->statements()) { - if (stmt.get() == stripBreakStmt) { - move_all_but_break(stmt, &caseStmts); - stripBreakStmt = nullptr; - break; - } - - caseStmts.push_back(std::move(stmt)); - } - } - - SkASSERT(stripBreakStmt == nullptr); // Verify that we fixed the unconditional break. - - // Return our newly-synthesized block. - return std::make_unique(/*offset=*/-1, std::move(caseStmts), switchStatement->symbols()); -} - void Compiler::simplifyStatement(DefinitionMap& definitions, BasicBlock& b, std::vector::iterator* iter, @@ -1419,6 +1241,8 @@ void Compiler::simplifyStatement(DefinitionMap& definitions, break; } case Statement::Kind::kSwitch: { + // TODO(skia:11319): this optimization logic is redundant with the static-switch + // optimization code found in SwitchStatement.cpp. SwitchStatement& s = stmt->as(); int64_t switchValue; if (ConstantFolder::GetConstantInt(*s.value(), &switchValue)) { @@ -1433,7 +1257,8 @@ void Compiler::simplifyStatement(DefinitionMap& definitions, int64_t caseValue; SkAssertResult(ConstantFolder::GetConstantInt(*c->value(), &caseValue)); if (caseValue == switchValue) { - std::unique_ptr newBlock = block_for_case(&s, c.get()); + std::unique_ptr newBlock = + SwitchStatement::BlockForCase(&s.cases(), c.get(), s.symbols()); if (newBlock) { (*iter)->setStatement(std::move(newBlock), usage); found = true; @@ -1454,7 +1279,8 @@ void Compiler::simplifyStatement(DefinitionMap& definitions, if (!found) { // no matching case. use default if it exists, or kill the whole thing if (defaultCase) { - std::unique_ptr newBlock = block_for_case(&s, defaultCase); + std::unique_ptr newBlock = + SwitchStatement::BlockForCase(&s.cases(), defaultCase, s.symbols()); if (newBlock) { (*iter)->setStatement(std::move(newBlock), usage); } else { diff --git a/src/sksl/ir/SkSLSwitchStatement.cpp b/src/sksl/ir/SkSLSwitchStatement.cpp index 20499283df..9b68f55143 100644 --- a/src/sksl/ir/SkSLSwitchStatement.cpp +++ b/src/sksl/ir/SkSLSwitchStatement.cpp @@ -10,9 +10,12 @@ #include #include "include/private/SkTHash.h" +#include "src/sksl/SkSLAnalysis.h" #include "src/sksl/SkSLConstantFolder.h" #include "src/sksl/SkSLContext.h" #include "src/sksl/SkSLProgramSettings.h" +#include "src/sksl/ir/SkSLBlock.h" +#include "src/sksl/ir/SkSLNop.h" #include "src/sksl/ir/SkSLSymbolTable.h" #include "src/sksl/ir/SkSLType.h" @@ -76,6 +79,106 @@ static std::forward_list find_duplicate_case_values( return duplicateCases; } +static void move_all_but_break(std::unique_ptr& stmt, StatementArray* target) { + switch (stmt->kind()) { + case Statement::Kind::kBlock: { + // Recurse into the block. + Block& block = stmt->as(); + + StatementArray blockStmts; + blockStmts.reserve_back(block.children().size()); + for (std::unique_ptr& stmt : block.children()) { + move_all_but_break(stmt, &blockStmts); + } + + target->push_back(std::make_unique(block.fOffset, std::move(blockStmts), + block.symbolTable(), block.isScope())); + break; + } + + case Statement::Kind::kBreak: + // Do not append a break to the target. + break; + + default: + // Append normal statements to the target. + target->push_back(std::move(stmt)); + break; + } +} + +std::unique_ptr SwitchStatement::BlockForCase( + std::vector>* cases, + SwitchCase* caseToCapture, + std::shared_ptr symbolTable) { + // We have to be careful to not move any of the pointers until after we're sure we're going to + // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan + // of action. First, find the switch-case we are interested in. + auto iter = cases->begin(); + for (; iter != cases->end(); ++iter) { + if (iter->get() == caseToCapture) { + break; + } + } + + // Next, walk forward through the rest of the switch. If we find a conditional break, we're + // stuck and can't simplify at all. If we find an unconditional break, we have a range of + // statements that we can use for simplification. + auto startIter = iter; + Statement* stripBreakStmt = nullptr; + for (; iter != cases->end(); ++iter) { + for (std::unique_ptr& stmt : (*iter)->statements()) { + if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) { + // We can't reduce switch-cases to a block when they have conditional exits. + return nullptr; + } + if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) { + // We found an unconditional exit. We can use this block, but we'll need to strip + // out the break statement if there is one. + stripBreakStmt = stmt.get(); + break; + } + } + + if (stripBreakStmt) { + break; + } + } + + // We fell off the bottom of the switch or encountered a break. We know the range of statements + // that we need to move over, and we know it's safe to do so. + StatementArray caseStmts; + caseStmts.reserve_back(std::distance(startIter, iter)); + + // We can move over most of the statements as-is. + while (startIter != iter) { + for (std::unique_ptr& stmt : (*startIter)->statements()) { + caseStmts.push_back(std::move(stmt)); + } + ++startIter; + } + + // If we found an unconditional break at the end, we need to move what we can while avoiding + // that break. + if (stripBreakStmt != nullptr) { + for (std::unique_ptr& stmt : (*startIter)->statements()) { + if (stmt.get() == stripBreakStmt) { + move_all_but_break(stmt, &caseStmts); + stripBreakStmt = nullptr; + break; + } + + caseStmts.push_back(std::move(stmt)); + } + } + + SkASSERT(stripBreakStmt == nullptr); // Verify that we stripped any unconditional break. + + // Return our newly-synthesized block. + return std::make_unique(caseToCapture->fOffset, std::move(caseStmts), + std::move(symbolTable)); +} + std::unique_ptr SwitchStatement::Make(const Context& context, int offset, bool isStatic, @@ -157,7 +260,53 @@ std::unique_ptr SwitchStatement::Make(const Context& context, // Confirm that every switch-case value is unique. SkASSERT(find_duplicate_case_values(cases).empty()); - // TODO(skia:11340): Optimize static switches. + // Flatten @switch statements. + if (isStatic) { + SKSL_INT switchValue; + if (ConstantFolder::GetConstantInt(*value, &switchValue)) { + SwitchCase* defaultCase = nullptr; + SwitchCase* matchingCase = nullptr; + for (const std::unique_ptr& sc : cases) { + if (!sc->value()) { + defaultCase = sc.get(); + continue; + } + + SKSL_INT caseValue; + SkAssertResult(ConstantFolder::GetConstantInt(*sc->value(), &caseValue)); + if (caseValue == switchValue) { + matchingCase = sc.get(); + break; + } + } + + if (!matchingCase) { + // No case value matches the switch value. + if (!defaultCase) { + // No default switch-case exists; the switch had no effect. + // We can eliminate the entire switch! + return std::make_unique(); + } + // We had a default case; that's what we matched with. + matchingCase = defaultCase; + } + + // Convert the switch-case that we matched with into a block. + std::unique_ptr newBlock = BlockForCase(&cases, matchingCase, symbolTable); + if (newBlock) { + return newBlock; + } + + // Report an error if this was a static switch and BlockForCase failed us. + if (!context.fConfig->fSettings.fPermitInvalidStaticTests) { + context.fErrors.error(value->fOffset, + "static switch contains non-static conditional exit"); + return nullptr; + } + } + } + + // The switch couldn't be optimized away; emit it normally. return std::make_unique(offset, isStatic, std::move(value), std::move(cases), std::move(symbolTable)); } diff --git a/src/sksl/ir/SkSLSwitchStatement.h b/src/sksl/ir/SkSLSwitchStatement.h index e4714d75fc..f120b493ee 100644 --- a/src/sksl/ir/SkSLSwitchStatement.h +++ b/src/sksl/ir/SkSLSwitchStatement.h @@ -36,8 +36,9 @@ public: , fCases(std::move(cases)) , fSymbols(std::move(symbols)) {} - // Create a `switch` statement with an array of case-values and case-statements. - // Coerces case values to the proper type and reports an error if cases are duplicated. + /** Create a `switch` statement with an array of case-values and case-statements. + * Coerces case values to the proper type and reports an error if cases are duplicated. + */ static std::unique_ptr Make(const Context& context, int offset, bool isStatic, @@ -46,8 +47,10 @@ public: SkTArray caseStatements, std::shared_ptr symbolTable); - // Create a `switch` statement with an array of SwitchCases. - // The array of SwitchCases must already contain non-overlapping, correctly-typed case values. + /** + * Create a `switch` statement with an array of SwitchCases. + * The array of SwitchCases must already contain non-overlapping, correctly-typed case values. + */ static std::unique_ptr Make(const Context& context, int offset, bool isStatic, @@ -55,6 +58,17 @@ public: std::vector> cases, std::shared_ptr symbolTable); + /** + * Returns a block containing all of the statements that will be run if the given case matches + * (which, owing to the statements being owned by unique_ptrs, means the switch itself will be + * disassembled by this call and must then be discarded). + * Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such + * as when break statements appear inside conditionals. + */ + static std::unique_ptr BlockForCase(std::vector>* cases, + SwitchCase* caseToCapture, + std::shared_ptr symbolTable); + std::unique_ptr& value() { return fValue; } diff --git a/tests/SkSLDSLTest.cpp b/tests/SkSLDSLTest.cpp index 1ee9f791e0..0fde84789b 100644 --- a/tests/SkSLDSLTest.cpp +++ b/tests/SkSLDSLTest.cpp @@ -39,7 +39,8 @@ public: } ~ExpectError() override { - REPORTER_ASSERT(fReporter, !fMsg); + REPORTER_ASSERT(fReporter, !fMsg, + "Error mismatch: expected:\n%sbut no error occurred\n", fMsg); SetErrorHandler(nullptr); } @@ -1149,11 +1150,12 @@ DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLSwitch, r, ctxInfo) { default: discard; } )"); - Statement y = Switch(b); - EXPECT_EQUAL(y, "switch (b) {}"); - Statement z = Switch(b, Default(), Case(0), Case(1)); - EXPECT_EQUAL(z, "switch (b) { default: case 0: case 1: }"); + EXPECT_EQUAL(Switch(b), + "switch (b) {}"); + + EXPECT_EQUAL(Switch(b, Default(), Case(0), Case(1)), + "switch (b) { default: case 0: case 1: }"); { ExpectError error(r, "error: duplicate case value '0'\n");