// Copyright (c) 2019 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/reduce/remove_selection_reduction_opportunity_finder.h" #include "source/reduce/remove_selection_reduction_opportunity.h" namespace spvtools { namespace reduce { using opt::BasicBlock; using opt::IRContext; using opt::Instruction; namespace { const uint32_t kMergeNodeIndex = 0; const uint32_t kContinueNodeIndex = 1; } // namespace std::string RemoveSelectionReductionOpportunityFinder::GetName() const { return "RemoveSelectionReductionOpportunityFinder"; } std::vector> RemoveSelectionReductionOpportunityFinder::GetAvailableOpportunities( IRContext* context) const { // Get all loop merge and continue blocks so we can check for these later. std::unordered_set merge_and_continue_blocks_from_loops; for (auto& function : *context->module()) { for (auto& block : function) { if (auto merge_instruction = block.GetMergeInst()) { if (merge_instruction->opcode() == SpvOpLoopMerge) { uint32_t merge_block_id = merge_instruction->GetSingleWordOperand(kMergeNodeIndex); uint32_t continue_block_id = merge_instruction->GetSingleWordOperand(kContinueNodeIndex); merge_and_continue_blocks_from_loops.insert(merge_block_id); merge_and_continue_blocks_from_loops.insert(continue_block_id); } } } } // Return all selection headers where the OpSelectionMergeInstruction can be // removed. std::vector> result; for (auto& function : *context->module()) { for (auto& block : function) { if (auto merge_instruction = block.GetMergeInst()) { if (merge_instruction->opcode() == SpvOpSelectionMerge) { if (CanOpSelectionMergeBeRemoved( context, block, merge_instruction, merge_and_continue_blocks_from_loops)) { result.push_back( MakeUnique(&block)); } } } } } return result; } bool RemoveSelectionReductionOpportunityFinder::CanOpSelectionMergeBeRemoved( IRContext* context, const BasicBlock& header_block, Instruction* merge_instruction, std::unordered_set merge_and_continue_blocks_from_loops) { assert(header_block.GetMergeInst() == merge_instruction && "CanOpSelectionMergeBeRemoved(...): header block and merge " "instruction mismatch"); // The OpSelectionMerge instruction is needed if either of the following are // true. // // 1. The header block has at least two (unique) successors that are not // merge or continue blocks of a loop. // // 2. The predecessors of the merge block are "using" the merge block to avoid // divergence. In other words, there exists a predecessor of the merge block // that has a successor that is not the merge block of this construct and not // a merge or continue block of a loop. // 1. { uint32_t divergent_successor_count = 0; std::unordered_set seen_successors; header_block.ForEachSuccessorLabel( [&seen_successors, &merge_and_continue_blocks_from_loops, &divergent_successor_count](uint32_t successor) { // Not already seen. if (seen_successors.find(successor) == seen_successors.end()) { seen_successors.insert(successor); // Not a loop continue or merge. if (merge_and_continue_blocks_from_loops.find(successor) == merge_and_continue_blocks_from_loops.end()) { ++divergent_successor_count; } } }); if (divergent_successor_count > 1) { return false; } } // 2. { uint32_t merge_block_id = merge_instruction->GetSingleWordOperand(kMergeNodeIndex); for (uint32_t predecessor_block_id : context->cfg()->preds(merge_block_id)) { const BasicBlock* predecessor_block = context->cfg()->block(predecessor_block_id); assert(predecessor_block); bool found_divergent_successor = false; predecessor_block->ForEachSuccessorLabel( [&found_divergent_successor, merge_block_id, &merge_and_continue_blocks_from_loops](uint32_t successor_id) { // The successor is not the merge block, nor a loop merge or // continue. if (successor_id != merge_block_id && merge_and_continue_blocks_from_loops.find(successor_id) == merge_and_continue_blocks_from_loops.end()) { found_divergent_successor = true; } }); if (found_divergent_successor) { return false; } } } return true; } } // namespace reduce } // namespace spvtools