// Copyright (c) 2023 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/invocation_interlock_placement_pass.h" #include #include #include #include #include #include #include #include #include #include #include "source/enum_set.h" #include "source/enum_string_mapping.h" #include "source/opt/ir_context.h" #include "source/opt/reflect.h" #include "source/spirv_target_env.h" #include "source/util/string_utils.h" namespace spvtools { namespace opt { namespace { constexpr uint32_t kEntryPointExecutionModelInIdx = 0; constexpr uint32_t kEntryPointFunctionIdInIdx = 1; constexpr uint32_t kFunctionCallFunctionIdInIdx = 0; } // namespace bool InvocationInterlockPlacementPass::hasSingleNextBlock(uint32_t block_id, bool reverse_cfg) { if (reverse_cfg) { // We are traversing forward, so check whether there is a single successor. BasicBlock* block = cfg()->block(block_id); switch (block->tail()->opcode()) { case spv::Op::OpBranchConditional: return false; case spv::Op::OpSwitch: return block->tail()->NumInOperandWords() == 1; default: return !block->tail()->IsReturnOrAbort(); } } else { // We are traversing backward, so check whether there is a single // predecessor. return cfg()->preds(block_id).size() == 1; } } void InvocationInterlockPlacementPass::forEachNext( uint32_t block_id, bool reverse_cfg, std::function f) { if (reverse_cfg) { BasicBlock* block = cfg()->block(block_id); block->ForEachSuccessorLabel([f](uint32_t succ_id) { f(succ_id); }); } else { for (uint32_t pred_id : cfg()->preds(block_id)) { f(pred_id); } } } void InvocationInterlockPlacementPass::addInstructionAtBlockBoundary( BasicBlock* block, spv::Op opcode, bool at_end) { if (at_end) { assert(block->begin()->opcode() != spv::Op::OpPhi && "addInstructionAtBlockBoundary expects to be called with at_end == " "true only if there is a single successor to block"); // Insert a begin instruction at the end of the block. Instruction* begin_inst = new Instruction(context(), opcode); begin_inst->InsertAfter(&*--block->tail()); } else { assert(block->begin()->opcode() != spv::Op::OpPhi && "addInstructionAtBlockBoundary expects to be called with at_end == " "false only if there is a single predecessor to block"); // Insert an end instruction at the beginning of the block. Instruction* end_inst = new Instruction(context(), opcode); end_inst->InsertBefore(&*block->begin()); } } bool InvocationInterlockPlacementPass::killDuplicateBegin(BasicBlock* block) { bool found = false; return context()->KillInstructionIf( block->begin(), block->end(), [&found](Instruction* inst) { if (inst->opcode() == spv::Op::OpBeginInvocationInterlockEXT) { if (found) { return true; } found = true; } return false; }); } bool InvocationInterlockPlacementPass::killDuplicateEnd(BasicBlock* block) { std::vector to_kill; block->ForEachInst([&to_kill](Instruction* inst) { if (inst->opcode() == spv::Op::OpEndInvocationInterlockEXT) { to_kill.push_back(inst); } }); if (to_kill.size() <= 1) { return false; } to_kill.pop_back(); for (Instruction* inst : to_kill) { context()->KillInst(inst); } return true; } void InvocationInterlockPlacementPass::recordBeginOrEndInFunction( Function* func) { if (extracted_functions_.count(func)) { return; } bool had_begin = false; bool had_end = false; func->ForEachInst([this, &had_begin, &had_end](Instruction* inst) { switch (inst->opcode()) { case spv::Op::OpBeginInvocationInterlockEXT: had_begin = true; break; case spv::Op::OpEndInvocationInterlockEXT: had_end = true; break; case spv::Op::OpFunctionCall: { uint32_t function_id = inst->GetSingleWordInOperand(kFunctionCallFunctionIdInIdx); Function* inner_func = context()->GetFunction(function_id); recordBeginOrEndInFunction(inner_func); ExtractionResult result = extracted_functions_[inner_func]; had_begin = had_begin || result.had_begin; had_end = had_end || result.had_end; break; } default: break; } }); ExtractionResult result = {had_begin, had_end}; extracted_functions_[func] = result; } bool InvocationInterlockPlacementPass:: removeBeginAndEndInstructionsFromFunction(Function* func) { bool modified = false; func->ForEachInst([this, &modified](Instruction* inst) { switch (inst->opcode()) { case spv::Op::OpBeginInvocationInterlockEXT: context()->KillInst(inst); modified = true; break; case spv::Op::OpEndInvocationInterlockEXT: context()->KillInst(inst); modified = true; break; default: break; } }); return modified; } bool InvocationInterlockPlacementPass::extractInstructionsFromCalls( std::vector blocks) { bool modified = false; for (BasicBlock* block : blocks) { block->ForEachInst([this, &modified](Instruction* inst) { if (inst->opcode() == spv::Op::OpFunctionCall) { uint32_t function_id = inst->GetSingleWordInOperand(kFunctionCallFunctionIdInIdx); Function* func = context()->GetFunction(function_id); ExtractionResult result = extracted_functions_[func]; if (result.had_begin) { Instruction* new_inst = new Instruction( context(), spv::Op::OpBeginInvocationInterlockEXT); new_inst->InsertBefore(inst); modified = true; } if (result.had_end) { Instruction* new_inst = new Instruction(context(), spv::Op::OpEndInvocationInterlockEXT); new_inst->InsertAfter(inst); modified = true; } } }); } return modified; } void InvocationInterlockPlacementPass::recordExistingBeginAndEndBlock( std::vector blocks) { for (BasicBlock* block : blocks) { block->ForEachInst([this, block](Instruction* inst) { switch (inst->opcode()) { case spv::Op::OpBeginInvocationInterlockEXT: begin_.insert(block->id()); break; case spv::Op::OpEndInvocationInterlockEXT: end_.insert(block->id()); break; default: break; } }); } } InvocationInterlockPlacementPass::BlockSet InvocationInterlockPlacementPass::computeReachableBlocks( BlockSet& previous_inside, const BlockSet& starting_nodes, bool reverse_cfg) { BlockSet inside = starting_nodes; std::deque worklist; worklist.insert(worklist.begin(), starting_nodes.begin(), starting_nodes.end()); while (!worklist.empty()) { uint32_t block_id = worklist.front(); worklist.pop_front(); forEachNext(block_id, reverse_cfg, [&inside, &previous_inside, &worklist](uint32_t next_id) { previous_inside.insert(next_id); if (inside.insert(next_id).second) { worklist.push_back(next_id); } }); } return inside; } bool InvocationInterlockPlacementPass::removeUnneededInstructions( BasicBlock* block) { bool modified = false; if (!predecessors_after_begin_.count(block->id()) && after_begin_.count(block->id())) { // None of the previous blocks are in the critical section, but this block // is. This can only happen if this block already has at least one begin // instruction. Leave the first begin instruction, and remove any others. modified |= killDuplicateBegin(block); } else if (predecessors_after_begin_.count(block->id())) { // At least one previous block is in the critical section; remove all // begin instructions in this block. modified |= context()->KillInstructionIf( block->begin(), block->end(), [](Instruction* inst) { return inst->opcode() == spv::Op::OpBeginInvocationInterlockEXT; }); } if (!successors_before_end_.count(block->id()) && before_end_.count(block->id())) { // Same as above modified |= killDuplicateEnd(block); } else if (successors_before_end_.count(block->id())) { modified |= context()->KillInstructionIf( block->begin(), block->end(), [](Instruction* inst) { return inst->opcode() == spv::Op::OpEndInvocationInterlockEXT; }); } return modified; } BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block, uint32_t succ_id) { // Create a new block to replace the critical edge. auto new_succ_temp = MakeUnique( MakeUnique(context(), spv::Op::OpLabel, 0, TakeNextId(), std::initializer_list{})); auto* new_succ = new_succ_temp.get(); // Insert the new block into the function. block->GetParent()->InsertBasicBlockAfter(std::move(new_succ_temp), block); new_succ->AddInstruction(MakeUnique( context(), spv::Op::OpBranch, 0, 0, std::initializer_list{ Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {succ_id})})); assert(block->tail()->opcode() == spv::Op::OpBranchConditional || block->tail()->opcode() == spv::Op::OpSwitch); // Update the first branch to successor to instead branch to // the new successor. If there are multiple edges, we arbitrarily choose the // first time it appears in the list. The other edges to `succ_id` will have // to be split by another call to `splitEdge`. block->tail()->WhileEachInId([new_succ, succ_id](uint32_t* branch_id) { if (*branch_id == succ_id) { *branch_id = new_succ->id(); return false; } return true; }); return new_succ; } bool InvocationInterlockPlacementPass::placeInstructionsForEdge( BasicBlock* block, uint32_t next_id, BlockSet& inside, BlockSet& previous_inside, spv::Op opcode, bool reverse_cfg) { bool modified = false; if (previous_inside.count(next_id) && !inside.count(block->id())) { // This block is not in the critical section but the next has at least one // other previous block that is, so this block should be enter it as well. // We need to add begin or end instructions to the edge. modified = true; if (hasSingleNextBlock(block->id(), reverse_cfg)) { // This is the only next block. // Additionally, because `next_id` is in `previous_inside`, we know that // `next_id` has at least one previous block in `inside`. And because // 'block` is not in `inside`, that means the `next_id` has to have at // least one other previous block in `inside`. // This is solely for a debug assertion. It is essentially recomputing the // value of `previous_inside` to verify that it was computed correctly // such that the above statement is true. bool next_has_previous_inside = false; // By passing !reverse_cfg to forEachNext, we are actually iterating over // the previous blocks. forEachNext(next_id, !reverse_cfg, [&next_has_previous_inside, inside](uint32_t previous_id) { if (inside.count(previous_id)) { next_has_previous_inside = true; } }); assert(next_has_previous_inside && "`previous_inside` must be the set of blocks with at least one " "previous block in `inside`"); addInstructionAtBlockBoundary(block, opcode, reverse_cfg); } else { // This block has multiple next blocks. Split the edge and insert the // instruction in the new next block. BasicBlock* new_branch; if (reverse_cfg) { new_branch = splitEdge(block, next_id); } else { new_branch = splitEdge(cfg()->block(next_id), block->id()); } auto inst = new Instruction(context(), opcode); inst->InsertBefore(&*new_branch->tail()); } } return modified; } bool InvocationInterlockPlacementPass::placeInstructions(BasicBlock* block) { bool modified = false; block->ForEachSuccessorLabel([this, block, &modified](uint32_t succ_id) { modified |= placeInstructionsForEdge( block, succ_id, after_begin_, predecessors_after_begin_, spv::Op::OpBeginInvocationInterlockEXT, /* reverse_cfg= */ true); modified |= placeInstructionsForEdge(cfg()->block(succ_id), block->id(), before_end_, successors_before_end_, spv::Op::OpEndInvocationInterlockEXT, /* reverse_cfg= */ false); }); return modified; } bool InvocationInterlockPlacementPass::processFragmentShaderEntry( Function* entry_func) { bool modified = false; // Save the original order of blocks in the function, so we don't iterate over // newly-added blocks. std::vector original_blocks; for (auto bi = entry_func->begin(); bi != entry_func->end(); ++bi) { original_blocks.push_back(&*bi); } modified |= extractInstructionsFromCalls(original_blocks); recordExistingBeginAndEndBlock(original_blocks); after_begin_ = computeReachableBlocks(predecessors_after_begin_, begin_, /* reverse_cfg= */ true); before_end_ = computeReachableBlocks(successors_before_end_, end_, /* reverse_cfg= */ false); for (BasicBlock* block : original_blocks) { modified |= removeUnneededInstructions(block); modified |= placeInstructions(block); } return modified; } bool InvocationInterlockPlacementPass::isFragmentShaderInterlockEnabled() { if (!context()->get_feature_mgr()->HasExtension( kSPV_EXT_fragment_shader_interlock)) { return false; } if (context()->get_feature_mgr()->HasCapability( spv::Capability::FragmentShaderSampleInterlockEXT)) { return true; } if (context()->get_feature_mgr()->HasCapability( spv::Capability::FragmentShaderPixelInterlockEXT)) { return true; } if (context()->get_feature_mgr()->HasCapability( spv::Capability::FragmentShaderShadingRateInterlockEXT)) { return true; } return false; } Pass::Status InvocationInterlockPlacementPass::Process() { // Skip this pass if the necessary extension or capability is missing if (!isFragmentShaderInterlockEnabled()) { return Status::SuccessWithoutChange; } bool modified = false; std::unordered_set entry_points; for (Instruction& entry_inst : context()->module()->entry_points()) { uint32_t entry_id = entry_inst.GetSingleWordInOperand(kEntryPointFunctionIdInIdx); entry_points.insert(context()->GetFunction(entry_id)); } for (auto fi = context()->module()->begin(); fi != context()->module()->end(); ++fi) { Function* func = &*fi; recordBeginOrEndInFunction(func); if (!entry_points.count(func) && extracted_functions_.count(func)) { modified |= removeBeginAndEndInstructionsFromFunction(func); } } for (Instruction& entry_inst : context()->module()->entry_points()) { uint32_t entry_id = entry_inst.GetSingleWordInOperand(kEntryPointFunctionIdInIdx); Function* entry_func = context()->GetFunction(entry_id); auto execution_model = spv::ExecutionModel( entry_inst.GetSingleWordInOperand(kEntryPointExecutionModelInIdx)); if (execution_model != spv::ExecutionModel::Fragment) { continue; } modified |= processFragmentShaderEntry(entry_func); } return modified ? Pass::Status::SuccessWithChange : Pass::Status::SuccessWithoutChange; } } // namespace opt } // namespace spvtools