From 9fbcce4ca17de7b2d8f6b322bcd1d43a7d6adc29 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 19 Sep 2018 16:40:09 -0400 Subject: [PATCH] Add unrolling to the legalization passes (#1903) Adds unrolling to the legalization passes. After enabling unrolling I found a bug when there is a self-referencing phi node. That has been fixed. The test that checks for that the order of optimizations is correct also needed to be updated. --- source/opt/loop_unroller.cpp | 33 ++++++++----- source/opt/optimizer.cpp | 1 + test/opt/loop_optimizations/unroll_simple.cpp | 46 +++++++++++++++++++ test/tools/opt/flags.py | 1 + 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp index 587615edf..24d3f5434 100644 --- a/source/opt/loop_unroller.cpp +++ b/source/opt/loop_unroller.cpp @@ -244,6 +244,10 @@ class LoopUnrollerUtilsImpl { // ect). void AssignNewResultIds(BasicBlock* basic_block); + // Using the map built by AssignNewResultIds, replace the uses in |inst| + // by the id that the use maps to. + void RemapOperands(Instruction* inst); + // Using the map built by AssignNewResultIds, for each instruction in // |basic_block| use // that map to substitute the IDs used by instructions (in the operands) with @@ -757,6 +761,11 @@ void LoopUnrollerUtilsImpl::CloseUnrolledLoop(Loop* loop) { for (BasicBlock* block : loop_blocks_inorder_) { RemapOperands(block); } + + // Rewrite the last phis, since they may still reference the original phi. + for (Instruction* last_phi : state_.previous_phis_) { + RemapOperands(last_phi); + } } // Uses the first loop to create a copy of the loop with new IDs. @@ -842,19 +851,21 @@ void LoopUnrollerUtilsImpl::AssignNewResultIds(BasicBlock* basic_block) { } } -// For all instructions in |basic_block| check if the operands used are from a -// copied instruction and if so swap out the operand for the copy of it. +void LoopUnrollerUtilsImpl::RemapOperands(Instruction* inst) { + auto remap_operands_to_new_ids = [this](uint32_t* id) { + auto itr = state_.new_inst.find(*id); + + if (itr != state_.new_inst.end()) { + *id = itr->second; + } + }; + + inst->ForEachInId(remap_operands_to_new_ids); +} + void LoopUnrollerUtilsImpl::RemapOperands(BasicBlock* basic_block) { for (Instruction& inst : *basic_block) { - auto remap_operands_to_new_ids = [this](uint32_t* id) { - auto itr = state_.new_inst.find(*id); - - if (itr != state_.new_inst.end()) { - *id = itr->second; - } - }; - - inst.ForEachInId(remap_operands_to_new_ids); + RemapOperands(&inst); } } diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 3a8b4d967..234141462 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -131,6 +131,7 @@ Optimizer& Optimizer::RegisterLegalizationPasses() { // Propagate constants to get as many constant conditions on branches // as possible. .RegisterPass(CreateCCPPass()) + .RegisterPass(CreateLoopUnrollPass(true)) .RegisterPass(CreateDeadBranchElimPass()) // Copy propagate members. Cleans up code sequences generated by // scalar replacement. Also important for removing OpPhi nodes. diff --git a/test/opt/loop_optimizations/unroll_simple.cpp b/test/opt/loop_optimizations/unroll_simple.cpp index 3b01fdc31..f551e7ca9 100644 --- a/test/opt/loop_optimizations/unroll_simple.cpp +++ b/test/opt/loop_optimizations/unroll_simple.cpp @@ -2952,6 +2952,52 @@ OpFunctionEnd EXPECT_NE(loop_2.GetLatchBlock(), loop_2.GetContinueBlock()); } +// Test that a loop with a self-referencing OpPhi instruction is handled +// correctly. +TEST_F(PassClassTest, OpPhiSelfReference) { + const std::string text = R"( + ; Find the two adds from the unrolled loop + ; CHECK: OpIAdd + ; CHECK: OpIAdd + ; CHECK: OpIAdd %uint %uint_0 %uint_1 + ; CHECK-NEXT: OpReturn + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "main" + OpExecutionMode %2 LocalSize 8 8 1 + OpSource HLSL 600 + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %2 = OpFunction %void None %5 + %10 = OpLabel + OpBranch %19 + %19 = OpLabel + %20 = OpPhi %uint %uint_0 %10 %20 %21 + %22 = OpPhi %uint %uint_0 %10 %23 %21 + %24 = OpULessThanEqual %bool %22 %uint_1 + OpLoopMerge %25 %21 Unroll + OpBranchConditional %24 %21 %25 + %21 = OpLabel + %23 = OpIAdd %uint %22 %uint_1 + OpBranch %19 + %25 = OpLabel + %14 = OpIAdd %uint %20 %uint_1 + OpReturn + OpFunctionEnd + )"; + + const bool kFullyUnroll = true; + const uint32_t kUnrollFactor = 0; + SinglePassRunAndMatch(text, true, kFullyUnroll, + kUnrollFactor); +} + } // namespace } // namespace opt } // namespace spvtools diff --git a/test/tools/opt/flags.py b/test/tools/opt/flags.py index 628d87108..ecdefcd75 100644 --- a/test/tools/opt/flags.py +++ b/test/tools/opt/flags.py @@ -235,6 +235,7 @@ class TestLegalizationPasses(expect.ValidObjectFile1_3, 'eliminate-local-multi-store', 'eliminate-dead-code-aggressive', 'ccp', + 'loop-unroll', 'eliminate-dead-branches', 'simplify-instructions', 'eliminate-dead-code-aggressive',