diff --git a/Android.mk b/Android.mk index 63253966e..a6456f832 100644 --- a/Android.mk +++ b/Android.mk @@ -72,6 +72,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/insert_extract_elim.cpp \ source/opt/instruction.cpp \ source/opt/instruction_list.cpp \ + source/opt/ir_context.cpp \ source/opt/ir_loader.cpp \ source/opt/local_access_chain_convert_pass.cpp \ source/opt/local_single_block_elim_pass.cpp \ diff --git a/source/link/linker.cpp b/source/link/linker.cpp index 9d3536372..79d487270 100644 --- a/source/link/linker.cpp +++ b/source/link/linker.cpp @@ -104,11 +104,10 @@ static spv_result_t MergeModules( // applied to a single ID.) // TODO(pierremoreau): What should be the proper behaviour with built-in // symbols? -static spv_result_t GetImportExportPairs(const MessageConsumer& consumer, - const ir::IRContext& linked_context, - const DefUseManager& def_use_manager, - const DecorationManager& decoration_manager, - LinkageTable* linkings_to_do); +static spv_result_t GetImportExportPairs( + const MessageConsumer& consumer, const ir::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, LinkageTable* linkings_to_do); // Checks that for each pair of import and export, the import and export have // the same type as well as the same decorations. @@ -224,20 +223,20 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, libspirv::AssemblyGrammar grammar(impl_->context); res = MergeModules(consumer, modules, grammar, linked_module.get()); if (res != SPV_SUCCESS) return res; - ir::IRContext linked_context(std::move(linked_module)); - - DefUseManager def_use_manager(consumer, linked_context.module()); + ir::IRContext linked_context(std::move(linked_module), consumer); // Phase 4: Find the import/export pairs LinkageTable linkings_to_do; DecorationManager decoration_manager(linked_context.module()); - res = GetImportExportPairs(consumer, linked_context, def_use_manager, + res = GetImportExportPairs(consumer, linked_context, + *linked_context.get_def_use_mgr(), decoration_manager, &linkings_to_do); if (res != SPV_SUCCESS) return res; // Phase 5: Ensure the import and export have the same types and decorations. res = CheckImportExportCompatibility(consumer, linkings_to_do, - def_use_manager, decoration_manager); + *linked_context.get_def_use_mgr(), + decoration_manager); if (res != SPV_SUCCESS) return res; // Phase 6: Remove duplicates @@ -255,11 +254,9 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, if (res != SPV_SUCCESS) return res; // Phase 8: Rematch import variables/functions to export variables/functions - // TODO(pierremoreau): Keep the previous DefUseManager up-to-date - DefUseManager def_use_manager2(consumer, linked_context.module()); for (const auto& linking_entry : linkings_to_do) - def_use_manager2.ReplaceAllUsesWith(linking_entry.imported_symbol.id, - linking_entry.exported_symbol.id); + linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id, + linking_entry.exported_symbol.id); // Phase 9: Compact the IDs used in the module manager.AddPass(); @@ -476,11 +473,10 @@ static spv_result_t MergeModules( return SPV_SUCCESS; } -static spv_result_t GetImportExportPairs(const MessageConsumer& consumer, - const ir::IRContext& linked_context, - const DefUseManager& def_use_manager, - const DecorationManager& decoration_manager, - LinkageTable* linkings_to_do) { +static spv_result_t GetImportExportPairs( + const MessageConsumer& consumer, const ir::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, LinkageTable* linkings_to_do) { spv_position_t position = {}; if (linkings_to_do == nullptr) @@ -500,14 +496,16 @@ static spv_result_t GetImportExportPairs(const MessageConsumer& consumer, const SpvId id = decoration.GetSingleWordInOperand(0u); // Ignore if the targeted symbol is a built-in bool is_built_in = false; - for (const auto& id_decoration : decoration_manager.GetDecorationsFor(id, false)) { + for (const auto& id_decoration : + decoration_manager.GetDecorationsFor(id, false)) { if (id_decoration->GetSingleWordInOperand(1u) == SpvDecorationBuiltIn) { is_built_in = true; break; } } - if (is_built_in) + if (is_built_in) { continue; + } const uint32_t type = decoration.GetSingleWordInOperand(3u); diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 6be613a69..54e930aa1 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -109,7 +109,7 @@ bool AggressiveDCEPass::KillInstIfTargetDead(ir::Instruction* inst) { const uint32_t tId = inst->GetSingleWordInOperand(0); const ir::Instruction* tInst = get_def_use_mgr()->GetDef(tId); if (dead_insts_.find(tInst) != dead_insts_.end()) { - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); return true; } return false; @@ -374,7 +374,7 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { if (ii->opcode() == SpvOpSelectionMerge) mergeBlockId = ii->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); - get_def_use_mgr()->KillInst(&*ii); + context()->KillInst(&*ii); modified = true; } // If a structured if was deleted, add a branch to its merge block, diff --git a/source/opt/block_merge_pass.cpp b/source/opt/block_merge_pass.cpp index cf5f0626a..4b06a2eb3 100644 --- a/source/opt/block_merge_pass.cpp +++ b/source/opt/block_merge_pass.cpp @@ -43,11 +43,11 @@ void BlockMergePass::KillInstAndName(ir::Instruction* inst) { if (uses != nullptr) for (auto u : *uses) if (u.inst->opcode() == SpvOpName) { - get_def_use_mgr()->KillInst(u.inst); + context()->KillInst(u.inst); break; } } - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); } bool BlockMergePass::MergeBlocks(ir::Function* func) { @@ -78,7 +78,7 @@ bool BlockMergePass::MergeBlocks(ir::Function* func) { continue; } // Merge blocks - get_def_use_mgr()->KillInst(br); + context()->KillInst(br); auto sbi = bi; for (; sbi != func->end(); ++sbi) if (sbi->id() == labId) diff --git a/source/opt/common_uniform_elim_pass.cpp b/source/opt/common_uniform_elim_pass.cpp index 3e2ac0cfd..aafc5f1a0 100644 --- a/source/opt/common_uniform_elim_pass.cpp +++ b/source/opt/common_uniform_elim_pass.cpp @@ -206,7 +206,7 @@ void CommonUniformElimPass::KillNamesAndDecorates(uint32_t id) { if (op != SpvOpName && !IsNonTypeDecorate(op)) continue; killList.push_back(u.inst); } - for (auto kip : killList) get_def_use_mgr()->KillInst(kip); + for (auto kip : killList) context()->KillInst(kip); } void CommonUniformElimPass::KillNamesAndDecorates(ir::Instruction* inst) { @@ -222,7 +222,7 @@ void CommonUniformElimPass::DeleteIfUseless(ir::Instruction* inst) { assert(resId != 0); if (HasOnlyNamesAndDecorates(resId)) { KillNamesAndDecorates(resId); - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); } } @@ -231,9 +231,9 @@ void CommonUniformElimPass::ReplaceAndDeleteLoad(ir::Instruction* loadInst, ir::Instruction* ptrInst) { const uint32_t loadId = loadInst->result_id(); KillNamesAndDecorates(loadId); - (void)get_def_use_mgr()->ReplaceAllUsesWith(loadId, replId); + (void)context()->ReplaceAllUsesWith(loadId, replId); // remove load instruction - get_def_use_mgr()->KillInst(loadInst); + context()->KillInst(loadInst); // if access chain, see if it can be removed as well if (IsNonPtrAccessChain(ptrInst->opcode())) DeleteIfUseless(ptrInst); } @@ -491,8 +491,8 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { for (auto instItr : idxItr.second) { uint32_t resId = instItr->result_id(); KillNamesAndDecorates(resId); - (void)get_def_use_mgr()->ReplaceAllUsesWith(resId, replId); - get_def_use_mgr()->KillInst(instItr); + (void)context()->ReplaceAllUsesWith(resId, replId); + context()->KillInst(instItr); } modified = true; } diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index 2558c8167..ebca89e52 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -233,8 +233,8 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { const uint32_t mergeLabId = mergeInst->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); AddBranch(liveLabId, *bi); - get_def_use_mgr()->KillInst(br); - get_def_use_mgr()->KillInst(mergeInst); + context()->KillInst(br); + context()->KillInst(mergeInst); modified = true; @@ -331,8 +331,8 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { } const uint32_t phiId = pii->result_id(); KillNamesAndDecorates(phiId); - (void)get_def_use_mgr()->ReplaceAllUsesWith(phiId, replId); - get_def_use_mgr()->KillInst(&*pii); + (void)context()->ReplaceAllUsesWith(phiId, replId); + context()->KillInst(&*pii); } } diff --git a/source/opt/dead_variable_elimination.cpp b/source/opt/dead_variable_elimination.cpp index da6b982a7..4fd6b2d00 100644 --- a/source/opt/dead_variable_elimination.cpp +++ b/source/opt/dead_variable_elimination.cpp @@ -109,7 +109,7 @@ void DeadVariableElimination::DeleteVariable(uint32_t result_id) { } } this->KillNamesAndDecorates(result_id); - get_def_use_mgr()->KillDef(result_id); + context()->KillDef(result_id); } } // namespace opt } // namespace spvtools diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp index da4283a02..1444a27f8 100644 --- a/source/opt/def_use_manager.cpp +++ b/source/opt/def_use_manager.cpp @@ -51,7 +51,6 @@ void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) { case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: case SPV_OPERAND_TYPE_SCOPE_ID: { uint32_t use_id = inst->GetSingleWordOperand(i); - // use_id is used by the instruction generating def_id. id_to_uses_[use_id].push_back({ inst, i }); inst_to_used_ids_[inst].push_back(use_id); } break; @@ -102,60 +101,6 @@ std::vector DefUseManager::GetAnnotations(uint32_t id) const { return annos; } -bool DefUseManager::KillDef(uint32_t id) { - auto iter = id_to_def_.find(id); - if (iter == id_to_def_.end()) return false; - KillInst(iter->second); - return true; -} - -void DefUseManager::KillInst(ir::Instruction* inst) { - if (!inst) return; - ClearInst(inst); - inst->ToNop(); -} - -bool DefUseManager::ReplaceAllUsesWith(uint32_t before, uint32_t after) { - if (before == after) return false; - if (id_to_uses_.count(before) == 0) return false; - - for (auto it = id_to_uses_[before].cbegin(); it != id_to_uses_[before].cend(); - ++it) { - const uint32_t type_result_id_count = - (it->inst->result_id() != 0) + (it->inst->type_id() != 0); - - if (it->operand_index < type_result_id_count) { - // Update the type_id. Note that result id is immutable so it should - // never be updated. - if (it->inst->type_id() != 0 && it->operand_index == 0) { - it->inst->SetResultType(after); - } else if (it->inst->type_id() == 0) { - SPIRV_ASSERT(consumer_, false, - "Result type id considered as use while the instruction " - "doesn't have a result type id."); - (void)consumer_; // Makes the compiler happy for release build. - } else { - SPIRV_ASSERT(consumer_, false, - "Trying setting the immutable result id."); - } - } else { - // Update an in-operand. - uint32_t in_operand_pos = it->operand_index - type_result_id_count; - // Make the modification in the instruction. - it->inst->SetInOperand(in_operand_pos, {after}); - } - // Update inst to used ids map - auto iter = inst_to_used_ids_.find(it->inst); - if (iter != inst_to_used_ids_.end()) - for (auto uit = iter->second.begin(); uit != iter->second.end(); uit++) - if (*uit == before) *uit = after; - // Register the use of |after| id into id_to_uses_. - // TODO(antiagainst): de-duplication. - id_to_uses_[after].push_back({it->inst, it->operand_index}); - } - id_to_uses_.erase(before); - return true; -} void DefUseManager::AnalyzeDefUse(ir::Module* module) { if (!module) return; diff --git a/source/opt/def_use_manager.h b/source/opt/def_use_manager.h index e4d8a3e1f..d4b7fc26e 100644 --- a/source/opt/def_use_manager.h +++ b/source/opt/def_use_manager.h @@ -49,8 +49,7 @@ class DefUseManager { // will be communicated to the outside via the given message |consumer|. This // instance only keeps a reference to the |consumer|, so the |consumer| should // outlive this instance. - DefUseManager(const MessageConsumer& consumer, ir::Module* module) - : consumer_(consumer) { + DefUseManager(ir::Module* module) { AnalyzeDefUse(module); } @@ -88,29 +87,6 @@ class DefUseManager { // Returns the map from ids to their uses in instructions. const IdToUsesMap& id_to_uses() const { return id_to_uses_; } - // Turns the instruction defining the given |id| into a Nop. Returns true on - // success, false if the given |id| is not defined at all. This method also - // erases both the uses of |id| and the information of this |id|-generating - // instruction's uses of its operands. - bool KillDef(uint32_t id); - // Turns the given instruction |inst| to a Nop. This method erases the - // information of the given instruction's uses of its operands. If |inst| - // defines an result id, the uses of the result id will also be erased. - void KillInst(ir::Instruction* inst); - // Replaces all uses of |before| id with |after| id. Returns true if any - // replacement happens. This method does not kill the definition of the - // |before| id. If |after| is the same as |before|, does nothing and returns - // false. - bool ReplaceAllUsesWith(uint32_t before, uint32_t after); - - private: - using InstToUsedIdsMap = - std::unordered_map>; - - // Analyzes the defs and uses in the given |module| and populates data - // structures in this class. Does nothing if |module| is nullptr. - void AnalyzeDefUse(ir::Module* module); - // Clear the internal def-use record of the given instruction |inst|. This // method will update the use information of the operand ids of |inst|. The // record: |inst| uses an |id|, will be removed from the use records of |id|. @@ -121,11 +97,17 @@ class DefUseManager { // Erases the records that a given instruction uses its operand ids. void EraseUseRecordsOfOperandIds(const ir::Instruction* inst); - const MessageConsumer& consumer_; // Message consumer. + private: + using InstToUsedIdsMap = + std::unordered_map>; + + // Analyzes the defs and uses in the given |module| and populates data + // structures in this class. Does nothing if |module| is nullptr. + void AnalyzeDefUse(ir::Module* module); + IdToDefMap id_to_def_; // Mapping from ids to their definitions IdToUsesMap id_to_uses_; // Mapping from ids to their uses - // Mapping from instructions to the ids used in the instructions generating - // the result ids. + // Mapping from instructions to the ids used in the instruction. InstToUsedIdsMap inst_to_used_ids_; }; diff --git a/source/opt/eliminate_dead_constant_pass.cpp b/source/opt/eliminate_dead_constant_pass.cpp index c77a70141..4d33ccf17 100644 --- a/source/opt/eliminate_dead_constant_pass.cpp +++ b/source/opt/eliminate_dead_constant_pass.cpp @@ -27,7 +27,6 @@ namespace spvtools { namespace opt { Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { - analysis::DefUseManager def_use(consumer(), irContext->module()); std::unordered_set working_list; // Traverse all the instructions to get the initial set of dead constants as // working list and count number of real uses for constants. Uses in @@ -37,7 +36,7 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { for (auto* c : constants) { uint32_t const_id = c->result_id(); size_t count = 0; - if (analysis::UseList* uses = def_use.GetUses(const_id)) { + if (analysis::UseList* uses = irContext->get_def_use_mgr()->GetUses(const_id)) { count = std::count_if(uses->begin(), uses->end(), [](const analysis::Use& u) { return !(ir::IsAnnotationInst(u.inst->opcode()) || @@ -69,7 +68,7 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { continue; } uint32_t operand_id = inst->GetSingleWordInOperand(i); - ir::Instruction* def_inst = def_use.GetDef(operand_id); + ir::Instruction* def_inst = irContext->get_def_use_mgr()->GetDef(operand_id); // If the use_count does not have any count for the def_inst, // def_inst must not be a constant, and should be ignored here. if (!use_counts.count(def_inst)) { @@ -95,7 +94,7 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { // constants. std::unordered_set dead_others; for (auto* dc : dead_consts) { - if (analysis::UseList* uses = def_use.GetUses(dc->result_id())) { + if (analysis::UseList* uses = irContext->get_def_use_mgr()->GetUses(dc->result_id())) { for (const auto& u : *uses) { if (ir::IsAnnotationInst(u.inst->opcode()) || ir::IsDebug1Inst(u.inst->opcode()) || @@ -109,7 +108,7 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { // Turn all dead instructions and uses of them to nop for (auto* dc : dead_consts) { - def_use.KillDef(dc->result_id()); + irContext->KillDef(dc->result_id()); } for (auto* da : dead_others) { da->ToNop(); diff --git a/source/opt/eliminate_dead_functions_pass.cpp b/source/opt/eliminate_dead_functions_pass.cpp index 96f999964..8eae3e9a6 100644 --- a/source/opt/eliminate_dead_functions_pass.cpp +++ b/source/opt/eliminate_dead_functions_pass.cpp @@ -52,7 +52,7 @@ void EliminateDeadFunctionsPass::EliminateFunction(ir::Function* func) { func->ForEachInst( [this](ir::Instruction* inst) { KillNamesAndDecorates(inst); - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); }, true); } diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index db9891e26..ad078fd90 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -372,8 +372,8 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( // original constant. uint32_t new_id = folded_inst->result_id(); uint32_t old_id = inst->result_id(); - get_def_use_mgr()->ReplaceAllUsesWith(old_id, new_id); - get_def_use_mgr()->KillDef(old_id); + context()->ReplaceAllUsesWith(old_id, new_id); + context()->KillDef(old_id); return true; } diff --git a/source/opt/insert_extract_elim.cpp b/source/opt/insert_extract_elim.cpp index c7ff9e692..c44ec48dd 100644 --- a/source/opt/insert_extract_elim.cpp +++ b/source/opt/insert_extract_elim.cpp @@ -109,8 +109,8 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) { } if (replId != 0) { const uint32_t extId = ii->result_id(); - (void)get_def_use_mgr()->ReplaceAllUsesWith(extId, replId); - get_def_use_mgr()->KillInst(&*ii); + (void)context()->ReplaceAllUsesWith(extId, replId); + context()->KillInst(&*ii); modified = true; } } break; diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index c966eea77..85d3298e0 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -13,3 +13,83 @@ // limitations under the License. #include "ir_context.h" +#include "log.h" + +namespace spvtools { +namespace ir { + +void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { + if (set & kAnalysisDefUse) { + BuildDefUseManager(); + } +} + +void IRContext::InvalidateAnalysesExceptFor(IRContext::Analysis preserved_analyses) { + uint32_t analyses_to_invalidate = valid_analyses_ & (~preserved_analyses); + if (analyses_to_invalidate & kAnalysisDefUse) { + def_use_mgr_.reset(nullptr); + } + valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate); +} + +void IRContext::KillInst(ir::Instruction* inst) { + if (!inst) { + return; + } + + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->ClearInst(inst); + } + inst->ToNop(); +} + +bool IRContext::KillDef(uint32_t id) { + ir::Instruction* def = get_def_use_mgr()->GetDef(id); + if (def != nullptr) { + KillInst(def); + return true; + } + return false; +} + +bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) { + if (before == after) return false; + opt::analysis::UseList* uses = get_def_use_mgr()->GetUses(before); + if (uses == nullptr) return false; + + std::vector uses_to_update; + for (auto it = uses->cbegin(); it != uses->cend(); ++it) { + uses_to_update.push_back(*it); + } + + for (opt::analysis::Use& use : uses_to_update) { + get_def_use_mgr()->EraseUseRecordsOfOperandIds(use.inst); + const uint32_t type_result_id_count = + (use.inst->result_id() != 0) + (use.inst->type_id() != 0); + + if (use.operand_index < type_result_id_count) { + // Update the type_id. Note that result id is immutable so it should + // never be updated. + if (use.inst->type_id() != 0 && use.operand_index == 0) { + use.inst->SetResultType(after); + } else if (use.inst->type_id() == 0) { + SPIRV_ASSERT(consumer_, false, + "Result type id considered as use while the instruction " + "doesn't have a result type id."); + (void)consumer_; // Makes the compiler happy for release build. + } else { + SPIRV_ASSERT(consumer_, false, + "Trying setting the immutable result id."); + } + } else { + // Update an in-operand. + uint32_t in_operand_pos = use.operand_index - type_result_id_count; + // Make the modification in the instruction. + use.inst->SetInOperand(in_operand_pos, {after}); + } + get_def_use_mgr()->AnalyzeInstUse(use.inst); + } + return true; +} +} // namespace ir +} // namespace spvtools diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index d1b980077..7945a5616 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -15,6 +15,7 @@ #ifndef SPIRV_TOOLS_IR_CONTEXT_H #define SPIRV_TOOLS_IR_CONTEXT_H +#include "def_use_manager.h" #include "module.h" #include @@ -24,7 +25,23 @@ namespace ir { class IRContext { public: - IRContext(std::unique_ptr&& m) : module_(std::move(m)) {} + enum Analysis { + kAnalysisNone = 0x0, + kAnalysisBegin = 0x1, + kAnalysisDefUse = kAnalysisBegin, + kAnalysisEnd = 0x2 + }; + + friend inline Analysis operator|(Analysis lhs, Analysis rhs); + friend inline Analysis& operator|=(Analysis& lhs, Analysis rhs); + friend inline Analysis operator<<(Analysis a, int shift); + friend inline Analysis& operator<<=(Analysis& a, int shift); + + IRContext(std::unique_ptr&& m, spvtools::MessageConsumer c) + : module_(std::move(m)), + consumer_(std::move(c)), + def_use_mgr_(nullptr), + valid_analyses_(kAnalysisNone) {} Module* module() const { return module_.get(); } inline void SetIdBound(uint32_t i); @@ -121,10 +138,89 @@ class IRContext { // Appends a function to this module. inline void AddFunction(std::unique_ptr&& f); + // Returns a pointer to a def-use manager. If the def-use manager is + // invalid, it is rebuilt first. + opt::analysis::DefUseManager* get_def_use_mgr() { + if (!AreAnalysesValid(kAnalysisDefUse)) { + BuildDefUseManager(); + } + return def_use_mgr_.get(); + } + + // Builds the def-use manager from scratch, even if it was already valid. + void BuildDefUseManager() { + def_use_mgr_.reset(new opt::analysis::DefUseManager(module())); + valid_analyses_ = valid_analyses_ | kAnalysisDefUse; + } + + // Sets the message consumer to the given |consumer|. |consumer| which will be + // invoked every time there is a message to be communicated to the outside. + void SetMessageConsumer(spvtools::MessageConsumer c) { + consumer_ = std::move(c); + } + + // Returns the reference to the message consumer for this pass. + const spvtools::MessageConsumer& consumer() const { return consumer_; } + + // Rebuilds the analyses in |set| that are invalid. + void BuildInvalidAnalyses(Analysis set); + + // Invalidates all of the analyses except for those in |preserved_analyses|. + void InvalidateAnalysesExceptFor(Analysis preserved_analyses); + + // Turns the instruction defining the given |id| into a Nop. Returns true on + // success, false if the given |id| is not defined at all. This method also + // erases both the uses of |id| and the information of this |id|-generating + // instruction's uses of its operands. + bool KillDef(uint32_t id); + + // Turns the given instruction |inst| to a Nop. This method erases the + // information of the given instruction's uses of its operands. If |inst| + // defines an result id, the uses of the result id will also be erased. + void KillInst(ir::Instruction* inst); + + // Returns true if all of the given analyses are valid. + bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; } + + // Replaces all uses of |before| id with |after| id. Returns true if any + // replacement happens. This method does not kill the definition of the + // |before| id. If |after| is the same as |before|, does nothing and returns + // false. + bool ReplaceAllUsesWith(uint32_t before, uint32_t after); + private: std::unique_ptr module_; + spvtools::MessageConsumer consumer_; + std::unique_ptr def_use_mgr_; + + // A bitset indicating which analyes are currently valid. + Analysis valid_analyses_; }; +inline ir::IRContext::Analysis operator|(ir::IRContext::Analysis lhs, + ir::IRContext::Analysis rhs) { + return static_cast(static_cast(lhs) | + static_cast(rhs)); +} + +inline ir::IRContext::Analysis& operator|=(ir::IRContext::Analysis& lhs, + ir::IRContext::Analysis rhs) { + lhs = static_cast(static_cast(lhs) | + static_cast(rhs)); + return lhs; +} + +inline ir::IRContext::Analysis operator<<(ir::IRContext::Analysis a, + int shift) { + return static_cast(static_cast(a) << shift); +} + +inline ir::IRContext::Analysis& operator<<=(ir::IRContext::Analysis& a, + int shift) { + a = static_cast(static_cast(a) << shift); + return a; +} + void IRContext::SetIdBound(uint32_t i) { module_->SetIdBound(i); } uint32_t IRContext::IdBound() const { return module()->IdBound(); } diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index f50c81513..1ea522510 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -35,7 +35,7 @@ void LocalAccessChainConvertPass::DeleteIfUseless(ir::Instruction* inst) { assert(resId != 0); if (HasOnlyNamesAndDecorates(resId)) { KillNamesAndDecorates(resId); - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); } } @@ -233,7 +233,7 @@ bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) { std::vector> newInsts; uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx); GenAccessChainStoreReplacement(ptrInst, valId, &newInsts); - get_def_use_mgr()->KillInst(&*ii); + context()->KillInst(&*ii); DeleteIfUseless(ptrInst); ++ii; ii = ii.InsertBefore(std::move(newInsts)); diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index c175cc49f..04916d863 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -68,7 +68,7 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( if (pinned_vars_.find(varId) == pinned_vars_.end()) { auto si = var2store_.find(varId); if (si != var2store_.end()) { - get_def_use_mgr()->KillInst(si->second); + context()->KillInst(si->second); } } var2store_[varId] = &*ii; diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index e303c7ea7..c4bfd127d 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -149,7 +149,7 @@ void MemPass::KillNamesAndDecorates(uint32_t id) { const SpvOp op = u.inst->opcode(); if (op == SpvOpName || IsNonTypeDecorate(op)) killList.push_back(u.inst); } - for (auto kip : killList) get_def_use_mgr()->KillInst(kip); + for (auto kip : killList) context()->KillInst(kip); } void MemPass::KillNamesAndDecorates(ir::Instruction* inst) { @@ -161,7 +161,7 @@ void MemPass::KillNamesAndDecorates(ir::Instruction* inst) { void MemPass::KillAllInsts(ir::BasicBlock* bp) { bp->ForEachInst([this](ir::Instruction* ip) { KillNamesAndDecorates(ip); - get_def_use_mgr()->KillInst(ip); + context()->KillInst(ip); }); } @@ -230,7 +230,7 @@ void MemPass::DCEInst(ir::Instruction* inst) { // Remember variable if dead load if (di->opcode() == SpvOpLoad) (void)GetPtr(di, &varId); KillNamesAndDecorates(di); - get_def_use_mgr()->KillInst(di); + context()->KillInst(di); // For all operands with no remaining uses, add their instruction // to the dead instruction queue. for (auto id : ids) @@ -246,7 +246,7 @@ void MemPass::DCEInst(ir::Instruction* inst) { void MemPass::ReplaceAndDeleteLoad(ir::Instruction* loadInst, uint32_t replId) { const uint32_t loadId = loadInst->result_id(); KillNamesAndDecorates(loadId); - (void)get_def_use_mgr()->ReplaceAllUsesWith(loadId, replId); + (void)context()->ReplaceAllUsesWith(loadId, replId); DCEInst(loadInst); } @@ -633,8 +633,8 @@ Pass::Status MemPass::InsertPhiInstructions(ir::Function* func) { // replacing to prevent incorrect replacement in those instructions. const uint32_t loadId = ii->result_id(); KillNamesAndDecorates(loadId); - (void)get_def_use_mgr()->ReplaceAllUsesWith(loadId, replId); - get_def_use_mgr()->KillInst(&*ii); + (void)context()->ReplaceAllUsesWith(loadId, replId); + context()->KillInst(&*ii); } break; default: { } break; } } @@ -767,14 +767,14 @@ void MemPass::RemoveBlock(ir::Function::iterator* bi) { // removal of phi operands. if (inst != rm_block.GetLabelInst()) { KillNamesAndDecorates(inst); - get_def_use_mgr()->KillInst(inst); + context()->KillInst(inst); } }); // Remove the label instruction last. auto label = rm_block.GetLabelInst(); KillNamesAndDecorates(label); - get_def_use_mgr()->KillInst(label); + context()->KillInst(label); *bi = bi->Erase(); } diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 804314011..53692d48e 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -105,7 +105,7 @@ bool Optimizer::Run(const uint32_t* original_binary, BuildModule(impl_->target_env, impl_->pass_manager.consumer(), original_binary, original_binary_size); if (module == nullptr) return false; - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), impl_->pass_manager.consumer()); auto status = impl_->pass_manager.Run(&context); if (status == opt::Pass::Status::SuccessWithChange || diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp index d0af4af6c..96e2ab182 100644 --- a/source/opt/pass.cpp +++ b/source/opt/pass.cpp @@ -28,11 +28,7 @@ const uint32_t kTypePointerTypeIdInIdx = 1; } // namespace -Pass::Pass() - : consumer_(nullptr), - def_use_mgr_(nullptr), - next_id_(0), - context_(nullptr) {} +Pass::Pass() : consumer_(nullptr), next_id_(0), context_(nullptr) {} void Pass::AddCalls(ir::Function* func, std::queue* todo) { for (auto bi = func->begin(); bi != func->end(); ++bi) @@ -106,6 +102,14 @@ bool Pass::ProcessCallTreeFromRoots( return modified; } +Pass::Status Pass::Run(ir::IRContext* ctx) { + Pass::Status status = Process(ctx); + if (status == Status::SuccessWithChange) { + ctx->InvalidateAnalysesExceptFor(GetPreservedAnalyses()); + } + return status; +} + uint32_t Pass::GetPointeeTypeId(const ir::Instruction* ptrInst) const { const uint32_t ptrTypeId = ptrInst->type_id(); const ir::Instruction* ptrTypeInst = get_def_use_mgr()->GetDef(ptrTypeId); @@ -114,4 +118,3 @@ uint32_t Pass::GetPointeeTypeId(const ir::Instruction* ptrInst) const { } // namespace opt } // namespace spvtools - diff --git a/source/opt/pass.h b/source/opt/pass.h index 427a4f69b..9b231d6ca 100644 --- a/source/opt/pass.h +++ b/source/opt/pass.h @@ -76,7 +76,7 @@ class Pass { // Returns the def-use manager used for this pass. TODO(dnovillo): This should // be handled by the pass manager. analysis::DefUseManager* get_def_use_mgr() const { - return def_use_mgr_.get(); + return context()->get_def_use_mgr(); } // Returns a pointer to the current module for this pass. @@ -111,10 +111,18 @@ class Pass { const std::unordered_map& id2function, std::queue* roots); - // Processes the given |module|. Returns Status::Failure if errors occur when + + // Run the pass on the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is - // succesful to indicate whether changes are made to the module. - virtual Status Process(ir::IRContext* context) = 0; + // successful to indicate whether changes are made to the module. If there + // were any changes it will also invalidate the analyses in the IRContext + // that are not preserved. + virtual Status Run(ir::IRContext* ctx) final; + + // Returns the set of analyses that the pass is guaranteed to preserve. + virtual ir::IRContext::Analysis GetPreservedAnalyses() { + return ir::IRContext::kAnalysisNone; + } protected: // Initialize basic data structures for the pass. This sets up the def-use @@ -124,10 +132,14 @@ class Pass { virtual void InitializeProcessing(ir::IRContext* c) { context_ = c; next_id_ = context_->IdBound(); - def_use_mgr_.reset(new analysis::DefUseManager(consumer(), get_module())); cfg_.reset(new ir::CFG(get_module())); } + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes are made to the module. + virtual Status Process(ir::IRContext* context) = 0; + // Return type id for |ptrInst|'s pointee uint32_t GetPointeeTypeId(const ir::Instruction* ptrInst) const; @@ -142,9 +154,6 @@ class Pass { private: MessageConsumer consumer_; // Message consumer. - // Def-Uses for the module we are processing - std::unique_ptr def_use_mgr_; - // Next unused ID uint32_t next_id_; diff --git a/source/opt/pass_manager.cpp b/source/opt/pass_manager.cpp index 4260a78f7..ce6256e69 100644 --- a/source/opt/pass_manager.cpp +++ b/source/opt/pass_manager.cpp @@ -21,7 +21,7 @@ namespace opt { Pass::Status PassManager::Run(ir::IRContext* context) { auto status = Pass::Status::SuccessWithoutChange; for (const auto& pass : passes_) { - const auto one_status = pass->Process(context); + const auto one_status = pass->Run(context); if (one_status == Pass::Status::Failure) return one_status; if (one_status == Pass::Status::SuccessWithChange) status = one_status; } diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp index 6b78f3eb0..3ee78f4d4 100644 --- a/source/opt/remove_duplicates_pass.cpp +++ b/source/opt/remove_duplicates_pass.cpp @@ -36,12 +36,11 @@ using opt::analysis::DefUseManager; using opt::analysis::DecorationManager; Pass::Status RemoveDuplicatesPass::Process(ir::IRContext* irContext) { - DefUseManager defUseManager(consumer(), irContext->module()); DecorationManager decManager(irContext->module()); bool modified = RemoveDuplicateCapabilities(irContext); - modified |= RemoveDuplicatesExtInstImports(irContext, defUseManager); - modified |= RemoveDuplicateTypes(irContext, defUseManager, decManager); + modified |= RemoveDuplicatesExtInstImports(irContext); + modified |= RemoveDuplicateTypes(irContext, decManager); modified |= RemoveDuplicateDecorations(irContext); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; @@ -67,8 +66,8 @@ bool RemoveDuplicatesPass::RemoveDuplicateCapabilities(ir::IRContext* irContext) return modified; } -bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( - ir::IRContext* irContext, analysis::DefUseManager& defUseManager) const { +bool +RemoveDuplicatesPass::RemoveDuplicatesExtInstImports(ir::IRContext* irContext) const { bool modified = false; std::unordered_map extInstImports; @@ -82,7 +81,7 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( ++i; } else { // It's a duplicate, remove it. - defUseManager.ReplaceAllUsesWith(i->result_id(), res.first->second); + irContext->ReplaceAllUsesWith(i->result_id(), res.first->second); i = i.Erase(); modified = true; } @@ -91,9 +90,8 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( return modified; } -bool RemoveDuplicatesPass::RemoveDuplicateTypes( - ir::IRContext* irContext, DefUseManager& defUseManager, - DecorationManager& decManager) const { +bool RemoveDuplicatesPass::RemoveDuplicateTypes(ir::IRContext* irContext, + DecorationManager& decManager) const { bool modified = false; std::vector visitedTypes; @@ -110,7 +108,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( // Is the current type equal to one of the types we have aready visited? SpvId idToKeep = 0u; for (auto j : visitedTypes) { - if (AreTypesEqual(*i, j, defUseManager, decManager)) { + if (AreTypesEqual(*i, j, *irContext->get_def_use_mgr(), decManager)) { idToKeep = j.result_id(); break; } @@ -122,7 +120,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( ++i; } else { // The same type has already been seen before, remove this one. - defUseManager.ReplaceAllUsesWith(i->result_id(), idToKeep); + irContext->ReplaceAllUsesWith(i->result_id(), idToKeep); modified = true; i = i.Erase(); } diff --git a/source/opt/remove_duplicates_pass.h b/source/opt/remove_duplicates_pass.h index 02570c67e..82e5f1db9 100644 --- a/source/opt/remove_duplicates_pass.h +++ b/source/opt/remove_duplicates_pass.h @@ -42,11 +42,9 @@ class RemoveDuplicatesPass : public Pass { private: bool RemoveDuplicateCapabilities(ir::IRContext* irContext) const; - bool RemoveDuplicatesExtInstImports( - ir::IRContext* irContext, analysis::DefUseManager& defUseManager) const; + bool RemoveDuplicatesExtInstImports(ir::IRContext* irContext) const; bool RemoveDuplicateTypes(ir::IRContext* irContext, - analysis::DefUseManager& defUseManager, - analysis::DecorationManager& decManager) const; + analysis::DecorationManager& decManager) const; bool RemoveDuplicateDecorations(ir::IRContext* irContext) const; }; diff --git a/source/opt/set_spec_constant_default_value_pass.cpp b/source/opt/set_spec_constant_default_value_pass.cpp index 422051c8b..d78a4138f 100644 --- a/source/opt/set_spec_constant_default_value_pass.cpp +++ b/source/opt/set_spec_constant_default_value_pass.cpp @@ -202,7 +202,7 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(ir::IRContext* irContext) const uint32_t kOpSpecConstantLiteralInOperandIndex = 0; bool modified = false; - analysis::DefUseManager def_use_mgr(consumer(), irContext->module()); + analysis::DefUseManager def_use_mgr(irContext->module()); analysis::TypeManager type_mgr(consumer(), *irContext->module()); // Scan through all the annotation instructions to find 'OpDecorate SpecId' // instructions. Then extract the decoration target of those instructions. diff --git a/source/opt/strength_reduction_pass.cpp b/source/opt/strength_reduction_pass.cpp index 28c80fdf7..689034aca 100644 --- a/source/opt/strength_reduction_pass.cpp +++ b/source/opt/strength_reduction_pass.cpp @@ -107,10 +107,10 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( get_def_use_mgr()->AnalyzeInstDefUse(&*newInstruction); inst = inst.InsertBefore(std::move(newInstruction)); ++inst; - get_def_use_mgr()->ReplaceAllUsesWith(inst->result_id(), newResultId); + context()->ReplaceAllUsesWith(inst->result_id(), newResultId); // Remove the old instruction. - get_def_use_mgr()->KillInst(&*inst); + context()->KillInst(&*inst); // We do not want to replace the instruction twice if both operands // are constants that are a power of 2. So we break here. diff --git a/source/opt/unify_const_pass.cpp b/source/opt/unify_const_pass.cpp index 93ae273b2..56ed894a3 100644 --- a/source/opt/unify_const_pass.cpp +++ b/source/opt/unify_const_pass.cpp @@ -107,12 +107,11 @@ Pass::Status UnifyConstantPass::Process(ir::IRContext* c) { InitializeProcessing(c); bool modified = false; ResultIdTrie defined_constants; - analysis::DefUseManager def_use_mgr(consumer(), get_module()); for (ir::Instruction& inst : context()->types_values()) { // Do not handle the instruction when there are decorations upon the result // id. - if (def_use_mgr.GetAnnotations(inst.result_id()).size() != 0) { + if (get_def_use_mgr()->GetAnnotations(inst.result_id()).size() != 0) { continue; } @@ -156,8 +155,8 @@ Pass::Status UnifyConstantPass::Process(ir::IRContext* c) { if (id != inst.result_id()) { // The constant is a duplicated one, use the cached constant to // replace the uses of this duplicated one, then turn it to nop. - def_use_mgr.ReplaceAllUsesWith(inst.result_id(), id); - def_use_mgr.KillInst(&inst); + context()->ReplaceAllUsesWith(inst.result_id(), id); + context()->KillInst(&inst); modified = true; } break; diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index f19ff7f5f..ecff063f5 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -198,3 +198,9 @@ add_spvtools_unittest(TARGET cfg_cleanup SRCS cfg_cleanup_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt ) + +add_spvtools_unittest(TARGET ir_context + SRCS ir_context_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) + diff --git a/test/opt/def_use_test.cpp b/test/opt/def_use_test.cpp index 0ce7b4ca0..88483a752 100644 --- a/test/opt/def_use_test.cpp +++ b/test/opt/def_use_test.cpp @@ -20,12 +20,14 @@ #include "opt/build_module.h" #include "opt/def_use_manager.h" +#include "opt/ir_context.h" #include "pass_utils.h" #include "spirv-tools/libspirv.hpp" namespace { -using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; using namespace spvtools; using spvtools::opt::analysis::DefUseManager; @@ -49,8 +51,8 @@ std::string DisassembleInst(ir::Instruction* inst) { // A struct for holding expected id defs and uses. struct InstDefUse { - using IdInstPair = std::pair; - using IdInstsPair = std::pair>; + using IdInstPair = std::pair; + using IdInstsPair = std::pair>; // Ids and their corresponding def instructions. std::vector defs; @@ -86,12 +88,13 @@ void CheckUse(const InstDefUse& expected_defs_uses, ASSERT_EQ(expected_uses.size(), uses.size()) << "id [" << id << "] # uses: expected: " << expected_uses.size() << " actual: " << uses.size(); - auto it = uses.cbegin(); - for (const auto expected_use : expected_uses) { - EXPECT_EQ(expected_use, DisassembleInst(it->inst)) - << "id [" << id << "] use instruction mismatch"; - ++it; + + std::vector actual_uses_disassembled; + for (const auto actual_use : uses) { + actual_uses_disassembled.emplace_back(DisassembleInst(actual_use.inst)); } + EXPECT_THAT(actual_uses_disassembled, + UnorderedElementsAreArray(expected_uses)); } } @@ -133,7 +136,7 @@ TEST_P(ParseDefUseTest, Case) { ASSERT_NE(nullptr, module); // Analyze def and use. - opt::analysis::DefUseManager manager(nullptr, module.get()); + opt::analysis::DefUseManager manager(module.get()); CheckDef(tc.du, manager.id_to_defs()); CheckUse(tc.du, manager.id_to_uses()); @@ -512,18 +515,19 @@ TEST_P(ReplaceUseTest, Case) { std::unique_ptr module = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); ASSERT_NE(nullptr, module); + ir::IRContext context(std::move(module), spvtools::MessageConsumer()); // Analyze def and use. - opt::analysis::DefUseManager manager(nullptr, module.get()); + context.BuildDefUseManager(); // Do the substitution. for (const auto& candiate : tc.candidates) { - manager.ReplaceAllUsesWith(candiate.first, candiate.second); + context.ReplaceAllUsesWith(candiate.first, candiate.second); } - EXPECT_EQ(tc.after, DisassembleModule(module.get())); - CheckDef(tc.du, manager.id_to_defs()); - CheckUse(tc.du, manager.id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context.module())); + CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -814,16 +818,17 @@ TEST_P(KillDefTest, Case) { std::unique_ptr module = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); ASSERT_NE(nullptr, module); + ir::IRContext context(std::move(module), spvtools::MessageConsumer()); // Analyze def and use. - opt::analysis::DefUseManager manager(nullptr, module.get()); + opt::analysis::DefUseManager manager(module.get()); // Do the substitution. - for (const auto id : tc.ids_to_kill) manager.KillDef(id); + for (const auto id : tc.ids_to_kill) context.KillDef(id); - EXPECT_EQ(tc.after, DisassembleModule(module.get())); - CheckDef(tc.du, manager.id_to_defs()); - CheckUse(tc.du, manager.id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context.module())); + CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1064,14 +1069,15 @@ TEST(DefUseTest, OpSwitch) { std::unique_ptr module = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text); ASSERT_NE(nullptr, module); + ir::IRContext context(std::move(module), spvtools::MessageConsumer()); // Analyze def and use. - opt::analysis::DefUseManager manager(nullptr, module.get()); + context.BuildDefUseManager(); // Do a bunch replacements. - manager.ReplaceAllUsesWith(9, 900); // to unused id - manager.ReplaceAllUsesWith(10, 1000); // to unused id - manager.ReplaceAllUsesWith(11, 7); // to existing id + context.ReplaceAllUsesWith(9, 900); // to unused id + context.ReplaceAllUsesWith(10, 1000); // to unused id + context.ReplaceAllUsesWith(11, 7); // to existing id // clang-format off const char modified_text[] = @@ -1095,7 +1101,7 @@ TEST(DefUseTest, OpSwitch) { "OpFunctionEnd"; // clang-format on - EXPECT_EQ(modified_text, DisassembleModule(module.get())); + EXPECT_EQ(modified_text, DisassembleModule(context.module())); InstDefUse def_uses = {}; def_uses.defs = { @@ -1110,17 +1116,18 @@ TEST(DefUseTest, OpSwitch) { {10, "%10 = OpLabel"}, {11, "%11 = OpLabel"}, }; - CheckDef(def_uses, manager.id_to_defs()); + CheckDef(def_uses, context.get_def_use_mgr()->id_to_defs()); { - auto* use_list = manager.GetUses(6); + auto* use_list = context.get_def_use_mgr()->GetUses(6); ASSERT_NE(nullptr, use_list); EXPECT_EQ(2u, use_list->size()); - EXPECT_EQ(SpvOpSwitch, use_list->front().inst->opcode()); - EXPECT_EQ(SpvOpReturnValue, use_list->back().inst->opcode()); + std::vector opcodes = {use_list->front().inst->opcode(), + use_list->back().inst->opcode()}; + EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue)); } { - auto* use_list = manager.GetUses(7); + auto* use_list = context.get_def_use_mgr()->GetUses(7); ASSERT_NE(nullptr, use_list); EXPECT_EQ(6u, use_list->size()); std::vector opcodes; @@ -1128,13 +1135,13 @@ TEST(DefUseTest, OpSwitch) { opcodes.push_back(use.inst->opcode()); } // OpSwitch is now a user of %7. - EXPECT_THAT(opcodes, - ElementsAre(SpvOpSelectionMerge, SpvOpBranch, SpvOpBranch, - SpvOpBranch, SpvOpBranch, SpvOpSwitch)); + EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSelectionMerge, SpvOpBranch, + SpvOpBranch, SpvOpBranch, + SpvOpBranch, SpvOpSwitch)); } // Check all ids only used by OpSwitch after replacement. for (const auto id : {8, 900, 1000}) { - auto* use_list = manager.GetUses(id); + auto* use_list = context.get_def_use_mgr()->GetUses(id); ASSERT_NE(nullptr, use_list); EXPECT_EQ(1u, use_list->size()); EXPECT_EQ(SpvOpSwitch, use_list->front().inst->opcode()); @@ -1189,7 +1196,7 @@ TEST_P(AnalyzeInstDefUseTest, Case) { ASSERT_NE(nullptr, module); // Analyze the instructions. - opt::analysis::DefUseManager manager(nullptr, module.get()); + opt::analysis::DefUseManager manager(module.get()); for (ir::Instruction& inst : tc.insts) { manager.AnalyzeInstDefUse(&inst); } @@ -1310,20 +1317,21 @@ TEST_P(KillInstTest, Case) { std::unique_ptr module = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before); ASSERT_NE(nullptr, module); + ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + context.BuildDefUseManager(); // KillInst uint32_t index = 0; - opt::analysis::DefUseManager manager(nullptr, module.get()); - module->ForEachInst([&index, &tc, &manager](ir::Instruction* inst) { + context.module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) { if (tc.indices_for_inst_to_kill.count(index) != 0) { - manager.KillInst(inst); + context.KillInst(inst); } index++; }); - EXPECT_EQ(tc.after, DisassembleModule(module.get())); - CheckDef(tc.expected_define_use, manager.id_to_defs()); - CheckUse(tc.expected_define_use, manager.id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context.module())); + CheckDef(tc.expected_define_use, context.get_def_use_mgr()->id_to_defs()); + CheckUse(tc.expected_define_use, context.get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1420,7 +1428,7 @@ TEST_P(GetAnnotationsTest, Case) { ASSERT_NE(nullptr, module); // Get annotations - opt::analysis::DefUseManager manager(nullptr, module.get()); + opt::analysis::DefUseManager manager(module.get()); auto insts = manager.GetAnnotations(tc.id); // Check diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp new file mode 100644 index 000000000..457a7c563 --- /dev/null +++ b/test/opt/ir_context_test.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "opt/ir_context.h" +#include "opt/pass.h" +#include "pass_fixture.h" +#include "pass_utils.h" + +namespace { + +using namespace spvtools; +using ir::IRContext; +using Analysis = IRContext::Analysis; + +class DummyPassPreservesNothing : public opt::Pass { + public: + DummyPassPreservesNothing(Status s) : opt::Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } + Status Process(IRContext*) override { return status_to_return_; } + Status status_to_return_; +}; + +class DummyPassPreservesAll : public opt::Pass { + public: + DummyPassPreservesAll(Status s) : opt::Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } + Status Process(IRContext*) override { return status_to_return_; } + Status status_to_return_; + virtual Analysis GetPreservedAnalyses() override { + return Analysis(IRContext::kAnalysisEnd - 1); + } +}; + +class DummyPassPreservesFirst : public opt::Pass { + public: + DummyPassPreservesFirst(Status s) : opt::Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } + Status Process(IRContext*) override { return status_to_return_; } + Status status_to_return_; + virtual Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisBegin; + } +}; + +using IRContextTest = PassTest<::testing::Test>; + +TEST_F(IRContextTest, IndividualValidAfterBuild) { + std::unique_ptr module(new ir::Module()); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + EXPECT_TRUE(context.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, AllValidAfterBuild) { + std::unique_ptr module = MakeUnique(); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + Analysis built_analyses = IRContext::kAnalysisNone; + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + built_analyses |= i; + } + EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); +} + +TEST_F(IRContextTest, AllValidAfterPassNoChange) { + std::unique_ptr module = MakeUnique(); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + Analysis built_analyses = IRContext::kAnalysisNone; + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + built_analyses |= i; + } + + DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithoutChange); + opt::Pass::Status s = pass.Run(&context); + EXPECT_EQ(s, opt::Pass::Status::SuccessWithoutChange); + EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); +} + +TEST_F(IRContextTest, NoneValidAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + } + + DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithChange); + opt::Pass::Status s = pass.Run(&context); + EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_FALSE(context.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + } + + DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange); + opt::Pass::Status s = pass.Run(&context); + EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_TRUE(context.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, AllPreserveFirstOnlyAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext context(std::move(module), spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + context.BuildInvalidAnalyses(i); + } + + DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange); + opt::Pass::Status s = pass.Run(&context); + EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + EXPECT_TRUE(context.AreAnalysesValid(IRContext::kAnalysisBegin)); + for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_FALSE(context.AreAnalysesValid(i)); + } +} +} // anonymous namespace diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h index bf5964861..93ef916a6 100644 --- a/test/opt/pass_fixture.h +++ b/test/opt/pass_fixture.h @@ -61,9 +61,9 @@ class PassTest : public TestT { opt::Pass::Status::Failure); } - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), consumer()); - const auto status = pass->Process(&context); + const auto status = pass->Run(&context); std::vector binary; context.module()->ToBinary(&binary, skip_nop); @@ -171,7 +171,7 @@ class PassTest : public TestT { std::unique_ptr module = BuildModule( SPV_ENV_UNIVERSAL_1_1, nullptr, original, assemble_options_); ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), consumer()); manager_->Run(&context); @@ -192,6 +192,7 @@ class PassTest : public TestT { disassemble_options_ = disassemble_options; } + MessageConsumer consumer() { return consumer_;} private: MessageConsumer consumer_; // Message consumer. SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities. diff --git a/test/opt/pass_manager_test.cpp b/test/opt/pass_manager_test.cpp index d54e74533..43d700557 100644 --- a/test/opt/pass_manager_test.cpp +++ b/test/opt/pass_manager_test.cpp @@ -151,12 +151,12 @@ class AppendTypeVoidInstPass : public opt::Pass { }; TEST(PassManager, RecomputeIdBoundAutomatically) { + opt::PassManager manager; std::unique_ptr module(new ir::Module()); - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), manager.consumer()); EXPECT_THAT(GetIdBound(*context.module()), Eq(0u)); - opt::PassManager manager; manager.Run(&context); manager.AddPass(); // With no ID changes, the ID bound does not change. diff --git a/test/opt/pass_test.cpp b/test/opt/pass_test.cpp index 012ea7239..8c62b289b 100644 --- a/test/opt/pass_test.cpp +++ b/test/opt/pass_test.cpp @@ -137,7 +137,7 @@ TEST_F(PassClassTest, BasicVisitReachable) { SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), consumer()); DummyPass testPass; std::vector processed; @@ -190,7 +190,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), consumer()); DummyPass testPass; std::vector processed; @@ -233,7 +233,7 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::IRContext context(std::move(module)); + ir::IRContext context(std::move(module), consumer()); DummyPass testPass; std::vector processed;