diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index 9db311199..21c4566fe 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp @@ -95,7 +95,7 @@ void MergeReturnPass::ProcessStructured( continue; } - auto blockId = block->GetLabelInst()->result_id(); + auto blockId = block->id(); if (blockId == CurrentState().CurrentMergeId()) { // Pop the current state as we've hit the merge state_.pop_back(); @@ -104,7 +104,7 @@ void MergeReturnPass::ProcessStructured( // Predicate successors of the original return blocks as necessary. if (std::find(return_blocks.begin(), return_blocks.end(), block) != return_blocks.end()) { - PredicateBlocks(block, &predicated); + PredicateBlocks(block, &predicated, &order); } // Generate state for next block @@ -202,8 +202,6 @@ void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) { void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target) { - // A new edge is being added from |new_source| to |target|, so go through - // |target|'s phi nodes add an undef incoming value for |new_source|. target->ForEachPhiInst([this, new_source](Instruction* inst) { uint32_t undefId = Type2Undef(inst->type_id()); inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}}); @@ -277,7 +275,8 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block, } void MergeReturnPass::PredicateBlocks( - BasicBlock* return_block, std::unordered_set* predicated) { + BasicBlock* return_block, std::unordered_set* predicated, + std::list* order) { // The CFG is being modified as the function proceeds so avoid caching // successors. @@ -308,7 +307,6 @@ void MergeReturnPass::PredicateBlocks( while (block != nullptr && block != final_return_block_) { if (!predicated->insert(block).second) break; - // Skip structured subgraphs. BasicBlock* next = nullptr; if (state->InLoop()) { @@ -316,11 +314,11 @@ void MergeReturnPass::PredicateBlocks( while (state->LoopMergeId() == next->id()) { state++; } - BreakFromConstruct(block, next, predicated); + BreakFromConstruct(block, next, predicated, order); } else if (state->InStructuredFlow()) { next = context()->get_instr_block(state->CurrentMergeId()); state++; - BreakFromConstruct(block, next, predicated); + BreakFromConstruct(block, next, predicated, order); } else { BasicBlock* tail = block; while (tail->GetMergeInst()) { @@ -340,7 +338,7 @@ void MergeReturnPass::PredicateBlocks( next = succ_block; }); - PredicateBlock(block, tail, predicated); + PredicateBlock(block, tail, predicated, order); } block = next; } @@ -364,7 +362,8 @@ bool MergeReturnPass::RequiresPredication(const BasicBlock* block, void MergeReturnPass::PredicateBlock( BasicBlock* block, BasicBlock* tail_block, - std::unordered_set* predicated) { + std::unordered_set* predicated, + std::list* order) { if (!RequiresPredication(block, tail_block)) { return; } @@ -403,6 +402,9 @@ void MergeReturnPass::PredicateBlock( function_->InsertBasicBlockAfter(std::move(new_block), block); predicated->insert(old_body); + // Update |order| so old_block will be traversed. + InsertAfterElement(block, old_body, order); + if (tail_block == block) { tail_block = old_body; } @@ -425,6 +427,9 @@ void MergeReturnPass::PredicateBlock( predicated->insert(new_merge); new_merge->SetParent(function_); + // Update |order| so old_block will be traversed. + InsertAfterElement(tail_block, new_merge, order); + // Register the new label. get_def_use_mgr()->AnalyzeInstDef(new_merge->GetLabelInst()); context()->set_instr_block(new_merge->GetLabelInst(), new_merge); @@ -502,7 +507,8 @@ void MergeReturnPass::PredicateBlock( void MergeReturnPass::BreakFromConstruct( BasicBlock* block, BasicBlock* merge_block, - std::unordered_set* predicated) { + std::unordered_set* predicated, + std::list* order) { // Make sure the cfg is build here. If we don't then it becomes very hard // to know which new blocks need to be updated. context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG); @@ -537,9 +543,12 @@ void MergeReturnPass::BreakFromConstruct( function_->InsertBasicBlockAfter(std::move(new_block), block); predicated->insert(old_body); + // Update |order| so old_block will be traversed. + InsertAfterElement(block, old_body, order); + // Within the new header we need the following: // 1. Load of the return status flag - // 2. Branch to new merge (true) or old body (false) + // 2. Branch to |merge_block| (true) or old body (false) // 3. Update OpPhi instructions in |merge_block|. // // Sine we are branching to the merge block of the current construct, there is @@ -793,5 +802,14 @@ void MergeReturnPass::MarkForNewPhiNodes(BasicBlock* block, new_merge_nodes_[block] = single_original_pred; } +void MergeReturnPass::InsertAfterElement(BasicBlock* element, + BasicBlock* new_element, + std::list* list) { + auto pos = std::find(list->begin(), list->end(), element); + assert(pos != list->end()); + ++pos; + list->insert(pos, new_element); +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h index 0a77b1b86..472d059fe 100644 --- a/source/opt/merge_return_pass.h +++ b/source/opt/merge_return_pass.h @@ -209,25 +209,42 @@ class MergeReturnPass : public MemPass { // |AddReturnFlag| and |AddReturnValue| must have already been called. void BranchToBlock(BasicBlock* block, uint32_t target); - // Returns true if we need to pridicate |block| where |tail_block| is the + // Returns true if we need to predicate |block| where |tail_block| is the // merge point. (See |PredicateBlocks|). There is no need to predicate if // there is no code that could be executed. bool RequiresPredication(const BasicBlock* block, const BasicBlock* tail_block) const; - // For every basic block that is reachable from a basic block in - // |return_blocks|, extra code is added to jump around any code that should - // not be executed because the original code would have already returned. This - // involves adding new selections constructs to jump around these - // instructions. + // For every basic block that is reachable from |return_block|, extra code is + // added to jump around any code that should not be executed because the + // original code would have already returned. This involves adding new + // selections constructs to jump around these instructions. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. void PredicateBlocks(BasicBlock* return_block, - std::unordered_set* pSet); + std::unordered_set* pSet, + std::list* order); + + // Add a conditional branch at the start of |block| that either jumps to + // |merge_block| or the original code in |block| depending on the value in + // |return_flag_|. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + void BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block, + std::unordered_set* predicated, + std::list* order); // Add the predication code (see |PredicateBlocks|) to |tail_block| if it // requires predication. |tail_block| and any new blocks that are known to // not require predication will be added to |predicated|. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. void PredicateBlock(BasicBlock* block, BasicBlock* tail_block, - std::unordered_set* predicated); + std::unordered_set* predicated, + std::list* order); // Add an |OpReturn| or |OpReturnValue| to the end of |block|. If an // |OpReturnValue| is needed, the return value is loaded from |return_value_|. @@ -270,8 +287,19 @@ class MergeReturnPass : public MemPass { } } + // Modifies existing OpPhi instruction in |target| block to account for the + // new edge from |new_source|. The value for that edge will be an Undef. If + // |target| only had a single predecessor, then it is marked as needing new + // phi nodes. See |MarkForNewPhiNodes|. + void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target); + StructuredControlState& CurrentState() { return state_.back(); } + // Inserts |new_element| into |list| after the first occurrence of |element|. + // |element| must be in |list| at least once. + void InsertAfterElement(BasicBlock* element, BasicBlock* new_element, + std::list* list); + // A stack used to keep track of the innermost contain loop and selection // constructs. std::vector state_; @@ -294,14 +322,12 @@ class MergeReturnPass : public MemPass { // The basic block that is suppose to become the contain the only return value // after processing the current function. BasicBlock* final_return_block_; + // This map contains the set of nodes that use to have a single predcessor, // but now have more. They will need new OpPhi nodes. For each of the nodes, // it is mapped to it original single predcessor. It is assumed there are no // values that will need a phi on the new edges. std::unordered_map new_merge_nodes_; - void BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block, - std::unordered_set* predicated); - void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target); }; } // namespace opt diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp index fc1f112bd..cd30671f6 100644 --- a/test/opt/pass_merge_return_test.cpp +++ b/test/opt/pass_merge_return_test.cpp @@ -488,6 +488,71 @@ TEST_F(MergeReturnPassTest, SplitBlockUsedInPhi) { SinglePassRunAndMatch(before, false); } + +TEST_F(MergeReturnPassTest, UpdateOrderWhenPredicating) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: OpFunction +; CHECK: OpSelectionMerge [[m1:%\w+]] None +; CHECK-NOT: OpReturn +; CHECK: [[m1]] = OpLabel +; CHECK: OpSelectionMerge [[m2:%\w+]] None +; CHECK: OpSelectionMerge [[m3:%\w+]] None +; CHECK: OpSelectionMerge [[m4:%\w+]] None +; CHECK: OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[m4]] +; CHECK: [[m4]] = OpLabel +; CHECK-NEXT: [[ld4:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld4]] [[m3]] +; CHECK: [[m3]] = OpLabel +; CHECK-NEXT: [[ld3:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld3]] [[m2]] +; CHECK: [[m2]] = OpLabel + OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PS_DebugTiles" + OpExecutionMode %1 OriginUpperLeft + OpSource HLSL 600 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %1 = OpFunction %void None %3 + %5 = OpLabel + %6 = OpFunctionCall %void %7 + OpReturn + OpFunctionEnd + %7 = OpFunction %void None %3 + %8 = OpLabel + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpReturn + %10 = OpLabel + %12 = OpUndef %bool + OpSelectionMerge %13 None + OpBranchConditional %12 %14 %15 + %15 = OpLabel + %16 = OpUndef %bool + OpSelectionMerge %17 None + OpBranchConditional %16 %18 %17 + %18 = OpLabel + OpReturn + %17 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + %13 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} #endif TEST_F(MergeReturnPassTest, StructuredControlFlowBothMergeAndHeader) {