mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-13 09:50:06 +00:00
Add GetContinueBlock to loop class.
Previously, the loop class used the terms latch and continue block interchangeably. This patch splits the two and corrects and tests some uses of the old uses of GetLatchBlock.
This commit is contained in:
parent
70bb3c1cc2
commit
1c2cbaf569
@ -214,6 +214,7 @@ Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
|
|||||||
assert(context);
|
assert(context);
|
||||||
assert(dom_analysis);
|
assert(dom_analysis);
|
||||||
loop_preheader_ = FindLoopPreheader(dom_analysis);
|
loop_preheader_ = FindLoopPreheader(dom_analysis);
|
||||||
|
loop_latch_ = FindLatchBlock();
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) {
|
BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) {
|
||||||
@ -280,6 +281,11 @@ BasicBlock* Loop::GetOrCreatePreHeaderBlock() {
|
|||||||
return loop_preheader_;
|
return loop_preheader_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Loop::SetContinueBlock(BasicBlock* continue_block) {
|
||||||
|
assert(IsInsideLoop(continue_block));
|
||||||
|
loop_continue_ = continue_block;
|
||||||
|
}
|
||||||
|
|
||||||
void Loop::SetLatchBlock(BasicBlock* latch) {
|
void Loop::SetLatchBlock(BasicBlock* latch) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
assert(latch->GetParent() && "The basic block does not belong to a function");
|
assert(latch->GetParent() && "The basic block does not belong to a function");
|
||||||
@ -321,6 +327,28 @@ void Loop::SetPreHeaderBlock(BasicBlock* preheader) {
|
|||||||
loop_preheader_ = preheader;
|
loop_preheader_ = preheader;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::BasicBlock* Loop::FindLatchBlock() {
|
||||||
|
ir::CFG* cfg = context_->cfg();
|
||||||
|
|
||||||
|
opt::DominatorAnalysis* dominator_analysis =
|
||||||
|
context_->GetDominatorAnalysis(loop_header_->GetParent());
|
||||||
|
|
||||||
|
// Look at the predecessors of the loop header to find a predecessor block
|
||||||
|
// which is dominated by the loop continue target. There should only be one
|
||||||
|
// block which meets this criteria and this is the latch block, as per the
|
||||||
|
// SPIR-V spec.
|
||||||
|
for (uint32_t block_id : cfg->preds(loop_header_->id())) {
|
||||||
|
if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) {
|
||||||
|
return cfg->block(block_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(
|
||||||
|
false &&
|
||||||
|
"Every loop should have a latch block dominated by the continue target");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
|
void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
|
||||||
ir::CFG* cfg = context_->cfg();
|
ir::CFG* cfg = context_->cfg();
|
||||||
exit_blocks->clear();
|
exit_blocks->clear();
|
||||||
@ -861,9 +889,9 @@ ir::Instruction* Loop::FindConditionVariable(
|
|||||||
|
|
||||||
// And make sure that the other is the latch block.
|
// And make sure that the other is the latch block.
|
||||||
if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
|
if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
|
||||||
loop_continue_->id() &&
|
loop_latch_->id() &&
|
||||||
variable_inst->GetSingleWordInOperand(operand_label_2) !=
|
variable_inst->GetSingleWordInOperand(operand_label_2) !=
|
||||||
loop_continue_->id()) {
|
loop_latch_->id()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -53,6 +53,7 @@ class Loop {
|
|||||||
loop_continue_(nullptr),
|
loop_continue_(nullptr),
|
||||||
loop_merge_(nullptr),
|
loop_merge_(nullptr),
|
||||||
loop_preheader_(nullptr),
|
loop_preheader_(nullptr),
|
||||||
|
loop_latch_(nullptr),
|
||||||
parent_(nullptr),
|
parent_(nullptr),
|
||||||
loop_is_marked_for_removal_(false) {}
|
loop_is_marked_for_removal_(false) {}
|
||||||
|
|
||||||
@ -82,17 +83,27 @@ class Loop {
|
|||||||
merge_inst->SetInOperand(0, {GetMergeBlock()->id()});
|
merge_inst->SetInOperand(0, {GetMergeBlock()->id()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the continue target basic block. This is the block designated as
|
||||||
|
// the continue target by the OpLoopMerge instruction.
|
||||||
|
inline BasicBlock* GetContinueBlock() { return loop_continue_; }
|
||||||
|
inline const BasicBlock* GetContinueBlock() const { return loop_continue_; }
|
||||||
|
|
||||||
// Returns the latch basic block (basic block that holds the back-edge).
|
// Returns the latch basic block (basic block that holds the back-edge).
|
||||||
// These functions return nullptr if the loop is not structured (i.e. if it
|
// These functions return nullptr if the loop is not structured (i.e. if it
|
||||||
// has more than one backedge).
|
// has more than one backedge).
|
||||||
inline BasicBlock* GetLatchBlock() { return loop_continue_; }
|
inline BasicBlock* GetLatchBlock() { return loop_latch_; }
|
||||||
inline const BasicBlock* GetLatchBlock() const { return loop_continue_; }
|
inline const BasicBlock* GetLatchBlock() const { return loop_latch_; }
|
||||||
|
|
||||||
// Sets |latch| as the loop unique block branching back to the header.
|
// Sets |latch| as the loop unique block branching back to the header.
|
||||||
// A latch block must have the following properties:
|
// A latch block must have the following properties:
|
||||||
// - |latch| must be in the loop;
|
// - |latch| must be in the loop;
|
||||||
// - must be the only block branching back to the header block.
|
// - must be the only block branching back to the header block.
|
||||||
void SetLatchBlock(BasicBlock* latch);
|
void SetLatchBlock(BasicBlock* latch);
|
||||||
|
|
||||||
|
// Sets |continue_block| as the continue block of the loop. This should be the
|
||||||
|
// continue target of the OpLoopMerge and should dominate the latch block.
|
||||||
|
void SetContinueBlock(BasicBlock* continue_block);
|
||||||
|
|
||||||
// Returns the basic block which marks the end of the loop.
|
// Returns the basic block which marks the end of the loop.
|
||||||
// These functions return nullptr if the loop is not structured.
|
// These functions return nullptr if the loop is not structured.
|
||||||
inline BasicBlock* GetMergeBlock() { return loop_merge_; }
|
inline BasicBlock* GetMergeBlock() { return loop_merge_; }
|
||||||
@ -340,6 +351,12 @@ class Loop {
|
|||||||
// Returns the context associated this loop.
|
// Returns the context associated this loop.
|
||||||
IRContext* GetContext() const { return context_; }
|
IRContext* GetContext() const { return context_; }
|
||||||
|
|
||||||
|
// Looks at all the blocks with a branch to the header block to find one
|
||||||
|
// which is also dominated by the loop continue block. This block is the latch
|
||||||
|
// block. The specification mandates that this block should exist, therefore
|
||||||
|
// this function will assert if it is not found.
|
||||||
|
ir::BasicBlock* FindLatchBlock();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IRContext* context_;
|
IRContext* context_;
|
||||||
// The block which marks the start of the loop.
|
// The block which marks the start of the loop.
|
||||||
@ -354,6 +371,9 @@ class Loop {
|
|||||||
// The block immediately before the loop header.
|
// The block immediately before the loop header.
|
||||||
BasicBlock* loop_preheader_;
|
BasicBlock* loop_preheader_;
|
||||||
|
|
||||||
|
// The block containing the backedge to the loop header.
|
||||||
|
BasicBlock* loop_latch_;
|
||||||
|
|
||||||
// A parent of a loop is the loop which contains it as a nested child loop.
|
// A parent of a loop is the loop which contains it as a nested child loop.
|
||||||
Loop* parent_;
|
Loop* parent_;
|
||||||
|
|
||||||
@ -372,9 +392,9 @@ class Loop {
|
|||||||
// Returns the loop preheader if it exists, returns nullptr otherwise.
|
// Returns the loop preheader if it exists, returns nullptr otherwise.
|
||||||
BasicBlock* FindLoopPreheader(opt::DominatorAnalysis* dom_analysis);
|
BasicBlock* FindLoopPreheader(opt::DominatorAnalysis* dom_analysis);
|
||||||
|
|
||||||
// Sets |latch| as the loop unique continue block. No checks are performed
|
// Sets |latch| as the loop unique latch block. No checks are performed
|
||||||
// here.
|
// here.
|
||||||
inline void SetLatchBlockImpl(BasicBlock* latch) { loop_continue_ = latch; }
|
inline void SetLatchBlockImpl(BasicBlock* latch) { loop_latch_ = latch; }
|
||||||
// Sets |merge| as the loop merge block. No checks are performed here.
|
// Sets |merge| as the loop merge block. No checks are performed here.
|
||||||
inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
|
inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ void AddInstructionsInBlock(std::vector<ir::Instruction*>* instructions,
|
|||||||
bool LoopFusion::UsedInContinueOrConditionBlock(
|
bool LoopFusion::UsedInContinueOrConditionBlock(
|
||||||
ir::Instruction* phi_instruction, ir::Loop* loop) {
|
ir::Instruction* phi_instruction, ir::Loop* loop) {
|
||||||
auto condition_block = loop->FindConditionBlock()->id();
|
auto condition_block = loop->FindConditionBlock()->id();
|
||||||
auto continue_block = loop->GetLatchBlock()->id();
|
auto continue_block = loop->GetContinueBlock()->id();
|
||||||
auto not_used = context_->get_def_use_mgr()->WhileEachUser(
|
auto not_used = context_->get_def_use_mgr()->WhileEachUser(
|
||||||
phi_instruction,
|
phi_instruction,
|
||||||
[this, condition_block, continue_block](ir::Instruction* instruction) {
|
[this, condition_block, continue_block](ir::Instruction* instruction) {
|
||||||
@ -125,8 +125,8 @@ bool LoopFusion::AreCompatible() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check there are no continues.
|
// Check there are no continues.
|
||||||
if (context_->cfg()->preds(loop_0_->GetLatchBlock()->id()).size() != 1 ||
|
if (context_->cfg()->preds(loop_0_->GetContinueBlock()->id()).size() != 1 ||
|
||||||
context_->cfg()->preds(loop_1_->GetLatchBlock()->id()).size() != 1) {
|
context_->cfg()->preds(loop_1_->GetContinueBlock()->id()).size() != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ LoopFusion::GetLoadsAndStoresInLoop(ir::Loop* loop) {
|
|||||||
std::vector<ir::Instruction*> stores{};
|
std::vector<ir::Instruction*> stores{};
|
||||||
|
|
||||||
for (auto block_id : loop->GetBlocks()) {
|
for (auto block_id : loop->GetBlocks()) {
|
||||||
if (block_id == loop->GetLatchBlock()->id()) {
|
if (block_id == loop->GetContinueBlock()->id()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -535,8 +535,8 @@ void LoopFusion::Fuse() {
|
|||||||
// Save the pointers/ids, won't be found in the middle of doing modifications.
|
// Save the pointers/ids, won't be found in the middle of doing modifications.
|
||||||
auto header_1 = loop_1_->GetHeaderBlock()->id();
|
auto header_1 = loop_1_->GetHeaderBlock()->id();
|
||||||
auto condition_1 = loop_1_->FindConditionBlock()->id();
|
auto condition_1 = loop_1_->FindConditionBlock()->id();
|
||||||
auto continue_1 = loop_1_->GetLatchBlock()->id();
|
auto continue_1 = loop_1_->GetContinueBlock()->id();
|
||||||
auto continue_0 = loop_0_->GetLatchBlock()->id();
|
auto continue_0 = loop_0_->GetContinueBlock()->id();
|
||||||
auto condition_block_of_0 = loop_0_->FindConditionBlock();
|
auto condition_block_of_0 = loop_0_->FindConditionBlock();
|
||||||
|
|
||||||
// Find the blocks whose branches need updating.
|
// Find the blocks whose branches need updating.
|
||||||
@ -551,7 +551,7 @@ void LoopFusion::Fuse() {
|
|||||||
// Update the branch for the |last_block_of_loop_1| to go to the continue
|
// Update the branch for the |last_block_of_loop_1| to go to the continue
|
||||||
// block of |loop_0_|.
|
// block of |loop_0_|.
|
||||||
last_block_of_1->ForEachSuccessorLabel(
|
last_block_of_1->ForEachSuccessorLabel(
|
||||||
[this](uint32_t* succ) { *succ = loop_0_->GetLatchBlock()->id(); });
|
[this](uint32_t* succ) { *succ = loop_0_->GetContinueBlock()->id(); });
|
||||||
|
|
||||||
// Update merge block id in the header of |loop_0_| to the merge block of
|
// Update merge block id in the header of |loop_0_| to the merge block of
|
||||||
// |loop_1_|.
|
// |loop_1_|.
|
||||||
@ -594,8 +594,8 @@ void LoopFusion::Fuse() {
|
|||||||
ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(),
|
ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(),
|
||||||
loop_0_->GetPreHeaderBlock()->id());
|
loop_0_->GetPreHeaderBlock()->id());
|
||||||
|
|
||||||
ReplacePhiParentWith(i, loop_1_->GetLatchBlock()->id(),
|
ReplacePhiParentWith(i, loop_1_->GetContinueBlock()->id(),
|
||||||
loop_0_->GetLatchBlock()->id());
|
loop_0_->GetContinueBlock()->id());
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update instruction to block mapping & DefUseManager.
|
// Update instruction to block mapping & DefUseManager.
|
||||||
@ -631,7 +631,7 @@ void LoopFusion::Fuse() {
|
|||||||
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock());
|
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock());
|
||||||
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock());
|
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock());
|
||||||
AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock());
|
AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock());
|
||||||
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetLatchBlock());
|
AddInstructionsInBlock(&instr_to_delete, loop_1_->GetContinueBlock());
|
||||||
|
|
||||||
// There was an additional empty block between the loops, kill that too.
|
// There was an additional empty block between the loops, kill that too.
|
||||||
if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
|
if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
|
||||||
@ -644,18 +644,19 @@ void LoopFusion::Fuse() {
|
|||||||
cfg->ForgetBlock(loop_1_->GetPreHeaderBlock());
|
cfg->ForgetBlock(loop_1_->GetPreHeaderBlock());
|
||||||
cfg->ForgetBlock(loop_1_->GetHeaderBlock());
|
cfg->ForgetBlock(loop_1_->GetHeaderBlock());
|
||||||
cfg->ForgetBlock(loop_1_->FindConditionBlock());
|
cfg->ForgetBlock(loop_1_->FindConditionBlock());
|
||||||
cfg->ForgetBlock(loop_1_->GetLatchBlock());
|
cfg->ForgetBlock(loop_1_->GetContinueBlock());
|
||||||
|
|
||||||
if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
|
if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) {
|
||||||
cfg->ForgetBlock(loop_0_->GetMergeBlock());
|
cfg->ForgetBlock(loop_0_->GetMergeBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetLatchBlock()->id());
|
cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetContinueBlock()->id());
|
||||||
cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id());
|
cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id());
|
||||||
|
|
||||||
cfg->AddEdge(last_block_of_1->id(), loop_0_->GetLatchBlock()->id());
|
cfg->AddEdge(last_block_of_1->id(), loop_0_->GetContinueBlock()->id());
|
||||||
|
|
||||||
cfg->AddEdge(loop_0_->GetLatchBlock()->id(), loop_1_->GetHeaderBlock()->id());
|
cfg->AddEdge(loop_0_->GetContinueBlock()->id(),
|
||||||
|
loop_1_->GetHeaderBlock()->id());
|
||||||
|
|
||||||
cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id());
|
cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id());
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ static const uint32_t kLoopControlIndex = 2;
|
|||||||
struct LoopUnrollState {
|
struct LoopUnrollState {
|
||||||
LoopUnrollState()
|
LoopUnrollState()
|
||||||
: previous_phi_(nullptr),
|
: previous_phi_(nullptr),
|
||||||
previous_continue_block_(nullptr),
|
previous_latch_block_(nullptr),
|
||||||
previous_condition_block_(nullptr),
|
previous_condition_block_(nullptr),
|
||||||
new_phi(nullptr),
|
new_phi(nullptr),
|
||||||
new_continue_block(nullptr),
|
new_continue_block(nullptr),
|
||||||
@ -84,11 +84,11 @@ struct LoopUnrollState {
|
|||||||
new_header_block(nullptr) {}
|
new_header_block(nullptr) {}
|
||||||
|
|
||||||
// Initialize from the loop descriptor class.
|
// Initialize from the loop descriptor class.
|
||||||
LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block,
|
LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* latch_block,
|
||||||
ir::BasicBlock* condition,
|
ir::BasicBlock* condition,
|
||||||
std::vector<ir::Instruction*>&& phis)
|
std::vector<ir::Instruction*>&& phis)
|
||||||
: previous_phi_(induction),
|
: previous_phi_(induction),
|
||||||
previous_continue_block_(continue_block),
|
previous_latch_block_(latch_block),
|
||||||
previous_condition_block_(condition),
|
previous_condition_block_(condition),
|
||||||
new_phi(nullptr),
|
new_phi(nullptr),
|
||||||
new_continue_block(nullptr),
|
new_continue_block(nullptr),
|
||||||
@ -100,7 +100,7 @@ struct LoopUnrollState {
|
|||||||
// Swap the state so that the new nodes are now the previous nodes.
|
// Swap the state so that the new nodes are now the previous nodes.
|
||||||
void NextIterationState() {
|
void NextIterationState() {
|
||||||
previous_phi_ = new_phi;
|
previous_phi_ = new_phi;
|
||||||
previous_continue_block_ = new_continue_block;
|
previous_latch_block_ = new_latch_block;
|
||||||
previous_condition_block_ = new_condition_block;
|
previous_condition_block_ = new_condition_block;
|
||||||
previous_phis_ = std::move(new_phis_);
|
previous_phis_ = std::move(new_phis_);
|
||||||
|
|
||||||
@ -109,6 +109,7 @@ struct LoopUnrollState {
|
|||||||
new_continue_block = nullptr;
|
new_continue_block = nullptr;
|
||||||
new_condition_block = nullptr;
|
new_condition_block = nullptr;
|
||||||
new_header_block = nullptr;
|
new_header_block = nullptr;
|
||||||
|
new_latch_block = nullptr;
|
||||||
|
|
||||||
// Clear new block/instruction maps.
|
// Clear new block/instruction maps.
|
||||||
new_blocks.clear();
|
new_blocks.clear();
|
||||||
@ -123,9 +124,10 @@ struct LoopUnrollState {
|
|||||||
std::vector<ir::Instruction*> previous_phis_;
|
std::vector<ir::Instruction*> previous_phis_;
|
||||||
|
|
||||||
std::vector<ir::Instruction*> new_phis_;
|
std::vector<ir::Instruction*> new_phis_;
|
||||||
// The previous continue block. The backedge will be removed from this and
|
|
||||||
// added to the new continue block.
|
// The previous latch block. The backedge will be removed from this and
|
||||||
ir::BasicBlock* previous_continue_block_;
|
// added to the new latch block.
|
||||||
|
ir::BasicBlock* previous_latch_block_;
|
||||||
|
|
||||||
// The previous condition block. This may be folded to flatten the loop.
|
// The previous condition block. This may be folded to flatten the loop.
|
||||||
ir::BasicBlock* previous_condition_block_;
|
ir::BasicBlock* previous_condition_block_;
|
||||||
@ -142,6 +144,9 @@ struct LoopUnrollState {
|
|||||||
// The new header block.
|
// The new header block.
|
||||||
ir::BasicBlock* new_header_block;
|
ir::BasicBlock* new_header_block;
|
||||||
|
|
||||||
|
// The new latch block.
|
||||||
|
ir::BasicBlock* new_latch_block;
|
||||||
|
|
||||||
// A mapping of new block ids to the original blocks which they were copied
|
// A mapping of new block ids to the original blocks which they were copied
|
||||||
// from.
|
// from.
|
||||||
std::unordered_map<uint32_t, ir::BasicBlock*> new_blocks;
|
std::unordered_map<uint32_t, ir::BasicBlock*> new_blocks;
|
||||||
@ -546,7 +551,7 @@ void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(ir::Loop* loop) {
|
|||||||
|
|
||||||
for (size_t index = 0; index < inductions.size(); ++index) {
|
for (size_t index = 0; index < inductions.size(); ++index) {
|
||||||
uint32_t trip_step_id = GetPhiDefID(state_.previous_phis_[index],
|
uint32_t trip_step_id = GetPhiDefID(state_.previous_phis_[index],
|
||||||
state_.previous_continue_block_->id());
|
state_.previous_latch_block_->id());
|
||||||
context_->ReplaceAllUsesWith(inductions[index]->result_id(), trip_step_id);
|
context_->ReplaceAllUsesWith(inductions[index]->result_id(), trip_step_id);
|
||||||
invalidated_instructions_.push_back(inductions[index]);
|
invalidated_instructions_.push_back(inductions[index]);
|
||||||
}
|
}
|
||||||
@ -600,7 +605,7 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
|
|||||||
AssignNewResultIds(basic_block);
|
AssignNewResultIds(basic_block);
|
||||||
|
|
||||||
// If this is the continue block we are copying.
|
// If this is the continue block we are copying.
|
||||||
if (itr == loop->GetLatchBlock()) {
|
if (itr == loop->GetContinueBlock()) {
|
||||||
// Make the OpLoopMerge point to this block for the continue.
|
// Make the OpLoopMerge point to this block for the continue.
|
||||||
if (!preserve_instructions) {
|
if (!preserve_instructions) {
|
||||||
ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
|
ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
|
||||||
@ -621,6 +626,9 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If this is the latch block being copied, record it in the state.
|
||||||
|
if (itr == loop->GetLatchBlock()) state_.new_latch_block = basic_block;
|
||||||
|
|
||||||
// If this is the condition block we are copying.
|
// If this is the condition block we are copying.
|
||||||
if (itr == loop_condition_block_) {
|
if (itr == loop_condition_block_) {
|
||||||
state_.new_condition_block = basic_block;
|
state_.new_condition_block = basic_block;
|
||||||
@ -642,16 +650,16 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
|
|||||||
CopyBasicBlock(loop, itr, false);
|
CopyBasicBlock(loop, itr, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the previous continue block to point to the new header.
|
// Set the previous latch block to point to the new header.
|
||||||
ir::Instruction& continue_branch = *state_.previous_continue_block_->tail();
|
ir::Instruction& latch_branch = *state_.previous_latch_block_->tail();
|
||||||
continue_branch.SetInOperand(0, {state_.new_header_block->id()});
|
latch_branch.SetInOperand(0, {state_.new_header_block->id()});
|
||||||
|
|
||||||
// As the algorithm copies the original loop blocks exactly, the tail of the
|
// As the algorithm copies the original loop blocks exactly, the tail of the
|
||||||
// latch block on iterations after the first one will be a branch to the new
|
// latch block on iterations after the first one will be a branch to the new
|
||||||
// header and not the actual loop header. The last continue block in the loop
|
// header and not the actual loop header. The last continue block in the loop
|
||||||
// should always be a backedge to the global header.
|
// should always be a backedge to the global header.
|
||||||
ir::Instruction& new_continue_branch = *state_.new_continue_block->tail();
|
ir::Instruction& new_latch_branch = *state_.new_latch_block->tail();
|
||||||
new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()});
|
new_latch_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()});
|
||||||
|
|
||||||
std::vector<ir::Instruction*> inductions;
|
std::vector<ir::Instruction*> inductions;
|
||||||
loop->GetInductionVariables(inductions);
|
loop->GetInductionVariables(inductions);
|
||||||
@ -667,7 +675,7 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
|
|||||||
|
|
||||||
if (!state_.previous_phis_.empty()) {
|
if (!state_.previous_phis_.empty()) {
|
||||||
state_.new_inst[master_copy->result_id()] = GetPhiDefID(
|
state_.new_inst[master_copy->result_id()] = GetPhiDefID(
|
||||||
state_.previous_phis_[index], state_.previous_continue_block_->id());
|
state_.previous_phis_[index], state_.previous_latch_block_->id());
|
||||||
} else {
|
} else {
|
||||||
// Do not replace the first phi block ids.
|
// Do not replace the first phi block ids.
|
||||||
state_.new_inst[master_copy->result_id()] = master_copy->result_id();
|
state_.new_inst[master_copy->result_id()] = master_copy->result_id();
|
||||||
@ -724,7 +732,7 @@ void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) {
|
|||||||
|
|
||||||
// Remove the final backedge to the header and make it point instead to the
|
// Remove the final backedge to the header and make it point instead to the
|
||||||
// merge block.
|
// merge block.
|
||||||
state_.previous_continue_block_->tail()->SetInOperand(
|
state_.previous_latch_block_->tail()->SetInOperand(
|
||||||
0, {loop->GetMergeBlock()->id()});
|
0, {loop->GetMergeBlock()->id()});
|
||||||
|
|
||||||
// Remove all induction variables as the phis will now be invalid. Replace all
|
// Remove all induction variables as the phis will now be invalid. Replace all
|
||||||
@ -779,7 +787,8 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop,
|
|||||||
AddBlocksToLoop(new_loop);
|
AddBlocksToLoop(new_loop);
|
||||||
|
|
||||||
new_loop->SetHeaderBlock(state_.new_header_block);
|
new_loop->SetHeaderBlock(state_.new_header_block);
|
||||||
new_loop->SetLatchBlock(state_.new_continue_block);
|
new_loop->SetContinueBlock(state_.new_continue_block);
|
||||||
|
new_loop->SetLatchBlock(state_.new_latch_block);
|
||||||
new_loop->SetMergeBlock(new_merge);
|
new_loop->SetMergeBlock(new_merge);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -875,8 +884,8 @@ void LoopUnrollerUtilsImpl::LinkLastPhisToStart(ir::Loop* loop) const {
|
|||||||
for (size_t i = 0; i < inductions.size(); ++i) {
|
for (size_t i = 0; i < inductions.size(); ++i) {
|
||||||
ir::Instruction* last_phi_in_block = state_.previous_phis_[i];
|
ir::Instruction* last_phi_in_block = state_.previous_phis_[i];
|
||||||
|
|
||||||
uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_,
|
uint32_t phi_index =
|
||||||
last_phi_in_block);
|
GetPhiIndexFromLabel(state_.previous_latch_block_, last_phi_in_block);
|
||||||
uint32_t phi_variable =
|
uint32_t phi_variable =
|
||||||
last_phi_in_block->GetSingleWordInOperand(phi_index - 1);
|
last_phi_in_block->GetSingleWordInOperand(phi_index - 1);
|
||||||
uint32_t phi_label = last_phi_in_block->GetSingleWordInOperand(phi_index);
|
uint32_t phi_label = last_phi_in_block->GetSingleWordInOperand(phi_index);
|
||||||
@ -927,7 +936,7 @@ bool LoopUtils::CanPerformUnroll() {
|
|||||||
if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr))
|
if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Make sure the continue block is a unconditional branch to the header
|
// Make sure the latch block is a unconditional branch to the header
|
||||||
// block.
|
// block.
|
||||||
const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail();
|
const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail();
|
||||||
bool branching_assumption =
|
bool branching_assumption =
|
||||||
@ -949,7 +958,7 @@ bool LoopUtils::CanPerformUnroll() {
|
|||||||
|
|
||||||
// Ban continues within the loop.
|
// Ban continues within the loop.
|
||||||
const std::vector<uint32_t>& continue_block_preds =
|
const std::vector<uint32_t>& continue_block_preds =
|
||||||
context_->cfg()->preds(loop_->GetLatchBlock()->id());
|
context_->cfg()->preds(loop_->GetContinueBlock()->id());
|
||||||
if (continue_block_preds.size() != 1) {
|
if (continue_block_preds.size() != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -653,6 +653,9 @@ void LoopUtils::PopulateLoopDesc(
|
|||||||
if (old_loop->GetLatchBlock())
|
if (old_loop->GetLatchBlock())
|
||||||
new_loop->SetLatchBlock(
|
new_loop->SetLatchBlock(
|
||||||
cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id()));
|
cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id()));
|
||||||
|
if (old_loop->GetContinueBlock())
|
||||||
|
new_loop->SetContinueBlock(
|
||||||
|
cloning_result.old_to_new_bb_.at(old_loop->GetContinueBlock()->id()));
|
||||||
if (old_loop->GetMergeBlock()) {
|
if (old_loop->GetMergeBlock()) {
|
||||||
auto it =
|
auto it =
|
||||||
cloning_result.old_to_new_bb_.find(old_loop->GetMergeBlock()->id());
|
cloning_result.old_to_new_bb_.find(old_loop->GetMergeBlock()->id());
|
||||||
|
@ -297,4 +297,87 @@ TEST_F(PassClassTest, NoLoop) {
|
|||||||
EXPECT_EQ(ld.NumLoops(), 0u);
|
EXPECT_EQ(ld.NumLoops(), 0u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Generated from following GLSL with latch block artificially inserted to be
|
||||||
|
seperate from continue.
|
||||||
|
#version 430
|
||||||
|
void main(void) {
|
||||||
|
float x[10];
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
x[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
TEST_F(PassClassTest, LoopLatchNotContinue) {
|
||||||
|
const std::string text = R"(OpCapability Shader
|
||||||
|
%1 = OpExtInstImport "GLSL.std.450"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint Fragment %2 "main"
|
||||||
|
OpExecutionMode %2 OriginUpperLeft
|
||||||
|
OpSource GLSL 430
|
||||||
|
OpName %2 "main"
|
||||||
|
OpName %3 "i"
|
||||||
|
OpName %4 "x"
|
||||||
|
%5 = OpTypeVoid
|
||||||
|
%6 = OpTypeFunction %5
|
||||||
|
%7 = OpTypeInt 32 1
|
||||||
|
%8 = OpTypePointer Function %7
|
||||||
|
%9 = OpConstant %7 0
|
||||||
|
%10 = OpConstant %7 10
|
||||||
|
%11 = OpTypeBool
|
||||||
|
%12 = OpTypeFloat 32
|
||||||
|
%13 = OpTypeInt 32 0
|
||||||
|
%14 = OpConstant %13 10
|
||||||
|
%15 = OpTypeArray %12 %14
|
||||||
|
%16 = OpTypePointer Function %15
|
||||||
|
%17 = OpTypePointer Function %12
|
||||||
|
%18 = OpConstant %7 1
|
||||||
|
%2 = OpFunction %5 None %6
|
||||||
|
%19 = OpLabel
|
||||||
|
%3 = OpVariable %8 Function
|
||||||
|
%4 = OpVariable %16 Function
|
||||||
|
OpStore %3 %9
|
||||||
|
OpBranch %20
|
||||||
|
%20 = OpLabel
|
||||||
|
%21 = OpPhi %7 %9 %19 %22 %30
|
||||||
|
OpLoopMerge %24 %23 None
|
||||||
|
OpBranch %25
|
||||||
|
%25 = OpLabel
|
||||||
|
%26 = OpSLessThan %11 %21 %10
|
||||||
|
OpBranchConditional %26 %27 %24
|
||||||
|
%27 = OpLabel
|
||||||
|
%28 = OpConvertSToF %12 %21
|
||||||
|
%29 = OpAccessChain %17 %4 %21
|
||||||
|
OpStore %29 %28
|
||||||
|
OpBranch %23
|
||||||
|
%23 = OpLabel
|
||||||
|
%22 = OpIAdd %7 %21 %18
|
||||||
|
OpStore %3 %22
|
||||||
|
OpBranch %30
|
||||||
|
%30 = OpLabel
|
||||||
|
OpBranch %20
|
||||||
|
%24 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
std::unique_ptr<ir::IRContext> context =
|
||||||
|
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
|
||||||
|
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||||
|
ir::Module* module = context->module();
|
||||||
|
EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
|
||||||
|
<< text << std::endl;
|
||||||
|
const ir::Function* f = spvtest::GetFunction(module, 2);
|
||||||
|
ir::LoopDescriptor ld{f};
|
||||||
|
|
||||||
|
EXPECT_EQ(ld.NumLoops(), 1u);
|
||||||
|
|
||||||
|
ir::Loop& loop = ld.GetLoopByIndex(0u);
|
||||||
|
|
||||||
|
EXPECT_NE(loop.GetLatchBlock(), loop.GetContinueBlock());
|
||||||
|
|
||||||
|
EXPECT_EQ(loop.GetContinueBlock()->id(), 23);
|
||||||
|
EXPECT_EQ(loop.GetLatchBlock()->id(), 30);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -2796,4 +2796,203 @@ OpFunctionEnd
|
|||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Generated from following GLSL with latch block artificially inserted to be
|
||||||
|
seperate from continue.
|
||||||
|
#version 430
|
||||||
|
void main(void) {
|
||||||
|
float x[10];
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
x[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
TEST_F(PassClassTest, PartiallyUnrollLatchNotContinue) {
|
||||||
|
// clang-format off
|
||||||
|
const std::string text = R"(OpCapability Shader
|
||||||
|
%1 = OpExtInstImport "GLSL.std.450"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint Fragment %2 "main"
|
||||||
|
OpExecutionMode %2 OriginUpperLeft
|
||||||
|
OpSource GLSL 430
|
||||||
|
OpName %2 "main"
|
||||||
|
OpName %3 "i"
|
||||||
|
OpName %4 "x"
|
||||||
|
%5 = OpTypeVoid
|
||||||
|
%6 = OpTypeFunction %5
|
||||||
|
%7 = OpTypeInt 32 1
|
||||||
|
%8 = OpTypePointer Function %7
|
||||||
|
%9 = OpConstant %7 0
|
||||||
|
%10 = OpConstant %7 10
|
||||||
|
%11 = OpTypeBool
|
||||||
|
%12 = OpTypeFloat 32
|
||||||
|
%13 = OpTypeInt 32 0
|
||||||
|
%14 = OpConstant %13 10
|
||||||
|
%15 = OpTypeArray %12 %14
|
||||||
|
%16 = OpTypePointer Function %15
|
||||||
|
%17 = OpTypePointer Function %12
|
||||||
|
%18 = OpConstant %7 1
|
||||||
|
%2 = OpFunction %5 None %6
|
||||||
|
%19 = OpLabel
|
||||||
|
%3 = OpVariable %8 Function
|
||||||
|
%4 = OpVariable %16 Function
|
||||||
|
OpStore %3 %9
|
||||||
|
OpBranch %20
|
||||||
|
%20 = OpLabel
|
||||||
|
%21 = OpPhi %7 %9 %19 %22 %30
|
||||||
|
OpLoopMerge %24 %23 Unroll
|
||||||
|
OpBranch %25
|
||||||
|
%25 = OpLabel
|
||||||
|
%26 = OpSLessThan %11 %21 %10
|
||||||
|
OpBranchConditional %26 %27 %24
|
||||||
|
%27 = OpLabel
|
||||||
|
%28 = OpConvertSToF %12 %21
|
||||||
|
%29 = OpAccessChain %17 %4 %21
|
||||||
|
OpStore %29 %28
|
||||||
|
OpBranch %23
|
||||||
|
%23 = OpLabel
|
||||||
|
%22 = OpIAdd %7 %21 %18
|
||||||
|
OpStore %3 %22
|
||||||
|
OpBranch %30
|
||||||
|
%30 = OpLabel
|
||||||
|
OpBranch %20
|
||||||
|
%24 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
const std::string expected = R"(OpCapability Shader
|
||||||
|
%1 = OpExtInstImport "GLSL.std.450"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint Fragment %2 "main"
|
||||||
|
OpExecutionMode %2 OriginUpperLeft
|
||||||
|
OpSource GLSL 430
|
||||||
|
OpName %2 "main"
|
||||||
|
OpName %3 "i"
|
||||||
|
OpName %4 "x"
|
||||||
|
%5 = OpTypeVoid
|
||||||
|
%6 = OpTypeFunction %5
|
||||||
|
%7 = OpTypeInt 32 1
|
||||||
|
%8 = OpTypePointer Function %7
|
||||||
|
%9 = OpConstant %7 0
|
||||||
|
%10 = OpConstant %7 10
|
||||||
|
%11 = OpTypeBool
|
||||||
|
%12 = OpTypeFloat 32
|
||||||
|
%13 = OpTypeInt 32 0
|
||||||
|
%14 = OpConstant %13 10
|
||||||
|
%15 = OpTypeArray %12 %14
|
||||||
|
%16 = OpTypePointer Function %15
|
||||||
|
%17 = OpTypePointer Function %12
|
||||||
|
%18 = OpConstant %7 1
|
||||||
|
%63 = OpConstant %13 1
|
||||||
|
%2 = OpFunction %5 None %6
|
||||||
|
%19 = OpLabel
|
||||||
|
%3 = OpVariable %8 Function
|
||||||
|
%4 = OpVariable %16 Function
|
||||||
|
OpStore %3 %9
|
||||||
|
OpBranch %20
|
||||||
|
%20 = OpLabel
|
||||||
|
%21 = OpPhi %7 %9 %19 %22 %23
|
||||||
|
OpLoopMerge %31 %25 Unroll
|
||||||
|
OpBranch %26
|
||||||
|
%26 = OpLabel
|
||||||
|
%27 = OpSLessThan %11 %21 %63
|
||||||
|
OpBranchConditional %27 %28 %31
|
||||||
|
%28 = OpLabel
|
||||||
|
%29 = OpConvertSToF %12 %21
|
||||||
|
%30 = OpAccessChain %17 %4 %21
|
||||||
|
OpStore %30 %29
|
||||||
|
OpBranch %25
|
||||||
|
%25 = OpLabel
|
||||||
|
%22 = OpIAdd %7 %21 %18
|
||||||
|
OpStore %3 %22
|
||||||
|
OpBranch %23
|
||||||
|
%23 = OpLabel
|
||||||
|
OpBranch %20
|
||||||
|
%31 = OpLabel
|
||||||
|
OpBranch %32
|
||||||
|
%32 = OpLabel
|
||||||
|
%33 = OpPhi %7 %21 %31 %61 %62
|
||||||
|
OpLoopMerge %42 %60 DontUnroll
|
||||||
|
OpBranch %34
|
||||||
|
%34 = OpLabel
|
||||||
|
%35 = OpSLessThan %11 %33 %10
|
||||||
|
OpBranchConditional %35 %36 %42
|
||||||
|
%36 = OpLabel
|
||||||
|
%37 = OpConvertSToF %12 %33
|
||||||
|
%38 = OpAccessChain %17 %4 %33
|
||||||
|
OpStore %38 %37
|
||||||
|
OpBranch %39
|
||||||
|
%39 = OpLabel
|
||||||
|
%40 = OpIAdd %7 %33 %18
|
||||||
|
OpStore %3 %40
|
||||||
|
OpBranch %41
|
||||||
|
%41 = OpLabel
|
||||||
|
OpBranch %43
|
||||||
|
%43 = OpLabel
|
||||||
|
OpBranch %45
|
||||||
|
%45 = OpLabel
|
||||||
|
%46 = OpSLessThan %11 %40 %10
|
||||||
|
OpBranch %47
|
||||||
|
%47 = OpLabel
|
||||||
|
%48 = OpConvertSToF %12 %40
|
||||||
|
%49 = OpAccessChain %17 %4 %40
|
||||||
|
OpStore %49 %48
|
||||||
|
OpBranch %50
|
||||||
|
%50 = OpLabel
|
||||||
|
%51 = OpIAdd %7 %40 %18
|
||||||
|
OpStore %3 %51
|
||||||
|
OpBranch %52
|
||||||
|
%52 = OpLabel
|
||||||
|
OpBranch %53
|
||||||
|
%53 = OpLabel
|
||||||
|
OpBranch %55
|
||||||
|
%55 = OpLabel
|
||||||
|
%56 = OpSLessThan %11 %51 %10
|
||||||
|
OpBranch %57
|
||||||
|
%57 = OpLabel
|
||||||
|
%58 = OpConvertSToF %12 %51
|
||||||
|
%59 = OpAccessChain %17 %4 %51
|
||||||
|
OpStore %59 %58
|
||||||
|
OpBranch %60
|
||||||
|
%60 = OpLabel
|
||||||
|
%61 = OpIAdd %7 %51 %18
|
||||||
|
OpStore %3 %61
|
||||||
|
OpBranch %62
|
||||||
|
%62 = OpLabel
|
||||||
|
OpBranch %32
|
||||||
|
%42 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
%24 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
// clang-format on
|
||||||
|
SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
|
||||||
|
SinglePassRunAndCheck<PartialUnrollerTestPass<3>>(text, expected, true);
|
||||||
|
|
||||||
|
// Make sure the latch block information is preserved and propagated correctly
|
||||||
|
// by the pass.
|
||||||
|
std::unique_ptr<ir::IRContext> context =
|
||||||
|
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
|
||||||
|
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||||
|
|
||||||
|
PartialUnrollerTestPass<3> unroller;
|
||||||
|
unroller.Process(context.get());
|
||||||
|
|
||||||
|
ir::Module* module = context->module();
|
||||||
|
EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
|
||||||
|
<< text << std::endl;
|
||||||
|
const ir::Function* f = spvtest::GetFunction(module, 2);
|
||||||
|
ir::LoopDescriptor ld{f};
|
||||||
|
|
||||||
|
EXPECT_EQ(ld.NumLoops(), 2u);
|
||||||
|
|
||||||
|
ir::Loop& loop_1 = ld.GetLoopByIndex(0u);
|
||||||
|
EXPECT_NE(loop_1.GetLatchBlock(), loop_1.GetContinueBlock());
|
||||||
|
|
||||||
|
ir::Loop& loop_2 = ld.GetLoopByIndex(1u);
|
||||||
|
EXPECT_NE(loop_2.GetLatchBlock(), loop_2.GetContinueBlock());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user