diff --git a/source/opt/dead_insert_elim_pass.cpp b/source/opt/dead_insert_elim_pass.cpp index 55f4efe7b..ed0c7b8fe 100644 --- a/source/opt/dead_insert_elim_pass.cpp +++ b/source/opt/dead_insert_elim_pass.cpp @@ -65,9 +65,9 @@ uint32_t DeadInsertElimPass::NumComponents(ir::Instruction* typeInst) { } } -void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, - std::vector* pExtIndices, - uint32_t extOffset) { +void DeadInsertElimPass::MarkInsertChain( + ir::Instruction* insertChain, std::vector* pExtIndices, + uint32_t extOffset, std::unordered_set* visited_phis) { // Not currently optimizing array inserts. ir::Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id()); if (typeInst->opcode() == SpvOpTypeArray) return; @@ -84,7 +84,8 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, for (uint32_t i = 0; i < cnum; i++) { extIndices.clear(); extIndices.push_back(i); - MarkInsertChain(insertChain, &extIndices, 0); + std::unordered_set sub_visited_phis; + MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis); } return; } @@ -101,14 +102,18 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, if (pExtIndices == nullptr) { liveInserts_.insert(insInst->result_id()); uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); } // If extract indices match insert, we are done. Mark insert and // inserted object. else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) { liveInserts_.insert(insInst->result_id()); uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); break; } // If non-matching intersection, mark insert @@ -119,15 +124,18 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, uint32_t numInsertIndices = insInst->NumInOperands() - 2; if (pExtIndices->size() - extOffset > numInsertIndices) { uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices, - extOffset + numInsertIndices); + extOffset + numInsertIndices, &obj_visited_phis); break; } // If fewer extract indices than insert, also mark inserted object and // continue up chain. else { uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); } } // Get next insert in chain @@ -139,14 +147,8 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, if (insInst->opcode() != SpvOpPhi) return; // Mark phi visited to prevent potential infinite loop. If phi is already // visited, return to avoid infinite loop. - auto iter = visitedPhis_.find(insInst->result_id()); - if (iter == visitedPhis_.end()) { - iter = visitedPhis_.emplace(insInst->result_id(), true).first; - } else if (iter->second) { - return; - } else { - iter->second = true; - } + if (visited_phis->count(insInst->result_id()) != 0) return; + visited_phis->insert(insInst->result_id()); // Phis may have duplicate inputs values for different edges, prune incoming // ids lists before recursing. @@ -158,11 +160,8 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, auto new_end = std::unique(ids.begin(), ids.end()); for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) { ir::Instruction* pi = get_def_use_mgr()->GetDef(*id_iter); - MarkInsertChain(pi, pExtIndices, extOffset); + MarkInsertChain(pi, pExtIndices, extOffset, visited_phis); } - - // Unmark phi when done visiting. - iter->second = false; } bool DeadInsertElimPass::EliminateDeadInserts(ir::Function* func) { @@ -216,11 +215,12 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { ++icnt; }); // Mark all inserts in chain that intersect with extract - MarkInsertChain(&*ii, &extIndices, 0); + std::unordered_set visited_phis; + MarkInsertChain(&*ii, &extIndices, 0, &visited_phis); } break; default: { // Mark inserts in chain for all components - MarkInsertChain(&*ii, nullptr, 0); + MarkInsertChain(&*ii, nullptr, 0, nullptr); } break; } }); diff --git a/source/opt/dead_insert_elim_pass.h b/source/opt/dead_insert_elim_pass.h index 97a725d9c..f7ee46a83 100644 --- a/source/opt/dead_insert_elim_pass.h +++ b/source/opt/dead_insert_elim_pass.h @@ -50,7 +50,8 @@ class DeadInsertElimPass : public MemPass { // index at |extOffset|. Chains are composed solely of Inserts and Phis. // Mark all inserts in chain if |extIndices| is nullptr. void MarkInsertChain(ir::Instruction* insertChain, - std::vector* extIndices, uint32_t extOffset); + std::vector* extIndices, uint32_t extOffset, + std::unordered_set* visited_phis); // Perform EliminateDeadInsertsOnePass(|func|) until no modification is // made. Return true if modified.