diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index afbb90081..a12b2ca3d 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp @@ -197,6 +197,7 @@ void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) { tail_opcode == SpvOpUnreachable) { assert(CurrentState().InLoop() && "Should be in the dummy loop."); BranchToBlock(block, CurrentState().LoopMergeId()); + return_blocks_.insert(block->id()); } } @@ -232,11 +233,19 @@ void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source, const auto& target_pred = cfg()->preds(target->id()); if (target_pred.size() == 1) { MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0])); + } else { + // If the loop contained a break and a return, OpPhi instructions may be + // required starting from the dominator of the loop merge. + DominatorAnalysis* dom_tree = + context()->GetDominatorAnalysis(target->GetParent()); + auto idom = dom_tree->ImmediateDominator(target); + if (idom) { + MarkForNewPhiNodes(target, idom); + } } } void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block, - uint32_t predecessor, Instruction& inst) { DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(merge_block->GetParent()); @@ -281,17 +290,16 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t undef_id = Type2Undef(inst.type_id()); std::vector phi_operands; - // Add the operands for the defining instructions. - phi_operands.push_back(inst.result_id()); - phi_operands.push_back(predecessor); - - // Add undef from all other blocks. + // Add the OpPhi operands. If the predecessor is a return block use undef, + // otherwise use |inst|'s id. std::vector preds = cfg()->preds(merge_block->id()); for (uint32_t pred_id : preds) { - if (pred_id != predecessor) { + if (return_blocks_.count(pred_id)) { phi_operands.push_back(undef_id); - phi_operands.push_back(pred_id); + } else { + phi_operands.push_back(inst.result_id()); } + phi_operands.push_back(pred_id); } Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands); @@ -400,8 +408,14 @@ bool MergeReturnPass::BreakFromConstruct( // Forget about the edges leaving block. They will be removed. cfg()->RemoveSuccessorEdges(block); - BasicBlock* old_body = block->SplitBasicBlock(context(), TakeNextId(), iter); + auto old_body_id = TakeNextId(); + BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter); predicated->insert(old_body); + // If a return block is being split, mark the new body block also as a return + // block. + if (return_blocks_.count(block->id())) { + return_blocks_.insert(old_body_id); + } // If |block| was a continue target for a loop |old_body| is now the correct // continue target. @@ -660,7 +674,7 @@ void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, BasicBlock* current_bb = pred; while (current_bb != nullptr && current_bb->id() != header_id) { for (Instruction& inst : *current_bb) { - CreatePhiNodesForInst(bb, pred->id(), inst); + CreatePhiNodesForInst(bb, inst); } current_bb = dom_tree->ImmediateDominator(current_bb); } diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h index d7e18f0c5..63094b798 100644 --- a/source/opt/merge_return_pass.h +++ b/source/opt/merge_return_pass.h @@ -240,12 +240,11 @@ class MergeReturnPass : public MemPass { // return block at the end of the pass. void CreateReturnBlock(); - // Creates a Phi node in |merge_block| for the result of |inst| coming from - // |predecessor|. Any uses of the result of |inst| that are no longer + // Creates a Phi node in |merge_block| for the result of |inst|. + // Any uses of the result of |inst| that are no longer // dominated by |inst|, are replaced with the result of the new |OpPhi| // instruction. - void CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor, - Instruction& inst); + void CreatePhiNodesForInst(BasicBlock* merge_block, Instruction& inst); // Traverse the nodes in |new_merge_nodes_|, and adds the OpPhi instructions // that are needed to make the code correct. It is assumed that at this point @@ -331,6 +330,11 @@ class MergeReturnPass : public MemPass { // values that will need a phi on the new edges. std::unordered_map new_merge_nodes_; bool HasNontrivialUnreachableBlocks(Function* function); + + // Contains all return blocks that are merged. This is set is populated while + // processing structured blocks and used to properly construct OpPhi + // instructions. + std::unordered_set return_blocks_; }; } // namespace opt diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp index f985c8925..2f2e74a99 100644 --- a/test/opt/pass_merge_return_test.cpp +++ b/test/opt/pass_merge_return_test.cpp @@ -1206,7 +1206,7 @@ TEST_F(MergeReturnPassTest, StructuredControlFlowPartialReplacePhi) { ; CHECK: [[bb:%\w+]] = OpLabel ; CHECK-NEXT: [[val:%\w+]] = OpUndef %bool ; CHECK: [[merge]] = OpLabel -; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool [[val]] [[bb]] {{%\w+}} [[old_ret_block]] +; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool {{%\w+}} [[old_ret_block]] [[val]] [[bb]] ; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb2:%\w+]] ; CHECK: [[bb2]] = OpLabel ; CHECK: OpBranch [[header2:%\w+]] @@ -1263,7 +1263,7 @@ TEST_F(MergeReturnPassTest, GeneratePhiInOuterLoop) { ; CHECK: [[continue]] = OpLabel ; CHECK-NEXT: [[undef:%\w+]] = OpUndef ; CHECK: [[merge]] = OpLabel - ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool [[undef]] [[continue]] {{%\w+}} {{%\w+}} + ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool {{%\w+}} {{%\w+}} [[undef]] [[continue]] ; CHECK: OpCopyObject %bool [[phi]] OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1328,7 +1328,7 @@ TEST_F(MergeReturnPassTest, SerialLoopsUpdateBlockMapping) { ; CHECK: OpLoopMerge [[merge:%\w+]] ; CHECK: [[def:%\w+]] = OpFOrdLessThan ; CHECK: [[merge]] = OpLabel -; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[def]] +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[def]] ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] ; CHECK: [[cont]] = OpLabel ; CHECK-NEXT: OpBranchConditional [[phi]] @@ -1480,6 +1480,198 @@ TEST_F(MergeReturnPassTest, InnerLoopMergeIsOuterLoopContinue) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SinglePassRunAndMatch(before, false); } + +TEST_F(MergeReturnPassTest, BreakFromLoopUseNoLongerDominated) { + const std::string spirv = R"( +; CHECK: [[undef:%\w+]] = OpUndef +; CHECK: OpLoopMerge +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] +; CHECK-NEXT: OpBranch [[body:%\w+]] +; CHECK: [[body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[non_ret:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret:%\w+]] [[non_ret]] +; CHECK: [[ret]] = OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[non_ret]] = OpLabel +; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]] +; CHECK: [[break]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[cont]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret]] [[def]] [[break]] [[def]] [[cont]] +; CHECK: OpLogicalNot {{%\w+}} [[phi]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +OpExecutionMode %func LocalSize 1 1 1 +%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func = OpFunction %void None %void_fn +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %3 +%3 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %true %4 %5 +%4 = OpLabel +OpReturn +%5 = OpLabel +%def = OpLogicalNot %bool %true +OpBranchConditional %true %6 %7 +%6 = OpLabel +OpBranch %8 +%7 = OpLabel +OpBranchConditional %true %2 %8 +%8 = OpLabel +OpBranch %9 +%9 = OpLabel +%use = OpLogicalNot %bool %def +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(spirv, true); +} + +TEST_F(MergeReturnPassTest, TwoBreaksFromLoopUsesNoLongerDominated) { + const std::string spirv = R"( +; CHECK: [[undef:%\w+]] = OpUndef +; CHECK: OpLoopMerge +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] +; CHECK-NEXT: OpBranch [[body:%\w+]] +; CHECK: [[body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[body2:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret1:%\w+]] [[body2]] +; CHECK: [[ret1]] = OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[body2]] = OpLabel +; CHECK-NEXT: [[def1:%\w+]] = OpLogicalNot +; CHECK-NEXT: OpSelectionMerge [[body3:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret2:%\w+]] [[body3:%\w+]] +; CHECK: [[ret2]] = OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[body3]] = OpLabel +; CHECK-NEXT: [[def2:%\w+]] = OpLogicalAnd +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]] +; CHECK: [[break]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[cont]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi1:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def1]] [[break]] [[def1]] [[cont]] +; CHECK-NEXT: [[phi2:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def2]] [[break]] [[def2]] [[cont]] +; CHECK: OpLogicalNot {{%\w+}} [[phi1]] +; CHECK: OpLogicalAnd {{%\w+}} [[phi2]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +OpExecutionMode %func LocalSize 1 1 1 +%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func = OpFunction %void None %void_fn +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %10 %9 None +OpBranch %3 +%3 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %true %4 %5 +%4 = OpLabel +OpReturn +%5 = OpLabel +%def1 = OpLogicalNot %bool %true +OpSelectionMerge %7 None +OpBranchConditional %true %6 %7 +%6 = OpLabel +OpReturn +%7 = OpLabel +%def2 = OpLogicalAnd %bool %true %true +OpBranchConditional %true %8 %9 +%8 = OpLabel +OpBranch %10 +%9 = OpLabel +OpBranchConditional %true %2 %10 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +%use1 = OpLogicalNot %bool %def1 +%use2 = OpLogicalAnd %bool %def2 %true +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(spirv, true); +} + +TEST_F(MergeReturnPassTest, PredicateBreakBlock) { + const std::string spirv = R"( +; IDs are being preserved so we can rely on basic block labels. +; CHECK: [[undef:%\w+]] = OpUndef +; CHECK: [[undef:%\w+]] = OpUndef +; CHECK: %13 = OpLabel +; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot +; CHECK: %8 = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] {{%\w+}} [[undef]] {{%\w+}} [[def]] %13 [[undef]] {{%\w+}} +; CHECK: OpLogicalAnd {{%\w+}} [[phi]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "func" +OpExecutionMode %1 LocalSize 1 1 1 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpUndef %bool +%1 = OpFunction %void None %3 +%6 = OpLabel +OpBranch %7 +%7 = OpLabel +OpLoopMerge %8 %9 None +OpBranch %10 +%10 = OpLabel +OpSelectionMerge %11 None +OpBranchConditional %true %12 %13 +%12 = OpLabel +OpLoopMerge %14 %15 None +OpBranch %16 +%16 = OpLabel +OpReturn +%15 = OpLabel +OpBranch %12 +%14 = OpLabel +OpUnreachable +%13 = OpLabel +%17 = OpLogicalNot %bool %true +OpBranch %8 +%11 = OpLabel +OpUnreachable +%9 = OpLabel +OpBranch %7 +%8 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpLogicalAnd %bool %17 %true +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(spirv, true); +} + } // namespace } // namespace opt } // namespace spvtools