SPIRV-Tools/source/lint/divergence_analysis.cpp
dong-ja 937227c761
Add divergence analysis to linter (#4465)
Currently, handles promotion of divergence due to reconvergence rules, but doesn't handle "late merges" caused by a later-than-necessary declared merge block.

Co-authored-by: Jakub Kuderski <kubak@google.com>
2021-08-23 17:03:28 -04:00

246 lines
9.0 KiB
C++

// Copyright (c) 2021 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/lint/divergence_analysis.h"
#include "source/opt/basic_block.h"
#include "source/opt/control_dependence.h"
#include "source/opt/dataflow.h"
#include "source/opt/function.h"
#include "source/opt/instruction.h"
#include "spirv/unified1/spirv.h"
namespace spvtools {
namespace lint {
void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
// Enqueue control dependents of block, if applicable.
// There are two ways for a dependence source to be updated:
// 1. control -> control: source block is marked divergent.
// 2. data -> control: branch condition is marked divergent.
uint32_t block_id;
if (inst->IsBlockTerminator()) {
block_id = context().get_instr_block(inst)->id();
} else if (inst->opcode() == SpvOpLabel) {
block_id = inst->result_id();
opt::BasicBlock* bb = context().cfg()->block(block_id);
// Only enqueue phi instructions, as other uses don't affect divergence.
bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
} else {
opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
return;
}
if (!cd_.HasBlock(block_id)) {
return;
}
for (const spvtools::opt::ControlDependence& dep :
cd_.GetDependenceTargets(block_id)) {
opt::Instruction* target_inst =
context().cfg()->block(dep.target_bb_id())->GetLabelInst();
Enqueue(target_inst);
}
}
opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
opt::Instruction* inst) {
if (inst->opcode() == SpvOpLabel) {
return VisitBlock(inst->result_id());
} else {
return VisitInstruction(inst);
}
}
opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
if (!cd_.HasBlock(id)) {
return opt::DataFlowAnalysis::VisitResult::kResultFixed;
}
DivergenceLevel& cur_level = divergence_[id];
if (cur_level == DivergenceLevel::kDivergent) {
return opt::DataFlowAnalysis::VisitResult::kResultFixed;
}
DivergenceLevel orig = cur_level;
for (const spvtools::opt::ControlDependence& dep :
cd_.GetDependenceSources(id)) {
if (divergence_[dep.source_bb_id()] > cur_level) {
cur_level = divergence_[dep.source_bb_id()];
divergence_source_[id] = dep.source_bb_id();
} else if (dep.source_bb_id() != 0) {
uint32_t condition_id = dep.GetConditionID(*context().cfg());
DivergenceLevel dep_level = divergence_[condition_id];
// Check if we are along the chain of unconditional branches starting from
// the branch target.
if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
follow_unconditional_branches_[dep.target_bb_id()]) {
// We must have reconverged in order to reach this block.
// Promote partially uniform to divergent.
if (dep_level == DivergenceLevel::kPartiallyUniform) {
dep_level = DivergenceLevel::kDivergent;
}
}
if (dep_level > cur_level) {
cur_level = dep_level;
divergence_source_[id] = condition_id;
divergence_dependence_source_[id] = dep.source_bb_id();
}
}
}
return cur_level > orig ? VisitResult::kResultChanged
: VisitResult::kResultFixed;
}
opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
opt::Instruction* inst) {
if (inst->IsBlockTerminator()) {
// This is called only when the condition has changed, so return changed.
return VisitResult::kResultChanged;
}
if (!inst->HasResultId()) {
return VisitResult::kResultFixed;
}
uint32_t id = inst->result_id();
DivergenceLevel& cur_level = divergence_[id];
if (cur_level == DivergenceLevel::kDivergent) {
return opt::DataFlowAnalysis::VisitResult::kResultFixed;
}
DivergenceLevel orig = cur_level;
cur_level = ComputeInstructionDivergence(inst);
return cur_level > orig ? VisitResult::kResultChanged
: VisitResult::kResultFixed;
}
DivergenceAnalysis::DivergenceLevel
DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
// TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
// and use that to short circuit other checks. Uniform is for subgroups which
// would satisfy derivative groups too. UniformId takes a scope, so if it is
// subgroup or greater it could satisfy derivative group and
// Device/QueueFamily could satisfy fully uniform.
uint32_t id = inst->result_id();
// Handle divergence roots.
if (inst->opcode() == SpvOpFunctionParameter) {
divergence_source_[id] = 0;
return divergence_[id] = DivergenceLevel::kDivergent;
} else if (inst->IsLoad()) {
spvtools::opt::Instruction* var = inst->GetBaseAddress();
if (var->opcode() != SpvOpVariable) {
// Assume divergent.
divergence_source_[id] = 0;
return DivergenceLevel::kDivergent;
}
DivergenceLevel ret = ComputeVariableDivergence(var);
if (ret > DivergenceLevel::kUniform) {
divergence_source_[inst->result_id()] = 0;
}
return divergence_[id] = ret;
}
// Get the maximum divergence of the operands.
DivergenceLevel ret = DivergenceLevel::kUniform;
inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
if (!op) return;
if (divergence_[*op] > ret) {
divergence_source_[inst->result_id()] = *op;
ret = divergence_[*op];
}
});
divergence_[inst->result_id()] = ret;
return ret;
}
DivergenceAnalysis::DivergenceLevel
DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
uint32_t type_id = var->type_id();
spvtools::opt::analysis::Pointer* type =
context().get_type_mgr()->GetType(type_id)->AsPointer();
assert(type != nullptr);
uint32_t def_id = var->result_id();
DivergenceLevel ret;
switch (type->storage_class()) {
case SpvStorageClassFunction:
case SpvStorageClassGeneric:
case SpvStorageClassAtomicCounter:
case SpvStorageClassStorageBuffer:
case SpvStorageClassPhysicalStorageBuffer:
case SpvStorageClassOutput:
case SpvStorageClassWorkgroup:
case SpvStorageClassImage: // Image atomics probably aren't uniform.
case SpvStorageClassPrivate:
ret = DivergenceLevel::kDivergent;
break;
case SpvStorageClassInput:
ret = DivergenceLevel::kDivergent;
// If this variable has a Flat decoration, it is partially uniform.
// TODO(kuhar): Track access chain indices and also consider Flat members
// of a structure.
context().get_decoration_mgr()->WhileEachDecoration(
def_id, SpvDecorationFlat, [&ret](const opt::Instruction&) {
ret = DivergenceLevel::kPartiallyUniform;
return false;
});
break;
case SpvStorageClassUniformConstant:
// May be a storage image which is also written to; mark those as
// divergent.
if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
ret = DivergenceLevel::kUniform;
} else {
ret = DivergenceLevel::kDivergent;
}
break;
case SpvStorageClassUniform:
case SpvStorageClassPushConstant:
case SpvStorageClassCrossWorkgroup: // Not for shaders; default uniform.
default:
ret = DivergenceLevel::kUniform;
break;
}
return ret;
}
void DivergenceAnalysis::Setup(opt::Function* function) {
// TODO(kuhar): Run functions called by |function| so we can detect
// reconvergence caused by multiple returns.
cd_.ComputeControlDependenceGraph(
*context().cfg(), *context().GetPostDominatorAnalysis(function));
context().cfg()->ForEachBlockInPostOrder(
function->entry().get(), [this](const opt::BasicBlock* bb) {
uint32_t id = bb->id();
if (bb->terminator() == nullptr ||
bb->terminator()->opcode() != SpvOpBranch) {
follow_unconditional_branches_[id] = id;
} else {
uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
// Target is guaranteed to have been visited before us in postorder.
follow_unconditional_branches_[id] =
follow_unconditional_branches_[target_id];
}
});
}
std::ostream& operator<<(std::ostream& os,
DivergenceAnalysis::DivergenceLevel level) {
switch (level) {
case DivergenceAnalysis::DivergenceLevel::kUniform:
return os << "uniform";
case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
return os << "partially uniform";
case DivergenceAnalysis::DivergenceLevel::kDivergent:
return os << "divergent";
default:
return os << "<invalid divergence level>";
}
}
} // namespace lint
} // namespace spvtools