// Copyright (c) 2021 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/lint/divergence_analysis.h" #include "source/opt/basic_block.h" #include "source/opt/control_dependence.h" #include "source/opt/dataflow.h" #include "source/opt/function.h" #include "source/opt/instruction.h" #include "spirv/unified1/spirv.h" namespace spvtools { namespace lint { void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) { // Enqueue control dependents of block, if applicable. // There are two ways for a dependence source to be updated: // 1. control -> control: source block is marked divergent. // 2. data -> control: branch condition is marked divergent. uint32_t block_id; if (inst->IsBlockTerminator()) { block_id = context().get_instr_block(inst)->id(); } else if (inst->opcode() == SpvOpLabel) { block_id = inst->result_id(); opt::BasicBlock* bb = context().cfg()->block(block_id); // Only enqueue phi instructions, as other uses don't affect divergence. bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); }); } else { opt::ForwardDataFlowAnalysis::EnqueueUsers(inst); return; } if (!cd_.HasBlock(block_id)) { return; } for (const spvtools::opt::ControlDependence& dep : cd_.GetDependenceTargets(block_id)) { opt::Instruction* target_inst = context().cfg()->block(dep.target_bb_id())->GetLabelInst(); Enqueue(target_inst); } } opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit( opt::Instruction* inst) { if (inst->opcode() == SpvOpLabel) { return VisitBlock(inst->result_id()); } else { return VisitInstruction(inst); } } opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) { if (!cd_.HasBlock(id)) { return opt::DataFlowAnalysis::VisitResult::kResultFixed; } DivergenceLevel& cur_level = divergence_[id]; if (cur_level == DivergenceLevel::kDivergent) { return opt::DataFlowAnalysis::VisitResult::kResultFixed; } DivergenceLevel orig = cur_level; for (const spvtools::opt::ControlDependence& dep : cd_.GetDependenceSources(id)) { if (divergence_[dep.source_bb_id()] > cur_level) { cur_level = divergence_[dep.source_bb_id()]; divergence_source_[id] = dep.source_bb_id(); } else if (dep.source_bb_id() != 0) { uint32_t condition_id = dep.GetConditionID(*context().cfg()); DivergenceLevel dep_level = divergence_[condition_id]; // Check if we are along the chain of unconditional branches starting from // the branch target. if (follow_unconditional_branches_[dep.branch_target_bb_id()] != follow_unconditional_branches_[dep.target_bb_id()]) { // We must have reconverged in order to reach this block. // Promote partially uniform to divergent. if (dep_level == DivergenceLevel::kPartiallyUniform) { dep_level = DivergenceLevel::kDivergent; } } if (dep_level > cur_level) { cur_level = dep_level; divergence_source_[id] = condition_id; divergence_dependence_source_[id] = dep.source_bb_id(); } } } return cur_level > orig ? VisitResult::kResultChanged : VisitResult::kResultFixed; } opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction( opt::Instruction* inst) { if (inst->IsBlockTerminator()) { // This is called only when the condition has changed, so return changed. return VisitResult::kResultChanged; } if (!inst->HasResultId()) { return VisitResult::kResultFixed; } uint32_t id = inst->result_id(); DivergenceLevel& cur_level = divergence_[id]; if (cur_level == DivergenceLevel::kDivergent) { return opt::DataFlowAnalysis::VisitResult::kResultFixed; } DivergenceLevel orig = cur_level; cur_level = ComputeInstructionDivergence(inst); return cur_level > orig ? VisitResult::kResultChanged : VisitResult::kResultFixed; } DivergenceAnalysis::DivergenceLevel DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) { // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId // and use that to short circuit other checks. Uniform is for subgroups which // would satisfy derivative groups too. UniformId takes a scope, so if it is // subgroup or greater it could satisfy derivative group and // Device/QueueFamily could satisfy fully uniform. uint32_t id = inst->result_id(); // Handle divergence roots. if (inst->opcode() == SpvOpFunctionParameter) { divergence_source_[id] = 0; return divergence_[id] = DivergenceLevel::kDivergent; } else if (inst->IsLoad()) { spvtools::opt::Instruction* var = inst->GetBaseAddress(); if (var->opcode() != SpvOpVariable) { // Assume divergent. divergence_source_[id] = 0; return DivergenceLevel::kDivergent; } DivergenceLevel ret = ComputeVariableDivergence(var); if (ret > DivergenceLevel::kUniform) { divergence_source_[inst->result_id()] = 0; } return divergence_[id] = ret; } // Get the maximum divergence of the operands. DivergenceLevel ret = DivergenceLevel::kUniform; inst->ForEachInId([this, inst, &ret](const uint32_t* op) { if (!op) return; if (divergence_[*op] > ret) { divergence_source_[inst->result_id()] = *op; ret = divergence_[*op]; } }); divergence_[inst->result_id()] = ret; return ret; } DivergenceAnalysis::DivergenceLevel DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) { uint32_t type_id = var->type_id(); spvtools::opt::analysis::Pointer* type = context().get_type_mgr()->GetType(type_id)->AsPointer(); assert(type != nullptr); uint32_t def_id = var->result_id(); DivergenceLevel ret; switch (type->storage_class()) { case SpvStorageClassFunction: case SpvStorageClassGeneric: case SpvStorageClassAtomicCounter: case SpvStorageClassStorageBuffer: case SpvStorageClassPhysicalStorageBuffer: case SpvStorageClassOutput: case SpvStorageClassWorkgroup: case SpvStorageClassImage: // Image atomics probably aren't uniform. case SpvStorageClassPrivate: ret = DivergenceLevel::kDivergent; break; case SpvStorageClassInput: ret = DivergenceLevel::kDivergent; // If this variable has a Flat decoration, it is partially uniform. // TODO(kuhar): Track access chain indices and also consider Flat members // of a structure. context().get_decoration_mgr()->WhileEachDecoration( def_id, SpvDecorationFlat, [&ret](const opt::Instruction&) { ret = DivergenceLevel::kPartiallyUniform; return false; }); break; case SpvStorageClassUniformConstant: // May be a storage image which is also written to; mark those as // divergent. if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) { ret = DivergenceLevel::kUniform; } else { ret = DivergenceLevel::kDivergent; } break; case SpvStorageClassUniform: case SpvStorageClassPushConstant: case SpvStorageClassCrossWorkgroup: // Not for shaders; default uniform. default: ret = DivergenceLevel::kUniform; break; } return ret; } void DivergenceAnalysis::Setup(opt::Function* function) { // TODO(kuhar): Run functions called by |function| so we can detect // reconvergence caused by multiple returns. cd_.ComputeControlDependenceGraph( *context().cfg(), *context().GetPostDominatorAnalysis(function)); context().cfg()->ForEachBlockInPostOrder( function->entry().get(), [this](const opt::BasicBlock* bb) { uint32_t id = bb->id(); if (bb->terminator() == nullptr || bb->terminator()->opcode() != SpvOpBranch) { follow_unconditional_branches_[id] = id; } else { uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0); // Target is guaranteed to have been visited before us in postorder. follow_unconditional_branches_[id] = follow_unconditional_branches_[target_id]; } }); } std::ostream& operator<<(std::ostream& os, DivergenceAnalysis::DivergenceLevel level) { switch (level) { case DivergenceAnalysis::DivergenceLevel::kUniform: return os << "uniform"; case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform: return os << "partially uniform"; case DivergenceAnalysis::DivergenceLevel::kDivergent: return os << "divergent"; default: return os << ""; } } } // namespace lint } // namespace spvtools