Add a loop peeling pass.

For each loop in a function, the pass walks the loops from inner to outer most loop
and tries to peel loop for which a certain amount of iteration can be done before or after the loop.

To limit code growth, peeling will not happen if the growth in code size goes above a configurable threshold.
This commit is contained in:
Victor Lomuller 2018-03-29 12:22:42 +01:00
parent 61b50b3bfa
commit 10e5d7cf13
15 changed files with 2471 additions and 37 deletions

View File

@ -483,6 +483,14 @@ Optimizer::PassToken CreateLocalRedundancyEliminationPass();
// the loops preheader.
Optimizer::PassToken CreateLoopInvariantCodeMotionPass();
// Creates a loop peeling pass.
// This pass will look for conditions inside a loop that are true or false only
// for the N first or last iteration. For loop with such condition, those N
// iterations of the loop will be executed outside of the main loop.
// To limit code size explosion, the loop peeling can only happen if the code
// size growth for each loop is under |code_growth_threshold|.
Optimizer::PassToken CreateLoopPeelingPass();
// Creates a loop unswitch pass.
// This pass will look for loop independent branch conditions and move the
// condition out of the loop and version the loop based on the taken branch.

View File

@ -552,12 +552,27 @@ void LoopDescriptor::PopulateList(const Function* f) {
}
ir::BasicBlock* Loop::FindConditionBlock() const {
const ir::Function& function = *loop_merge_->GetParent();
if (!loop_merge_) {
return nullptr;
}
ir::BasicBlock* condition_block = nullptr;
const opt::DominatorAnalysis* dom_analysis =
context_->GetDominatorAnalysis(&function, *context_->cfg());
ir::BasicBlock* bb = dom_analysis->ImmediateDominator(loop_merge_);
uint32_t in_loop_pred = 0;
for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) {
if (IsInsideLoop(p)) {
if (in_loop_pred) {
// 2 in-loop predecessors.
return nullptr;
}
in_loop_pred = p;
}
}
if (!in_loop_pred) {
// Merge block is unreachable.
return nullptr;
}
ir::BasicBlock* bb = context_->cfg()->block(in_loop_pred);
if (!bb) return nullptr;
@ -604,6 +619,10 @@ bool Loop::FindNumberOfIterations(const ir::Instruction* induction,
const opt::analysis::Integer* type =
upper_bound->AsIntConstant()->type()->AsInteger();
if (type->width() > 32) {
return false;
}
if (type->IsSigned()) {
condition_value = upper_bound->AsIntConstant()->GetS32BitValue();
} else {
@ -811,10 +830,10 @@ ir::Instruction* Loop::FindConditionVariable(
uint32_t operand_label_2 = 3;
// Make sure one of them is the preheader.
if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
loop_preheader_->id() &&
variable_inst->GetSingleWordInOperand(operand_label_2) !=
loop_preheader_->id()) {
if (!IsInsideLoop(
variable_inst->GetSingleWordInOperand(operand_label_1)) &&
!IsInsideLoop(
variable_inst->GetSingleWordInOperand(operand_label_2))) {
return nullptr;
}

View File

@ -338,6 +338,9 @@ class Loop {
// Returns nullptr if it can't be found.
ir::Instruction* GetConditionInst() const;
// Returns the context associated this loop.
IRContext* GetContext() const { return context_; }
private:
IRContext* context_;
// The block which marks the start of the loop.

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>
@ -23,9 +24,12 @@
#include "loop_descriptor.h"
#include "loop_peeling.h"
#include "loop_utils.h"
#include "scalar_analysis.h"
#include "scalar_analysis_nodes.h"
namespace spvtools {
namespace opt {
size_t LoopPeelingPass::code_grow_threshold_ = 1000;
void LoopPeeling::DuplicateAndConnectLoop(
LoopUtils::LoopCloningResult* clone_results) {
@ -130,7 +134,15 @@ void LoopPeeling::DuplicateAndConnectLoop(
cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock());
}
void LoopPeeling::InsertCanonicalInductionVariable() {
void LoopPeeling::InsertCanonicalInductionVariable(
LoopUtils::LoopCloningResult* clone_results) {
if (original_loop_canonical_induction_variable_) {
canonical_induction_variable_ =
context_->get_def_use_mgr()->GetDef(clone_results->value_map_.at(
original_loop_canonical_induction_variable_->result_id()));
return;
}
ir::BasicBlock::iterator insert_point =
GetClonedLoop()->GetLatchBlock()->tail();
if (GetClonedLoop()->GetLatchBlock()->GetMergeInst()) {
@ -369,7 +381,7 @@ ir::BasicBlock* LoopPeeling::CreateBlockBefore(ir::BasicBlock* bb) {
ir::IRContext::kAnalysisDefUse |
ir::IRContext::kAnalysisInstrToBlockMapping)
.AddBranch(bb->id());
cfg.AddEdge(new_bb->id(), bb->id());
cfg.RegisterBlock(new_bb.get());
// Add the basic block to the function.
ir::Function::iterator it = loop_utils_.GetFunction()->FindBlock(bb->id());
@ -407,7 +419,7 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) {
DuplicateAndConnectLoop(&clone_results);
// Add a canonical induction variable "canonical_induction_variable_".
InsertCanonicalInductionVariable();
InsertCanonicalInductionVariable(&clone_results);
InstructionBuilder builder(context_,
&*cloned_loop_->GetPreHeaderBlock()->tail(),
@ -471,7 +483,7 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
DuplicateAndConnectLoop(&clone_results);
// Add a canonical induction variable "canonical_induction_variable_".
InsertCanonicalInductionVariable();
InsertCanonicalInductionVariable(&clone_results);
InstructionBuilder builder(context_,
&*cloned_loop_->GetPreHeaderBlock()->tail(),
@ -559,5 +571,525 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
ir::IRContext::kAnalysisLoopAnalysis | ir::IRContext::kAnalysisCFG);
}
Pass::Status LoopPeelingPass::Process(ir::IRContext* c) {
InitializeProcessing(c);
bool modified = false;
ir::Module* module = c->module();
// Process each function in the module
for (ir::Function& f : *module) {
modified |= ProcessFunction(&f);
}
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
bool LoopPeelingPass::ProcessFunction(ir::Function* f) {
bool modified = false;
ir::LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
std::vector<ir::Loop*> to_process_loop;
to_process_loop.reserve(loop_descriptor.NumLoops());
for (ir::Loop& l : loop_descriptor) {
to_process_loop.push_back(&l);
}
opt::ScalarEvolutionAnalysis scev_analysis(context());
for (ir::Loop* loop : to_process_loop) {
CodeMetrics loop_size;
loop_size.Analyze(*loop);
auto try_peel = [&loop_size, &modified,
this](ir::Loop* loop_to_peel) -> ir::Loop* {
if (!loop_to_peel->IsLCSSA()) {
LoopUtils(context(), loop_to_peel).MakeLoopClosedSSA();
}
bool peeled_loop;
ir::Loop* still_peelable_loop;
std::tie(peeled_loop, still_peelable_loop) =
ProcessLoop(loop_to_peel, &loop_size);
if (peeled_loop) {
modified = true;
}
return still_peelable_loop;
};
ir::Loop* still_peelable_loop = try_peel(loop);
// The pass is working out the maximum factor by which a loop can be peeled.
// If the loop can potentially be peeled again, then there is only one
// possible direction, so only one call is still needed.
if (still_peelable_loop) {
try_peel(loop);
}
}
return modified;
}
std::pair<bool, ir::Loop*> LoopPeelingPass::ProcessLoop(
ir::Loop* loop, CodeMetrics* loop_size) {
opt::ScalarEvolutionAnalysis* scev_analysis =
context()->GetScalarEvolutionAnalysis();
// Default values for bailing out.
std::pair<bool, ir::Loop*> bail_out{false, nullptr};
ir::BasicBlock* exit_block = loop->FindConditionBlock();
if (!exit_block) {
return bail_out;
}
ir::Instruction* exiting_iv = loop->FindConditionVariable(exit_block);
if (!exiting_iv) {
return bail_out;
}
size_t iterations = 0;
if (!loop->FindNumberOfIterations(exiting_iv, &*exit_block->tail(),
&iterations)) {
return bail_out;
}
if (!iterations) {
return bail_out;
}
ir::Instruction* canonical_induction_variable = nullptr;
loop->GetHeaderBlock()->WhileEachPhiInst([&canonical_induction_variable,
scev_analysis,
this](ir::Instruction* insn) {
if (const SERecurrentNode* iv =
scev_analysis->AnalyzeInstruction(insn)->AsSERecurrentNode()) {
const SEConstantNode* offset = iv->GetOffset()->AsSEConstantNode();
const SEConstantNode* coeff = iv->GetCoefficient()->AsSEConstantNode();
if (offset && coeff && offset->FoldToSingleValue() == 0 &&
coeff->FoldToSingleValue() == 1) {
if (context()->get_type_mgr()->GetType(insn->type_id())->AsInteger()) {
canonical_induction_variable = insn;
return false;
}
}
}
return true;
});
bool is_signed = canonical_induction_variable
? context()
->get_type_mgr()
->GetType(canonical_induction_variable->type_id())
->AsInteger()
->IsSigned()
: false;
LoopPeeling peeler(
loop,
InstructionBuilder(context(), loop->GetHeaderBlock(),
ir::IRContext::kAnalysisDefUse |
ir::IRContext::kAnalysisInstrToBlockMapping)
.Add32BitConstantInteger<uint32_t>(static_cast<uint32_t>(iterations),
is_signed),
canonical_induction_variable);
if (!peeler.CanPeelLoop()) {
return bail_out;
}
// For each basic block in the loop, check if it can be peeled. If it
// can, get the direction (before/after) and by which factor.
LoopPeelingInfo peel_info(loop, iterations, scev_analysis);
uint32_t peel_before_factor = 0;
uint32_t peel_after_factor = 0;
for (uint32_t block : loop->GetBlocks()) {
if (block == exit_block->id()) {
continue;
}
ir::BasicBlock* bb = cfg()->block(block);
PeelDirection direction;
uint32_t factor;
std::tie(direction, factor) = peel_info.GetPeelingInfo(bb);
if (direction == PeelDirection::kNone) {
continue;
}
if (direction == PeelDirection::kBefore) {
peel_before_factor = std::max(peel_before_factor, factor);
} else {
assert(direction == PeelDirection::kAfter);
peel_after_factor = std::max(peel_after_factor, factor);
}
}
PeelDirection direction = PeelDirection::kNone;
uint32_t factor = 0;
// Find which direction we should peel.
if (peel_before_factor) {
factor = peel_before_factor;
direction = PeelDirection::kBefore;
}
if (peel_after_factor) {
if (peel_before_factor < peel_after_factor) {
// Favor a peel after here and give the peel before another shot later.
factor = peel_after_factor;
direction = PeelDirection::kAfter;
}
}
// Do the peel if we can.
if (direction == PeelDirection::kNone) return bail_out;
// This does not take into account branch elimination opportunities and
// the unrolling. It assumes the peeled loop will be unrolled as well.
if (factor * loop_size->roi_size_ > code_grow_threshold_) {
return bail_out;
}
loop_size->roi_size_ *= factor;
// Find if a loop should be peeled again.
ir::Loop* extra_opportunity = nullptr;
if (direction == PeelDirection::kBefore) {
peeler.PeelBefore(factor);
if (stats_) {
stats_->peeled_loops_.emplace_back(loop, PeelDirection::kBefore, factor);
}
if (peel_after_factor) {
// We could have peeled after, give it another try.
extra_opportunity = peeler.GetOriginalLoop();
}
} else {
peeler.PeelAfter(factor);
if (stats_) {
stats_->peeled_loops_.emplace_back(loop, PeelDirection::kAfter, factor);
}
if (peel_before_factor) {
// We could have peeled before, give it another try.
extra_opportunity = peeler.GetClonedLoop();
}
}
return {true, extra_opportunity};
}
uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstLoopInvariantOperand(
ir::Instruction* condition) const {
for (uint32_t i = 0; i < condition->NumInOperands(); i++) {
ir::BasicBlock* bb =
context_->get_instr_block(condition->GetSingleWordInOperand(i));
if (bb && loop_->IsInsideLoop(bb)) {
return condition->GetSingleWordInOperand(i);
}
}
return 0;
}
uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstNonLoopInvariantOperand(
ir::Instruction* condition) const {
for (uint32_t i = 0; i < condition->NumInOperands(); i++) {
ir::BasicBlock* bb =
context_->get_instr_block(condition->GetSingleWordInOperand(i));
if (!bb || !loop_->IsInsideLoop(bb)) {
return condition->GetSingleWordInOperand(i);
}
}
return 0;
}
static bool IsHandledCondition(SpvOp opcode) {
switch (opcode) {
case SpvOpIEqual:
case SpvOpINotEqual:
case SpvOpUGreaterThan:
case SpvOpSGreaterThan:
case SpvOpUGreaterThanEqual:
case SpvOpSGreaterThanEqual:
case SpvOpULessThan:
case SpvOpSLessThan:
case SpvOpULessThanEqual:
case SpvOpSLessThanEqual:
return true;
default:
return false;
}
}
LoopPeelingPass::LoopPeelingInfo::Direction
LoopPeelingPass::LoopPeelingInfo::GetPeelingInfo(ir::BasicBlock* bb) const {
if (bb->terminator()->opcode() != SpvOpBranchConditional) {
return GetNoneDirection();
}
opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
ir::Instruction* condition =
def_use_mgr->GetDef(bb->terminator()->GetSingleWordInOperand(0));
if (!IsHandledCondition(condition->opcode())) {
return GetNoneDirection();
}
if (!GetFirstLoopInvariantOperand(condition)) {
// No loop invariant, it cannot be peeled by this pass.
return GetNoneDirection();
}
if (!GetFirstNonLoopInvariantOperand(condition)) {
// Seems to be a job for the unswitch pass.
return GetNoneDirection();
}
// Left hand-side.
SExpression lhs = scev_analysis_->AnalyzeInstruction(
def_use_mgr->GetDef(condition->GetSingleWordInOperand(0)));
if (lhs->GetType() == SENode::CanNotCompute) {
// Can't make any conclusion.
return GetNoneDirection();
}
// Right hand-side.
SExpression rhs = scev_analysis_->AnalyzeInstruction(
def_use_mgr->GetDef(condition->GetSingleWordInOperand(1)));
if (rhs->GetType() == SENode::CanNotCompute) {
// Can't make any conclusion.
return GetNoneDirection();
}
// Only take into account recurrent expression over the current loop.
bool is_lhs_rec = !scev_analysis_->IsLoopInvariant(loop_, lhs);
bool is_rhs_rec = !scev_analysis_->IsLoopInvariant(loop_, rhs);
if ((is_lhs_rec && is_rhs_rec) || (!is_lhs_rec && !is_rhs_rec)) {
return GetNoneDirection();
}
if (is_lhs_rec) {
if (!lhs->AsSERecurrentNode() ||
lhs->AsSERecurrentNode()->GetLoop() != loop_) {
return GetNoneDirection();
}
}
if (is_rhs_rec) {
if (!rhs->AsSERecurrentNode() ||
rhs->AsSERecurrentNode()->GetLoop() != loop_) {
return GetNoneDirection();
}
}
// If the op code is ==, then we try a peel before or after.
// If opcode is not <, >, <= or >=, we bail out.
//
// For the remaining cases, we canonicalize the expression so that the
// constant expression is on the left hand side and the recurring expression
// is on the right hand side. If we swap hand side, then < becomes >, <=
// becomes >= etc.
// If the opcode is <=, then we add 1 to the right hand side and do the peel
// check on <.
// If the opcode is >=, then we add 1 to the left hand side and do the peel
// check on >.
CmpOperator cmp_operator;
switch (condition->opcode()) {
default:
return GetNoneDirection();
case SpvOpIEqual:
case SpvOpINotEqual:
return HandleEquality(lhs, rhs);
case SpvOpUGreaterThan:
case SpvOpSGreaterThan: {
cmp_operator = CmpOperator::kGT;
break;
}
case SpvOpULessThan:
case SpvOpSLessThan: {
cmp_operator = CmpOperator::kLT;
break;
}
// We add one to transform >= into > and <= into <.
case SpvOpUGreaterThanEqual:
case SpvOpSGreaterThanEqual: {
cmp_operator = CmpOperator::kGE;
break;
}
case SpvOpULessThanEqual:
case SpvOpSLessThanEqual: {
cmp_operator = CmpOperator::kLE;
break;
}
}
// Force the left hand side to be the non recurring expression.
if (is_lhs_rec) {
std::swap(lhs, rhs);
switch (cmp_operator) {
case CmpOperator::kLT: {
cmp_operator = CmpOperator::kGT;
break;
}
case CmpOperator::kGT: {
cmp_operator = CmpOperator::kLT;
break;
}
case CmpOperator::kLE: {
cmp_operator = CmpOperator::kGE;
break;
}
case CmpOperator::kGE: {
cmp_operator = CmpOperator::kLE;
break;
}
}
}
return HandleInequality(cmp_operator, lhs, rhs->AsSERecurrentNode());
}
SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtFirstIteration(
SERecurrentNode* rec) const {
return rec->GetOffset();
}
SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtIteration(
SERecurrentNode* rec, int64_t iteration) const {
SExpression coeff = rec->GetCoefficient();
SExpression offset = rec->GetOffset();
return (coeff * iteration) + offset;
}
SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtLastIteration(
SERecurrentNode* rec) const {
return GetValueAtIteration(rec, loop_max_iterations_ - 1);
}
bool LoopPeelingPass::LoopPeelingInfo::EvalOperator(CmpOperator cmp_op,
SExpression lhs,
SExpression rhs,
bool* result) const {
assert(scev_analysis_->IsLoopInvariant(loop_, lhs));
assert(scev_analysis_->IsLoopInvariant(loop_, rhs));
// We perform the test: 0 cmp_op rhs - lhs
// What is left is then to determine the sign of the expression.
switch (cmp_op) {
case CmpOperator::kLT: {
return scev_analysis_->IsAlwaysGreaterThanZero(rhs - lhs, result);
}
case CmpOperator::kGT: {
return scev_analysis_->IsAlwaysGreaterThanZero(lhs - rhs, result);
}
case CmpOperator::kLE: {
return scev_analysis_->IsAlwaysGreaterOrEqualToZero(rhs - lhs, result);
}
case CmpOperator::kGE: {
return scev_analysis_->IsAlwaysGreaterOrEqualToZero(lhs - rhs, result);
}
}
return false;
}
LoopPeelingPass::LoopPeelingInfo::Direction
LoopPeelingPass::LoopPeelingInfo::HandleEquality(SExpression lhs,
SExpression rhs) const {
{
// Try peel before opportunity.
SExpression lhs_cst = lhs;
if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) {
lhs_cst = rec_node->GetOffset();
}
SExpression rhs_cst = rhs;
if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) {
rhs_cst = rec_node->GetOffset();
}
if (lhs_cst == rhs_cst) {
return Direction{LoopPeelingPass::PeelDirection::kBefore, 1};
}
}
{
// Try peel after opportunity.
SExpression lhs_cst = lhs;
if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) {
// rec_node(x) = a * x + b
// assign to lhs: a * (loop_max_iterations_ - 1) + b
lhs_cst = GetValueAtLastIteration(rec_node);
}
SExpression rhs_cst = rhs;
if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) {
// rec_node(x) = a * x + b
// assign to lhs: a * (loop_max_iterations_ - 1) + b
rhs_cst = GetValueAtLastIteration(rec_node);
}
if (lhs_cst == rhs_cst) {
return Direction{LoopPeelingPass::PeelDirection::kAfter, 1};
}
}
return GetNoneDirection();
}
LoopPeelingPass::LoopPeelingInfo::Direction
LoopPeelingPass::LoopPeelingInfo::HandleInequality(CmpOperator cmp_op,
SExpression lhs,
SERecurrentNode* rhs) const {
SExpression offset = rhs->GetOffset();
SExpression coefficient = rhs->GetCoefficient();
// Compute (cst - B) / A.
std::pair<SExpression, int64_t> flip_iteration = (lhs - offset) / coefficient;
if (!flip_iteration.first->AsSEConstantNode()) {
return GetNoneDirection();
}
// note: !!flip_iteration.second normalize to 0/1 (via bool cast).
int64_t iteration =
flip_iteration.first->AsSEConstantNode()->FoldToSingleValue() +
!!flip_iteration.second;
if (iteration <= 0 ||
loop_max_iterations_ <= static_cast<uint64_t>(iteration)) {
// Always true or false within the loop bounds.
return GetNoneDirection();
}
// If this is a <= or >= operator and the iteration, make sure |iteration| is
// the one flipping the condition.
// If (cst - B) and A are not divisible, this equivalent to a < or > check, so
// we skip this test.
if (!flip_iteration.second &&
(cmp_op == CmpOperator::kLE || cmp_op == CmpOperator::kGE)) {
bool first_iteration;
bool current_iteration;
if (!EvalOperator(cmp_op, lhs, offset, &first_iteration) ||
!EvalOperator(cmp_op, lhs, GetValueAtIteration(rhs, iteration),
&current_iteration)) {
return GetNoneDirection();
}
// If the condition did not flip the next will.
if (first_iteration == current_iteration) {
iteration++;
}
}
uint32_t cast_iteration = 0;
// sanity check: can we fit |iteration| in a uint32_t ?
if (static_cast<uint64_t>(iteration) < std::numeric_limits<uint32_t>::max()) {
cast_iteration = static_cast<uint32_t>(iteration);
}
if (cast_iteration) {
// Peel before if we are closer to the start, after if closer to the end.
if (loop_max_iterations_ / 2 > cast_iteration) {
return Direction{LoopPeelingPass::PeelDirection::kBefore, cast_iteration};
} else {
return Direction{
LoopPeelingPass::PeelDirection::kAfter,
static_cast<uint32_t>(loop_max_iterations_ - cast_iteration)};
}
}
return GetNoneDirection();
}
} // namespace opt
} // namespace spvtools

View File

@ -26,6 +26,8 @@
#include "opt/ir_context.h"
#include "opt/loop_descriptor.h"
#include "opt/loop_utils.h"
#include "opt/pass.h"
#include "opt/scalar_analysis.h"
namespace spvtools {
namespace opt {
@ -61,13 +63,6 @@ namespace opt {
// - The loop must not have any ambiguous iterators updates (see
// "CanPeelLoop").
// The method "CanPeelLoop" checks that those constrained are met.
//
// FIXME(Victor): Allow the utility it accept an canonical induction variable
// rather than automatically create one.
// FIXME(Victor): When possible, evaluate the initial value of the second loop
// iterating values rather than using the exit value of the first loop.
// FIXME(Victor): Make the utility work-out the upper bound without having to
// provide it. This should become easy once the scalar evolution is in.
class LoopPeeling {
public:
// LoopPeeling constructor.
@ -75,20 +70,33 @@ class LoopPeeling {
// |loop_iteration_count| is the instruction holding the |loop| iteration
// count, must be invariant for |loop| and must be of an int 32 type (signed
// or unsigned).
LoopPeeling(ir::IRContext* context, ir::Loop* loop,
ir::Instruction* loop_iteration_count)
: context_(context),
loop_utils_(context, loop),
// |canonical_induction_variable| is an induction variable that can be used to
// count the number of iterations, must be of the same type as
// |loop_iteration_count| and start at 0 and increase by step of one at each
// iteration. The value nullptr is interpreted as no suitable variable exists
// and one will be created.
LoopPeeling(ir::Loop* loop, ir::Instruction* loop_iteration_count,
ir::Instruction* canonical_induction_variable = nullptr)
: context_(loop->GetContext()),
loop_utils_(loop->GetContext(), loop),
loop_(loop),
loop_iteration_count_(!loop->IsInsideLoop(loop_iteration_count)
? loop_iteration_count
: nullptr),
int_type_(nullptr),
original_loop_canonical_induction_variable_(
canonical_induction_variable),
canonical_induction_variable_(nullptr) {
if (loop_iteration_count_) {
int_type_ = context_->get_type_mgr()
->GetType(loop_iteration_count_->type_id())
->AsInteger();
if (canonical_induction_variable_) {
assert(canonical_induction_variable_->type_id() ==
loop_iteration_count_->type_id() &&
"loop_iteration_count and canonical_induction_variable do not "
"have the same type");
}
}
GetIteratingExitValues();
}
@ -164,11 +172,11 @@ class LoopPeeling {
// This is set to true when the exit and back-edge branch instruction is the
// same.
bool do_while_form_;
// The canonical induction variable from the original loop if it exists.
ir::Instruction* original_loop_canonical_induction_variable_;
// The canonical induction variable of the cloned loop. The induction variable
// is initialized to 0 and incremented by step of 1.
ir::Instruction* canonical_induction_variable_;
// Map between loop iterators and exit values. Loop iterators
std::unordered_map<uint32_t, ir::Instruction*> exit_value_;
@ -179,7 +187,8 @@ class LoopPeeling {
// Insert the canonical induction variable into the first loop as a simplified
// counter.
void InsertCanonicalInductionVariable();
void InsertCanonicalInductionVariable(
LoopUtils::LoopCloningResult* clone_results);
// Fixes the exit condition of the before loop. The function calls
// |condition_builder| to get the condition to use in the conditional branch
@ -217,6 +226,111 @@ class LoopPeeling {
ir::BasicBlock* if_merge);
};
// Implements a loop peeling optimization.
// For each loop, the pass will try to peel it if there is conditions that
// are true for the "N" first or last iterations of the loop.
// To avoid code size explosion, too large loops will not be peeled.
class LoopPeelingPass : public Pass {
public:
// Describes the peeling direction.
enum class PeelDirection {
kNone, // Cannot peel
kBefore, // Can peel before
kAfter // Can peel last
};
// Holds some statistics about peeled function.
struct LoopPeelingStats {
std::vector<std::tuple<const ir::Loop*, PeelDirection, uint32_t>>
peeled_loops_;
};
LoopPeelingPass(LoopPeelingStats* stats = nullptr) : stats_(stats) {}
// Sets the loop peeling growth threshold. If the code size increase is above
// |code_grow_threshold|, the loop will not be peeled. The code size is
// measured in terms of SPIR-V instructions.
static void SetLoopPeelingThreshold(size_t code_grow_threshold) {
code_grow_threshold_ = code_grow_threshold;
}
// Returns the loop peeling code growth threshold.
static size_t GetLoopPeelingThreshold() { return code_grow_threshold_; }
const char* name() const override { return "loop-peeling"; }
// Processes the given |module|. Returns Status::Failure if errors occur when
// processing. Returns the corresponding Status::Success if processing is
// succesful to indicate whether changes have been made to the modue.
Pass::Status Process(ir::IRContext* context) override;
private:
// Describes the peeling direction.
enum class CmpOperator {
kLT, // less than
kGT, // greater than
kLE, // less than or equal
kGE, // greater than or equal
};
class LoopPeelingInfo {
public:
using Direction = std::pair<PeelDirection, uint32_t>;
LoopPeelingInfo(ir::Loop* loop, size_t loop_max_iterations,
opt::ScalarEvolutionAnalysis* scev_analysis)
: context_(loop->GetContext()),
loop_(loop),
scev_analysis_(scev_analysis),
loop_max_iterations_(loop_max_iterations) {}
// Returns by how much and to which direction a loop should be peeled to
// make the conditional branch of the basic block |bb| an unconditional
// branch. If |bb|'s terminator is not a conditional branch or the condition
// is not workable then it returns PeelDirection::kNone and a 0 factor.
Direction GetPeelingInfo(ir::BasicBlock* bb) const;
private:
// Returns the id of the loop invariant operand of the conditional
// expression |condition|. It returns if no operand is invariant.
uint32_t GetFirstLoopInvariantOperand(ir::Instruction* condition) const;
// Returns the id of the non loop invariant operand of the conditional
// expression |condition|. It returns if all operands are invariant.
uint32_t GetFirstNonLoopInvariantOperand(ir::Instruction* condition) const;
// Returns the value of |rec| at the first loop iteration.
SExpression GetValueAtFirstIteration(SERecurrentNode* rec) const;
// Returns the value of |rec| at the given |iteration|.
SExpression GetValueAtIteration(SERecurrentNode* rec,
int64_t iteration) const;
// Returns the value of |rec| at the last loop iteration.
SExpression GetValueAtLastIteration(SERecurrentNode* rec) const;
bool EvalOperator(CmpOperator cmp_op, SExpression lhs, SExpression rhs,
bool* result) const;
Direction HandleEquality(SExpression lhs, SExpression rhs) const;
Direction HandleInequality(CmpOperator cmp_op, SExpression lhs,
SERecurrentNode* rhs) const;
static Direction GetNoneDirection() {
return Direction{LoopPeelingPass::PeelDirection::kNone, 0};
}
ir::IRContext* context_;
ir::Loop* loop_;
opt::ScalarEvolutionAnalysis* scev_analysis_;
size_t loop_max_iterations_;
};
// Peel profitable loops in |f|.
bool ProcessFunction(ir::Function* f);
// Peel |loop| if profitable.
std::pair<bool, ir::Loop*> ProcessLoop(ir::Loop* loop,
CodeMetrics* loop_size);
static size_t code_grow_threshold_;
LoopPeelingStats* stats_;
};
} // namespace opt
} // namespace spvtools

View File

@ -476,7 +476,6 @@ void LoopUtils::MakeLoopClosedSSA() {
}
context_->InvalidateAnalysesExceptFor(
ir::IRContext::Analysis::kAnalysisDefUse |
ir::IRContext::Analysis::kAnalysisCFG |
ir::IRContext::Analysis::kAnalysisDominatorAnalysis |
ir::IRContext::Analysis::kAnalysisLoopAnalysis);
@ -488,7 +487,6 @@ ir::Loop* LoopUtils::CloneLoop(
analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
std::unique_ptr<ir::Loop> new_loop = MakeUnique<ir::Loop>(context_);
if (loop_->HasParent()) new_loop->SetParent(loop_->GetParent());
ir::CFG& cfg = *context_->cfg();
@ -598,5 +596,26 @@ void LoopUtils::PopulateLoopDesc(
}
}
// Class to gather some metrics about a region of interest.
void CodeMetrics::Analyze(const ir::Loop& loop) {
ir::CFG& cfg = *loop.GetContext()->cfg();
roi_size_ = 0;
block_sizes_.clear();
for (uint32_t id : loop.GetBlocks()) {
const ir::BasicBlock* bb = cfg.block(id);
size_t bb_size = 0;
bb->ForEachInst([&bb_size](const ir::Instruction* insn) {
if (insn->opcode() == SpvOpLabel) return;
if (insn->IsNop()) return;
if (insn->opcode() == SpvOpPhi) return;
bb_size++;
});
block_sizes_[bb->id()] = bb_size;
roi_size_ += bb_size;
}
}
} // namespace opt
} // namespace spvtools

View File

@ -24,6 +24,19 @@ namespace spvtools {
namespace opt {
// Class to gather some metrics about a Region Of Interest (ROI).
// So far it counts the number of instructions in a ROI (excluding debug
// and label instructions) per basic block and in total.
struct CodeMetrics {
void Analyze(const ir::Loop& loop);
// The number of instructions per basic block in the ROI.
std::unordered_map<uint32_t, size_t> block_sizes_;
// Number of instruction in the ROI.
size_t roi_size_;
};
// LoopUtils is used to encapsulte loop optimizations and from the passes which
// use them. Any pass which needs a loop optimization should do it through this
// or through a pass which is using this.

View File

@ -381,6 +381,11 @@ Optimizer::PassToken CreateLoopInvariantCodeMotionPass() {
return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::LICMPass>());
}
Optimizer::PassToken CreateLoopPeelingPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::LoopPeelingPass>());
}
Optimizer::PassToken CreateLoopUnswitchPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::LoopUnswitchPass>());

View File

@ -42,6 +42,7 @@
#include "local_single_block_elim_pass.h"
#include "local_single_store_elim_pass.h"
#include "local_ssa_elim_pass.h"
#include "loop_peeling.h"
#include "loop_unroller.h"
#include "loop_unswitch_pass.h"
#include "merge_return_pass.h"

View File

@ -634,5 +634,343 @@ void SENode::DumpDot(std::ostream& out, bool recurse) const {
}
}
namespace {
class IsGreaterThanZero {
public:
explicit IsGreaterThanZero(ir::IRContext* context) : context_(context) {}
// Determine if the value of |node| is always strictly greater than zero if
// |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
// true. It returns true is the evaluation was able to conclude something, in
// which case the result is stored in |result|.
// The algorithm work by going through all the nodes and determine the
// sign of each of them.
bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
*result = false;
switch (Visit(node)) {
case Signedness::kPositiveOrNegative: {
return false;
}
case Signedness::kStrictlyNegative: {
*result = false;
break;
}
case Signedness::kNegative: {
if (!or_equal_zero) {
return false;
}
*result = false;
break;
}
case Signedness::kStrictlyPositive: {
*result = true;
break;
}
case Signedness::kPositive: {
if (!or_equal_zero) {
return false;
}
*result = true;
break;
}
}
return true;
}
private:
enum class Signedness {
kPositiveOrNegative, // Yield a value positive or negative.
kStrictlyNegative, // Yield a value strictly less than 0.
kNegative, // Yield a value less or equal to 0.
kStrictlyPositive, // Yield a value strictly greater than 0.
kPositive // Yield a value greater or equal to 0.
};
// Combine the signedness according to arithmetic rules of a given operator.
using Combiner = std::function<Signedness(Signedness, Signedness)>;
// Returns a functor to interpret the signedness of 2 expressions as if they
// were added.
Combiner GetAddCombiner() const {
return [](Signedness lhs, Signedness rhs) {
switch (lhs) {
case Signedness::kPositiveOrNegative:
break;
case Signedness::kStrictlyNegative:
if (rhs == Signedness::kStrictlyNegative ||
rhs == Signedness::kNegative)
return lhs;
break;
case Signedness::kNegative: {
if (rhs == Signedness::kStrictlyNegative)
return Signedness::kStrictlyNegative;
if (rhs == Signedness::kNegative) return Signedness::kNegative;
break;
}
case Signedness::kStrictlyPositive: {
if (rhs == Signedness::kStrictlyPositive ||
rhs == Signedness::kPositive) {
return Signedness::kStrictlyPositive;
}
break;
}
case Signedness::kPositive: {
if (rhs == Signedness::kStrictlyPositive)
return Signedness::kStrictlyPositive;
if (rhs == Signedness::kPositive) return Signedness::kPositive;
break;
}
}
return Signedness::kPositiveOrNegative;
};
}
// Returns a functor to interpret the signedness of 2 expressions as if they
// were multiplied.
Combiner GetMulCombiner() const {
return [](Signedness lhs, Signedness rhs) {
switch (lhs) {
case Signedness::kPositiveOrNegative:
break;
case Signedness::kStrictlyNegative: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative: {
return Signedness::kStrictlyPositive;
}
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive: {
return Signedness::kStrictlyNegative;
}
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
break;
}
case Signedness::kNegative: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative:
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive:
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
break;
}
case Signedness::kStrictlyPositive: {
return rhs;
}
case Signedness::kPositive: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative:
case Signedness::kNegative: {
return Signedness::kNegative;
}
case Signedness::kStrictlyPositive:
case Signedness::kPositive: {
return Signedness::kPositive;
}
}
break;
}
}
return Signedness::kPositiveOrNegative;
};
}
Signedness Visit(const SENode* node) {
switch (node->GetType()) {
case SENode::Constant:
return Visit(node->AsSEConstantNode());
break;
case SENode::RecurrentAddExpr:
return Visit(node->AsSERecurrentNode());
break;
case SENode::Negative:
return Visit(node->AsSENegative());
break;
case SENode::CanNotCompute:
return Visit(node->AsSECantCompute());
break;
case SENode::ValueUnknown:
return Visit(node->AsSEValueUnknown());
break;
case SENode::Add:
return VisitExpr(node, GetAddCombiner());
break;
case SENode::Multiply:
return VisitExpr(node, GetMulCombiner());
break;
}
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of a constant |node|.
Signedness Visit(const SEConstantNode* node) {
if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of an unknown |node| based on its type.
Signedness Visit(const SEValueUnknown* node) {
ir::Instruction* insn =
context_->get_def_use_mgr()->GetDef(node->ResultId());
analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
assert(type && "Can't retrieve a type for the instruction");
analysis::Integer* int_type = type->AsInteger();
assert(type && "Can't retrieve an integer type for the instruction");
return int_type->IsSigned() ? Signedness::kPositiveOrNegative
: Signedness::kPositive;
}
// Returns the signedness of a recurring expression.
Signedness Visit(const SERecurrentNode* node) {
Signedness coeff_sign = Visit(node->GetCoefficient());
// SERecurrentNode represent an affine expression in the range [0,
// loop_bound], so the result cannot be strictly positive or negative.
switch (coeff_sign) {
default:
break;
case Signedness::kStrictlyNegative:
coeff_sign = Signedness::kNegative;
break;
case Signedness::kStrictlyPositive:
coeff_sign = Signedness::kPositive;
break;
}
return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
}
// Returns the signedness of a negation |node|.
Signedness Visit(const SENegative* node) {
switch (Visit(*node->begin())) {
case Signedness::kPositiveOrNegative: {
return Signedness::kPositiveOrNegative;
}
case Signedness::kStrictlyNegative: {
return Signedness::kStrictlyPositive;
}
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive: {
return Signedness::kStrictlyNegative;
}
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
return Signedness::kPositiveOrNegative;
}
Signedness Visit(const SECantCompute*) {
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of a binary expression by using the combiner
// |reduce|.
Signedness VisitExpr(
const SENode* node,
std::function<Signedness(Signedness, Signedness)> reduce) {
Signedness result = Visit(*node->begin());
for (const SENode* operand : ir::make_range(++node->begin(), node->end())) {
if (result == Signedness::kPositiveOrNegative) {
return Signedness::kPositiveOrNegative;
}
result = reduce(result, Visit(operand));
}
return result;
}
ir::IRContext* context_;
};
} // namespace
bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
bool* is_gt_zero) const {
return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
}
bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
SENode* node, bool* is_ge_zero) const {
return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
}
namespace {
// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
// if |node| is not in the chain, returns the original chain.
static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
const SENode* node) {
SENode* lhs = mul->GetChildren()[0];
SENode* rhs = mul->GetChildren()[1];
if (lhs == node) {
return rhs;
}
if (rhs == node) {
return lhs;
}
if (lhs->AsSEMultiplyNode()) {
SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
if (res != lhs)
return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
}
if (rhs->AsSEMultiplyNode()) {
SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
if (res != rhs)
return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
}
return mul;
}
} // namespace
std::pair<SExpression, int64_t> SExpression::operator/(
SExpression rhs_wrapper) const {
SENode* lhs = node_;
SENode* rhs = rhs_wrapper.node_;
// Check for division by 0.
if (rhs->AsSEConstantNode() &&
!rhs->AsSEConstantNode()->FoldToSingleValue()) {
return {scev_->CreateCantComputeNode(), 0};
}
// Trivial case.
if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
return {scev_->CreateConstant(lhs_value / rhs_value),
lhs_value % rhs_value};
}
// look for a "c U / U" pattern.
if (lhs->AsSEMultiplyNode()) {
assert(lhs->GetChildren().size() == 2 &&
"More than 2 operand for a multiply node.");
SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
if (res != lhs) {
return {res, 0};
}
}
return {scev_->CreateCantComputeNode(), 0};
}
} // namespace opt
} // namespace spvtools

View File

@ -95,6 +95,15 @@ class ScalarEvolutionAnalysis {
// Checks that the graph starting from |node| is invariant to the |loop|.
bool IsLoopInvariant(const ir::Loop* loop, const SENode* node) const;
// Sets |is_gt_zero| to true if |node| represent a value always strictly
// greater than 0. The result of |is_gt_zero| is valid only if the function
// returns true.
bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const;
// Sets |is_ge_zero| to true if |node| represent a value greater or equals to
// 0. The result of |is_ge_zero| is valid only if the function returns true.
bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const;
// Find the recurrent term belonging to |loop| in the graph starting from
// |node| and return the coefficient of that recurrent term. Constant zero
// will be returned if no recurrent could be found. |node| should be in
@ -151,6 +160,140 @@ class ScalarEvolutionAnalysis {
node_cache_;
};
// Wrapping class to manipulate SENode pointer using + - * / operators.
class SExpression {
public:
// Implicit on purpose !
SExpression(SENode* node)
: node_(node->GetParentAnalysis()->SimplifyExpression(node)),
scev_(node->GetParentAnalysis()) {}
inline operator SENode*() const { return node_; }
inline SENode* operator->() const { return node_; }
const SENode& operator*() const { return *node_; }
inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const {
return scev_;
}
inline SExpression operator+(SENode* rhs) const;
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline SExpression operator+(T integer) const;
inline SExpression operator+(SExpression rhs) const;
inline SExpression operator-() const;
inline SExpression operator-(SENode* rhs) const;
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline SExpression operator-(T integer) const;
inline SExpression operator-(SExpression rhs) const;
inline SExpression operator*(SENode* rhs) const;
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline SExpression operator*(T integer) const;
inline SExpression operator*(SExpression rhs) const;
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline std::pair<SExpression, int64_t> operator/(T integer) const;
// Try to perform a division. Returns the pair <this.node_ / rhs, division
// remainder>. If it fails to simplify it, the function returns a
// CanNotCompute node.
std::pair<SExpression, int64_t> operator/(SExpression rhs) const;
private:
SENode* node_;
ScalarEvolutionAnalysis* scev_;
};
inline SExpression SExpression::operator+(SENode* rhs) const {
return scev_->CreateAddNode(node_, rhs);
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression SExpression::operator+(T integer) const {
return *this + scev_->CreateConstant(integer);
}
inline SExpression SExpression::operator+(SExpression rhs) const {
return *this + rhs.node_;
}
inline SExpression SExpression::operator-() const {
return scev_->CreateNegation(node_);
}
inline SExpression SExpression::operator-(SENode* rhs) const {
return *this + scev_->CreateNegation(rhs);
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression SExpression::operator-(T integer) const {
return *this - scev_->CreateConstant(integer);
}
inline SExpression SExpression::operator-(SExpression rhs) const {
return *this - rhs.node_;
}
inline SExpression SExpression::operator*(SENode* rhs) const {
return scev_->CreateMultiplyNode(node_, rhs);
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression SExpression::operator*(T integer) const {
return *this * scev_->CreateConstant(integer);
}
inline SExpression SExpression::operator*(SExpression rhs) const {
return *this * rhs.node_;
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline std::pair<SExpression, int64_t> SExpression::operator/(T integer) const {
return *this / scev_->CreateConstant(integer);
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression operator+(T lhs, SExpression rhs) {
return rhs + lhs;
}
inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; }
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression operator-(T lhs, SExpression rhs) {
return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} -
rhs;
}
inline SExpression operator-(SENode* lhs, SExpression rhs) {
return SExpression{lhs} - rhs;
}
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline SExpression operator*(T lhs, SExpression rhs) {
return rhs * lhs;
}
inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; }
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type>
inline std::pair<SExpression, int64_t> operator/(T lhs, SExpression rhs) {
return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} /
rhs;
}
inline std::pair<SExpression, int64_t> operator/(SENode* lhs, SExpression rhs) {
return SExpression{lhs} / rhs;
}
} // namespace opt
} // namespace spvtools
#endif // SOURCE_OPT_SCALAR_ANALYSIS_H__

View File

@ -91,6 +91,12 @@ add_spvtools_unittest(TARGET peeling_test
LIBS SPIRV-Tools-opt
)
add_spvtools_unittest(TARGET peeling_pass_test
SRCS ../function_utils.h
peeling_pass.cpp
LIBS SPIRV-Tools-opt
)
add_spvtools_unittest(TARGET loop_dependence_analysis
SRCS ../function_utils.h
dependence_analysis.cpp

View File

@ -134,7 +134,7 @@ TEST_F(PeelingTest, CannotPeel) {
loop_count = builder.Add32BitSignedIntegerConstant(10);
}
opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count);
opt::LoopPeeling peel(&*ld.begin(), loop_count);
EXPECT_FALSE(peel.CanPeelLoop());
};
{
@ -495,7 +495,7 @@ TEST_F(PeelingTest, SimplePeeling) {
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10);
opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst);
opt::LoopPeeling peel(&*ld.begin(), ten_cst);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelBefore(2);
@ -549,7 +549,7 @@ CHECK-NEXT: OpLoopMerge
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10);
opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst);
opt::LoopPeeling peel(&*ld.begin(), ten_cst);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelAfter(2);
@ -580,6 +580,114 @@ CHECK: [[AFTER_LOOP]] = OpLabel
CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]]
CHECK-NEXT: OpLoopMerge
)";
Match(check, context.get());
}
// Same as above, but reuse the induction variable.
// Peel before.
{
SCOPED_TRACE("Peel before with IV reuse");
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ir::Module* module = context->module();
EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
<< text << std::endl;
ir::Function& f = *module->begin();
ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
EXPECT_EQ(ld.NumLoops(), 1u);
opt::InstructionBuilder builder(context.get(), &*f.begin());
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10);
opt::LoopPeeling peel(&*ld.begin(), ten_cst,
context->get_def_use_mgr()->GetDef(22));
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelBefore(2);
const std::string check = R"(
CHECK: [[CST_TEN:%\w+]] = OpConstant {{%\w+}} 10
CHECK: [[CST_TWO:%\w+]] = OpConstant {{%\w+}} 2
CHECK: OpFunction
CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel
CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} [[CST_TWO]] [[CST_TEN]]
CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] [[CST_TWO]] [[CST_TEN]]
CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel
CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]]
CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None
CHECK: [[COND_BLOCK:%\w+]] = OpLabel
CHECK-NEXT: OpSLessThan
CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[i]]
CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]]
CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]]
CHECK-NEXT: OpBranch [[BEFORE_LOOP]]
CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel
CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]]
CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]]
CHECK: [[AFTER_LOOP]] = OpLabel
CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]]
CHECK-NEXT: OpLoopMerge
)";
Match(check, context.get());
}
// Peel after.
{
SCOPED_TRACE("Peel after IV reuse");
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ir::Module* module = context->module();
EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
<< text << std::endl;
ir::Function& f = *module->begin();
ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
EXPECT_EQ(ld.NumLoops(), 1u);
opt::InstructionBuilder builder(context.get(), &*f.begin());
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10);
opt::LoopPeeling peel(&*ld.begin(), ten_cst,
context->get_def_use_mgr()->GetDef(22));
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelAfter(2);
const std::string check = R"(
CHECK: OpFunction
CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel
CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}}
CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]]
CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]]
CHECK: [[BEFORE_LOOP]] = OpLabel
CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]]
CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None
CHECK: [[COND_BLOCK:%\w+]] = OpLabel
CHECK-NEXT: OpSLessThan
CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[I]] {{%\w+}}
CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]]
CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]]
CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]]
CHECK-NEXT: OpBranch [[BEFORE_LOOP]]
CHECK: [[IF_MERGE]] = OpLabel
CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]]
CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]]
CHECK: [[AFTER_LOOP]] = OpLabel
CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]]
CHECK-NEXT: OpLoopMerge
)";
Match(check, context.get());
@ -658,7 +766,7 @@ TEST_F(PeelingTest, PeelingUncountable) {
ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(16);
EXPECT_EQ(loop_count->opcode(), SpvOpLoad);
opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count);
opt::LoopPeeling peel(&*ld.begin(), loop_count);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelBefore(1);
@ -710,7 +818,7 @@ CHECK-NEXT: OpLoopMerge
ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(16);
EXPECT_EQ(loop_count->opcode(), SpvOpLoad);
opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count);
opt::LoopPeeling peel(&*ld.begin(), loop_count);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelAfter(1);
@ -811,7 +919,7 @@ TEST_F(PeelingTest, DoWhilePeeling) {
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10);
opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst);
opt::LoopPeeling peel(&*ld.begin(), ten_cst);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelBefore(2);
@ -861,7 +969,7 @@ CHECK-NEXT: OpLoopMerge
// Exit condition.
ir::Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10);
opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst);
opt::LoopPeeling peel(&*ld.begin(), ten_cst);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelAfter(2);
@ -983,7 +1091,7 @@ TEST_F(PeelingTest, PeelingLoopWithStore) {
ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(15);
EXPECT_EQ(loop_count->opcode(), SpvOpLoad);
opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count);
opt::LoopPeeling peel(&*ld.begin(), loop_count);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelBefore(1);
@ -1035,7 +1143,7 @@ CHECK-NEXT: OpLoopMerge
ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(15);
EXPECT_EQ(loop_count->opcode(), SpvOpLoad);
opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count);
opt::LoopPeeling peel(&*ld.begin(), loop_count);
EXPECT_TRUE(peel.CanPeelLoop());
peel.PeelAfter(1);

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,7 @@
#include <sstream>
#include <vector>
#include "opt/loop_peeling.h"
#include "opt/set_spec_constant_default_value_pass.h"
#include "spirv-tools/optimizer.hpp"
@ -181,6 +182,14 @@ Options (in lexicographical order):
Partially unrolls loops marked with the Unroll flag. Takes an
additional non-0 integer argument to set the unroll factor, or
how many times a loop body should be duplicated
--loop-peeling
Execute few first (respectively last) iterations before
(respectively after) the loop if it can elide some branches.
--loop-peeling-threshold
Takes a non-0 integer argument to set the loop peeling code size
growth threshold. The threshold prevents the loop peeling
from happening if the code size increase created by
the optimization is above the threshold.
--merge-blocks
Join two blocks into a single block if the second has the
first as its only predecessor. Performed only on entry point
@ -405,6 +414,20 @@ OptStatus ParseLoopUnrollPartialArg(int argc, const char** argv, int argi,
return {OPT_STOP, 1};
}
OptStatus ParseLoopPeelingThresholdArg(int argc, const char** argv, int argi) {
if (argi < argc) {
int factor = atoi(argv[argi]);
if (factor > 0) {
opt::LoopPeelingPass::SetLoopPeelingThreshold(factor);
return {OPT_CONTINUE, 0};
}
}
fprintf(
stderr,
"error: --loop-peeling-threshold must be followed by a non-0 integer\n");
return {OPT_STOP, 1};
}
// Parses command-line flags. |argc| contains the number of command-line flags.
// |argv| points to an array of strings holding the flags. |optimizer| is the
// Optimizer instance used to optimize the program.
@ -538,6 +561,13 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
if (status.action != OPT_CONTINUE) {
return status;
}
} else if (0 == strcmp(cur_arg, "--loop-peeling")) {
optimizer->RegisterPass(CreateLoopPeelingPass());
} else if (0 == strcmp(cur_arg, "--loop-peeling-threshold")) {
OptStatus status = ParseLoopPeelingThresholdArg(argc, argv, ++argi);
if (status.action != OPT_CONTINUE) {
return status;
}
} else if (0 == strcmp(cur_arg, "--skip-validation")) {
*skip_validator = true;
} else if (0 == strcmp(cur_arg, "-O")) {