diff --git a/Android.mk b/Android.mk index 3a59c2076..f0ad67c65 100644 --- a/Android.mk +++ b/Android.mk @@ -111,6 +111,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/replace_invalid_opc.cpp \ source/opt/scalar_replacement_pass.cpp \ source/opt/set_spec_constant_default_value_pass.cpp \ + source/opt/simplification_pass.cpp \ source/opt/strength_reduction_pass.cpp \ source/opt/strip_debug_info_pass.cpp \ source/opt/type_manager.cpp \ diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index b866bc69b..becd024ec 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -504,6 +504,9 @@ Optimizer::PassToken CreateIfConversionPass(); // current shader stage by constants. Has no effect on non-shader modules. Optimizer::PassToken CreateReplaceInvalidOpcodePass(); +// Creates a pass that simplifies instructions using the instruction folder. +Optimizer::PassToken CreateSimplificationPass(); + } // namespace spvtools #endif // SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index c02086cc3..f43ac7779 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -70,6 +70,7 @@ add_library(SPIRV-Tools-opt replace_invalid_opc.h scalar_replacement_pass.h set_spec_constant_default_value_pass.h + simplification_pass.h strength_reduction_pass.h strip_debug_info_pass.h tree_iterator.h @@ -135,6 +136,7 @@ add_library(SPIRV-Tools-opt replace_invalid_opc.cpp scalar_replacement_pass.cpp set_spec_constant_default_value_pass.cpp + simplification_pass.cpp strength_reduction_pass.cpp strip_debug_info_pass.cpp type_manager.cpp diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp index c2dd235b2..404d19249 100644 --- a/source/opt/cfg.cpp +++ b/source/opt/cfg.cpp @@ -87,6 +87,19 @@ void CFG::ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root, root, get_structured_successors, ignore_block, post_order, ignore_edge); } +void CFG::ForEachBlockInReversePostOrder( + BasicBlock* bb, const std::function& f) { + std::vector po; + std::unordered_set seen; + ComputePostOrderTraversal(bb, &po, &seen); + + for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) { + if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) { + f(*current_bb); + } + } +} + void CFG::ComputeStructuredSuccessors(ir::Function* func) { block2structured_succs_.clear(); for (auto& blk : *func) { @@ -111,5 +124,18 @@ void CFG::ComputeStructuredSuccessors(ir::Function* func) { } } +void CFG::ComputePostOrderTraversal(BasicBlock* bb, vector* order, + unordered_set* seen) { + seen->insert(bb); + static_cast(bb)->ForEachSuccessorLabel( + [&order, &seen, this](const uint32_t sbid) { + BasicBlock* succ_bb = id2block_[sbid]; + if (!seen->count(succ_bb)) { + ComputePostOrderTraversal(succ_bb, order, seen); + } + }); + order->push_back(bb); +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/cfg.h b/source/opt/cfg.h index 138aa0a6b..53dddd234 100644 --- a/source/opt/cfg.h +++ b/source/opt/cfg.h @@ -19,6 +19,7 @@ #include #include +#include namespace spvtools { namespace ir { @@ -68,6 +69,12 @@ class CFG { void ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root, std::list* order); + // Applies |f| to the basic block in reverse post order starting with |bb|. + // Note that basic blocks that cannot be reached from |bb| node will not be + // processed. + void ForEachBlockInReversePostOrder( + BasicBlock* bb, const std::function& f); + // Registers |blk| as a basic block in the cfg, this also updates the // predecessor lists of each successor of |blk|. void RegisterBlock(ir::BasicBlock* blk) { @@ -101,6 +108,13 @@ class CFG { // ignored by DFS. void ComputeStructuredSuccessors(ir::Function* func); + // Computes the post-order traversal of the cfg starting at |bb| skipping + // nodes in |seen|. The order of the traversal is appended to |order|, and + // all nodes in the traversal are added to |seen|. + void ComputePostOrderTraversal(BasicBlock* bb, + std::vector* order, + std::unordered_set* seen); + // Module for this CFG. ir::Module* module_; diff --git a/source/opt/dead_insert_elim_pass.h b/source/opt/dead_insert_elim_pass.h index df13c0cc9..5ce2aefae 100644 --- a/source/opt/dead_insert_elim_pass.h +++ b/source/opt/dead_insert_elim_pass.h @@ -36,7 +36,7 @@ namespace opt { class DeadInsertElimPass : public MemPass { public: DeadInsertElimPass(); - const char* name() const override { return "eliminate-dead-insert"; } + const char* name() const override { return "eliminate-dead-inserts"; } Status Process(ir::IRContext*) override; private: diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp index a6c93f89c..04c04469b 100644 --- a/source/opt/fold.cpp +++ b/source/opt/fold.cpp @@ -182,10 +182,10 @@ uint32_t OperateWords(SpvOp opcode, } } -bool FoldInstructionInternal(ir::Instruction* inst, - std::function id_map) { +bool FoldInstructionInternal(ir::Instruction* inst) { ir::IRContext* context = inst->context(); - ir::Instruction* folded_inst = FoldInstructionToConstant(inst, id_map); + auto identity_map = [](uint32_t id) { return id; }; + ir::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); if (folded_inst != nullptr) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); @@ -201,8 +201,7 @@ bool FoldInstructionInternal(ir::Instruction* inst, if (operand->type != SPV_OPERAND_TYPE_ID) { constants.push_back(nullptr); } else { - uint32_t id = id_map(operand->words[0]); - inst->SetInOperand(i, {id}); + uint32_t id = operand->words[0]; const analysis::Constant* constant = const_manger->FindDeclaredConstant(id); constants.push_back(constant); @@ -660,29 +659,13 @@ bool IsFoldableType(ir::Instruction* type_inst) { return false; } -ir::Instruction* FoldInstruction(ir::Instruction* inst, - std::function id_map) { - ir::IRContext* context = inst->context(); +bool FoldInstruction(ir::Instruction* inst) { bool modified = false; - std::unique_ptr folded_inst(inst->Clone(context)); - while (FoldInstructionInternal(&*folded_inst, id_map)) { + ir::Instruction* folded_inst(inst); + while (FoldInstructionInternal(&*folded_inst)) { modified = true; } - - if (modified) { - if (folded_inst->opcode() == SpvOpCopyObject) { - analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - return def_use_mgr->GetDef(folded_inst->GetSingleWordInOperand(0)); - } else { - InstructionBuilder ir_builder( - context, inst, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - folded_inst->SetResultId(context->TakeNextId()); - return ir_builder.AddInstruction(std::move(folded_inst)); - } - } - return nullptr; + return modified; } } // namespace opt diff --git a/source/opt/fold.h b/source/opt/fold.h index 0438ccfe7..439ed2b6a 100644 --- a/source/opt/fold.h +++ b/source/opt/fold.h @@ -15,12 +15,12 @@ #ifndef LIBSPIRV_UTIL_FOLD_H_ #define LIBSPIRV_UTIL_FOLD_H_ -#include "constants.h" -#include "def_use_manager.h" - #include #include +#include "constants.h" +#include "def_use_manager.h" + namespace spvtools { namespace opt { @@ -75,26 +75,18 @@ bool IsFoldableType(ir::Instruction* type_inst); ir::Instruction* FoldInstructionToConstant( ir::Instruction* inst, std::function id_map); -// Tries to fold |inst| to a simpler instruction that computes the same value, -// when the input ids to |inst| have been substituted using |id_map|. Returns a -// pointer to the simplified instruction if successful. If necessary, a new -// instruction is created and placed in the global values section, for -// constants, or after |inst| for other instructions. +// Returns true if |inst| can be folded into a simpler instruction. +// If |inst| can be simplified, |inst| is overwritten with the simplified +// instruction reusing the same result id. // -// |inst| must be an instruction that exists in the body of a function. +// If |inst| is simplified, it is possible that the resulting code in invalid +// because the instruction is in a bad location. Callers of this function have +// to handle the following cases: // -// |id_map| is a function that takes one result id and returns another. It can -// be used for things like CCP where it is known that some ids contain a -// constant, but the instruction itself has not been updated yet. This can map -// those ids to the appropriate constants. -ir::Instruction* FoldInstruction(ir::Instruction* inst, - std::function id_map); - -// The same as above when |id_map| is the identity function. -inline ir::Instruction* FoldInstruction(ir::Instruction* inst) { - auto identity_map = [](uint32_t id) { return id; }; - return FoldInstruction(inst, identity_map); -} +// 1) An OpPhi becomes and OpCopyObject - If there are OpPhi instruction after +// |inst| in a basic block then this is invalid. The caller must fix this +// up. +bool FoldInstruction(ir::Instruction* inst); } // namespace opt } // namespace spvtools diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 728ad0cd2..ed33dfd61 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -31,7 +31,7 @@ FoldingRule IntMultipleBy1() { continue; } const analysis::IntConstant* int_constant = constants[i]->AsIntConstant(); - if (int_constant->GetU32BitValue() == 1) { + if (int_constant && int_constant->GetU32BitValue() == 1) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}}); @@ -42,6 +42,142 @@ FoldingRule IntMultipleBy1() { }; } +FoldingRule CompositeConstructFeedingExtract() { + return [](ir::Instruction* inst, + const std::vector&) { + // If the input to an OpCompositeExtract is an OpCompositeConstruct, + // then we can simply use the appropriate element in the construction. + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + ir::Instruction* cinst = def_use_mgr->GetDef(cid); + + if (cinst->opcode() != SpvOpCompositeConstruct) { + return false; + } + + std::vector operands; + analysis::Type* composite_type = type_mgr->GetType(cinst->type_id()); + if (composite_type->AsVector() == nullptr) { + // Get the element being extracted from the OpCompositeConstruct + // Since it is not a vector, it is simple to extract the single element. + uint32_t element_index = inst->GetSingleWordInOperand(1); + uint32_t element_id = cinst->GetSingleWordInOperand(element_index); + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + + // Add the remaining indices for extraction. + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + operands.push_back( + {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(i)}}); + } + + } else { + // With vectors we have to handle the case where it is concatenating + // vectors. + assert(inst->NumInOperands() == 2 && + "Expecting a vector of scalar values."); + + uint32_t element_index = inst->GetSingleWordInOperand(1); + for (uint32_t construct_index = 0; + construct_index < cinst->NumInOperands(); ++construct_index) { + uint32_t element_id = cinst->GetSingleWordInOperand(construct_index); + ir::Instruction* element_def = def_use_mgr->GetDef(element_id); + analysis::Vector* element_type = + type_mgr->GetType(element_def->type_id())->AsVector(); + if (element_type) { + uint32_t vector_size = element_type->element_count(); + if (vector_size < element_index) { + // The element we want comes after this vector. + element_index -= vector_size; + } else { + // We want an element of this vector. + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}}); + break; + } + } else { + if (element_index == 0) { + // This is a scalar, and we this is the element we are extracting. + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + break; + } else { + // Skip over this scalar value. + --element_index; + } + } + } + } + + // If there were no extra indices, then we have the final object. No need + // to extract even more. + if (operands.size() == 1) { + inst->SetOpcode(SpvOpCopyObject); + } + + inst->SetInOperands(std::move(operands)); + return true; + }; +} + +FoldingRule CompositeExtractFeedingConstruct() { + // If the OpCompositeConstruct is simply putting back together elements that + // where extracted from the same souce, we can simlpy reuse the source. + // + // This is a common code pattern because of the way that scalar replacement + // works. + return [](ir::Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeConstruct && + "Wrong opcode. Should be OpCompositeConstruct."); + analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + uint32_t original_id = 0; + + // Check each element to make sure they are: + // - extractions + // - extracting the same position they are inserting + // - all extract from the same id. + for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { + uint32_t element_id = inst->GetSingleWordInOperand(i); + ir::Instruction* element_inst = def_use_mgr->GetDef(element_id); + + if (element_inst->opcode() != SpvOpCompositeExtract) { + return false; + } + + if (element_inst->NumInOperands() != 2) { + return false; + } + + if (element_inst->GetSingleWordInOperand(1) != i) { + return false; + } + + if (i == 0) { + original_id = + element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + } else if (original_id != element_inst->GetSingleWordInOperand( + kExtractCompositeIdInIdx)) { + return false; + } + } + + // The last check it to see that the object being extracted from is the + // correct type. + ir::Instruction* original_inst = def_use_mgr->GetDef(original_id); + if (original_inst->type_id() != inst->type_id()) { + return false; + } + + // Simplify by using the original object. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); + return true; + }; +} + FoldingRule InsertFeedingExtract() { return [](ir::Instruction* inst, const std::vector&) { @@ -113,6 +249,51 @@ FoldingRule InsertFeedingExtract() { return true; }; } + +FoldingRule RedundantPhi() { + // An OpPhi instruction where all values are the same or the result of the phi + // itself, can be replaced by the value itself. + return + [](ir::Instruction* inst, const std::vector&) { + assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); + + ir::IRContext* context = inst->context(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + uint32_t incoming_value = 0; + + for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + if (op_id == inst->result_id()) { + continue; + } + + ir::Instruction* op_inst = def_use_mgr->GetDef(op_id); + if (op_inst->opcode() == SpvOpUndef) { + // TODO: We should be able to still use op_id if we know that + // the definition of op_id dominates |inst|. + return false; + } + + if (incoming_value == 0) { + incoming_value = op_id; + } else if (op_id != incoming_value) { + // Found two possible value. Can't simplify. + return false; + } + } + + if (incoming_value == 0) { + // Code looks invalid. Don't do anything. + return false; + } + + // We have a single incoming value. Simplify using that value. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); + return true; + }; +} } // namespace spvtools::opt::FoldingRules::FoldingRules() { @@ -121,9 +302,14 @@ spvtools::opt::FoldingRules::FoldingRules() { // applies to the instruction, the rest of the rules will not be attempted. // Take that into consideration. - rules[SpvOpIMul].push_back(IntMultipleBy1()); + rules[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct()); rules[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); + rules[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); + + rules[SpvOpIMul].push_back(IntMultipleBy1()); + + rules[SpvOpPhi].push_back(RedundantPhi()); } } // namespace opt } // namespace spvtools diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 31f42ae04..bdcc68400 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -18,6 +18,7 @@ #include "make_unique.h" #include "pass_manager.h" #include "passes.h" +#include "simplification_pass.h" namespace spvtools { @@ -103,7 +104,7 @@ Optimizer& Optimizer::RegisterLegalizationPasses() { .RegisterPass(CreateLocalMultiStoreElimPass()) // Copy propagate members. Cleans up code sequences generated by // scalar replacement. - .RegisterPass(CreateInsertExtractElimPass()) + .RegisterPass(CreateSimplificationPass()) // May need loop unrolling here see // https://github.com/Microsoft/DirectXShaderCompiler/pull/930 .RegisterPass(CreateDeadBranchElimPass()) @@ -379,4 +380,9 @@ Optimizer::PassToken CreateReplaceInvalidOpcodePass() { return MakeUnique( MakeUnique()); } + +Optimizer::PassToken CreateSimplificationPass() { + return MakeUnique( + MakeUnique()); +} } // namespace spvtools diff --git a/source/opt/simplification_pass.cpp b/source/opt/simplification_pass.cpp new file mode 100644 index 000000000..47706e8f8 --- /dev/null +++ b/source/opt/simplification_pass.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "simplification_pass.h" + +#include +#include +#include + +#include "fold.h" + +namespace spvtools { +namespace opt { + +Pass::Status SimplificationPass::Process(ir::IRContext* c) { + InitializeProcessing(c); + bool modified = false; + + for (ir::Function& function : *get_module()) { + modified |= SimplifyFunction(&function); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool SimplificationPass::SimplifyFunction(ir::Function* function) { + bool modified = false; + // Phase 1: Traverse all instructions in dominance order. + // The second phase will only be on the instructions whose inputs have changed + // after being processed during phase 1. Since OpPhi instructions are the + // only instructions whose inputs do not necessarily dominate the use, we keep + // track of the OpPhi instructions already seen, and add them to the work list + // for phase 2 when needed. + std::vector work_list; + std::unordered_set process_phis; + std::unordered_set inst_to_kill; + std::unordered_set in_work_list; + + cfg()->ForEachBlockInReversePostOrder( + function->entry().get(), + [&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill, + this](ir::BasicBlock* bb) { + for (ir::Instruction* inst = &*bb->begin(); inst; + inst = inst->NextNode()) { + if (inst->opcode() == SpvOpPhi) { + process_phis.insert(inst); + } + + if (inst->opcode() == SpvOpCopyObject || FoldInstruction(inst)) { + modified = true; + context()->AnalyzeUses(inst); + get_def_use_mgr()->ForEachUser(inst, [&work_list, &process_phis, + &in_work_list]( + ir::Instruction* use) { + if (process_phis.count(use) && in_work_list.insert(use).second) { + work_list.push_back(use); + } + }); + if (inst->opcode() == SpvOpCopyObject) { + context()->ReplaceAllUsesWith(inst->result_id(), + inst->GetSingleWordInOperand(0)); + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } + } + } + }); + + // Phase 2: process the instructions in the work list until all of the work is + // done. This time we add all users to the work list because phase 1 + // has already finished. + for (size_t i = 0; i < work_list.size(); ++i) { + ir::Instruction* inst = work_list[i]; + in_work_list.erase(inst); + if (FoldInstruction(inst)) { + modified = true; + context()->AnalyzeUses(inst); + get_def_use_mgr()->ForEachUser( + inst, [&work_list, &in_work_list](ir::Instruction* use) { + if (!use->IsDecoration() && use->opcode() != SpvOpName && + in_work_list.insert(use).second) { + work_list.push_back(use); + } + }); + + if (inst->opcode() == SpvOpCopyObject) { + context()->ReplaceAllUsesWith(inst->result_id(), + inst->GetSingleWordInOperand(0)); + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } + } + } + + // Phase 3: Kill instructions we know are no longer needed. + for (ir::Instruction* inst : inst_to_kill) { + context()->KillInst(inst); + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/simplification_pass.h b/source/opt/simplification_pass.h new file mode 100644 index 000000000..ff0e3be15 --- /dev/null +++ b/source/opt/simplification_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ +#define LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ + +#include "function.h" +#include "ir_context.h" +#include "pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class SimplificationPass : public Pass { + public: + const char* name() const override { return "simplify-instructions"; } + Status Process(ir::IRContext*) override; + + private: + // Returns true if the module was changed. The simplifier is called on every + // instruction in |function| until nothing else in the function can be + // simplified. + bool SimplifyFunction(ir::Function* function); +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 63266bf3c..6029e8ea3 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -291,3 +291,8 @@ add_spvtools_unittest(TARGET replace_invalid_opc SRCS replace_invalid_opc_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt ) + +add_spvtools_unittest(TARGET simplification + SRCS simplification_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 8b4773fb2..7d2567d7e 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -57,11 +57,13 @@ TEST_P(IntegerInstructionFoldingTest, Case) { // Fold the instruction to test. opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstruction(inst); + bool succeeded = opt::FoldInstruction(inst); // Make sure the instruction folded as expected. - EXPECT_NE(inst, nullptr); + EXPECT_TRUE(succeeded); if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); EXPECT_EQ(inst->opcode(), SpvOpConstant); opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); const opt::analysis::IntConstant* result = @@ -94,7 +96,9 @@ OpName %main "main" %int = OpTypeInt 32 1 %long = OpTypeInt 64 1 %uint = OpTypeInt 32 1 +%v2int = OpTypeVector %int 2 %v4int = OpTypeVector %int 4 +%struct_v2int_int_int = OpTypeStruct %v2int %int %int %_ptr_int = OpTypePointer Function %int %_ptr_uint = OpTypePointer Function %uint %_ptr_bool = OpTypePointer Function %bool @@ -112,6 +116,7 @@ OpName %main "main" %uint_3 = OpConstant %uint 3 %uint_32 = OpConstant %uint 32 %uint_max = OpConstant %uint -1 +%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 )"; @@ -309,11 +314,13 @@ TEST_P(BooleanInstructionFoldingTest, Case) { // Fold the instruction to test. opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstruction(inst); + bool succeeded = opt::FoldInstruction(inst); // Make sure the instruction folded as expected. - EXPECT_NE(inst, nullptr); + EXPECT_TRUE(succeeded); if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); std::vector bool_opcodes = {SpvOpConstantTrue, SpvOpConstantFalse}; EXPECT_THAT(bool_opcodes, Contains(inst->opcode())); opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); @@ -560,7 +567,7 @@ TEST_P(IntegerInstructionFoldingTestWithMap, Case) { // Fold the instruction to test. opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstruction(inst, tc.id_map); + inst = opt::FoldInstructionToConstant(inst, tc.id_map); // Make sure the instruction folded as expected. EXPECT_NE(inst, nullptr); @@ -607,7 +614,7 @@ TEST_P(BooleanInstructionFoldingTestWithMap, Case) { // Fold the instruction to test. opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstruction(inst, tc.id_map); + inst = opt::FoldInstructionToConstant(inst, tc.id_map); // Make sure the instruction folded as expected. EXPECT_NE(inst, nullptr); @@ -656,18 +663,27 @@ TEST_P(GeneralInstructionFoldingTest, Case) { // Fold the instruction to test. opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstruction(inst); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = opt::FoldInstruction(inst); // Make sure the instruction folded as expected. - EXPECT_TRUE((inst == nullptr) == (tc.expected_result == 0)); - if (inst != nullptr) { - EXPECT_EQ(inst->result_id(), tc.expected_result); + EXPECT_EQ(inst->result_id(), original_inst->result_id()); + EXPECT_EQ(inst->type_id(), original_inst->type_id()); + EXPECT_TRUE((!succeeded) == (tc.expected_result == 0)); + if (succeeded) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result); + } else { + EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands()); + for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { + EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i)); + } } } // clang-format off -INSTANTIATE_TEST_CASE_P(TestCase, GeneralInstructionFoldingTest, -::testing::Values( +INSTANTIATE_TEST_CASE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTest, + ::testing::Values( // Test case 0: Don't fold n * m InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + @@ -1123,8 +1139,12 @@ INSTANTIATE_TEST_CASE_P(TestCase, GeneralInstructionFoldingTest, "%2 = OpIMul %int %3 %int_1\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 3), - // Test case 42: fold Insert feeding extract + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: fold Insert feeding extract InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -1137,7 +1157,174 @@ INSTANTIATE_TEST_CASE_P(TestCase, GeneralInstructionFoldingTest, "%7 = OpCompositeExtract %int %6 0\n" + "OpReturn\n" + "OpFunctionEnd", - 7, 2) + 7, 2), + // Test case 1: fold Composite construct feeding extract (position 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v4int %2 %int_0 %int_0 %int_0\n" + + "%4 = OpCompositeExtract %int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, 2), + // Test case 2: fold Composite construct feeding extract (position 3) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v4int %2 %int_0 %int_0 %100\n" + + "%4 = OpCompositeExtract %int %3 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_0_ID), + // Test case 3: fold Composite construct with vectors feeding extract (scalar element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %2 %int_0\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, INT_0_ID), + // Test case 4: fold Composite construct with vectors feeding extract (start of vector element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %2 %int_0\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2), + // Test case 5: fold Composite construct with vectors feeding extract (middle of vector element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %int_0 %2\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2), + // Test case 6: fold Composite construct with multiple indices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %int_0 %2\n" + + "%4 = OpCompositeConstruct %struct_v2int_int_int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 0 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2) +)); + +INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: fold Extracts feeding construct + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %2 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 2), + // Test case 1: Don't fold Extracts feeding construct (Different source) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %v4int_0_0_0_0 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0), + // Test case 2: Don't fold Extracts feeding construct (bad indices) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 0\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %2 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0), + // Test case 3: Don't fold Extracts feeding construct (different type) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %struct_v2int_int_int %struct_v2int_int_int_null\n" + + "%3 = OpCompositeExtract %v2int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0) +)); + +INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold phi with the same values for all edges. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranchConditional %true %l1 %l2\n" + + "%l1 = OpLabel\n" + + " OpBranch %merge_lab\n" + + "%l2 = OpLabel\n" + + " OpBranch %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "%2 = OpPhi %int %100 %l1 %100 %l2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, INT_0_ID), + // Test case 1: Fold phi in pass through loop. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranch %l1\n" + + "%l1 = OpLabel\n" + + "%2 = OpPhi %int %100 %main_lab %2 %l1\n" + + " OpBranchConditional %true %l1 %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, INT_0_ID), + // Test case 2: Don't Fold phi because of different values. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranch %l1\n" + + "%l1 = OpLabel\n" + + "%2 = OpPhi %int %int_0 %main_lab %int_3 %l1\n" + + " OpBranchConditional %true %l1 %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) )); // clang-format off } // anonymous namespace diff --git a/test/opt/simplification_test.cpp b/test/opt/simplification_test.cpp new file mode 100644 index 000000000..76531bc19 --- /dev/null +++ b/test/opt/simplification_test.cpp @@ -0,0 +1,205 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opt/simplification_pass.h" + +#include "assembly_builder.h" +#include "gmock/gmock.h" +#include "pass_fixture.h" + +namespace { + +using namespace spvtools; + +using SimplificationTest = PassTest<::testing::Test>; + +#ifdef SPIRV_EFFCEE +TEST_F(SimplificationTest, StraightLineTest) { + // Testing that folding rules are combined in simple straight line code. + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %i %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %i "i" + OpName %o "o" + OpDecorate %i Flat + OpDecorate %i Location 0 + OpDecorate %o Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 + %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %int_1 = OpConstant %int 1 +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %8 + %21 = OpLabel + %31 = OpCompositeInsert %v4int %int_1 %13 0 +; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad + %23 = OpLoad %v4int %i + %33 = OpCompositeInsert %v4int %int_0 %23 0 + %35 = OpCompositeExtract %int %31 0 +; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1 + %37 = OpCompositeExtract %int %33 1 +; CHECK: [[add:%[a-zA-Z_\d]+]] = OpIAdd %int %int_1 [[extract]] + %29 = OpIAdd %int %35 %37 + OpStore %o %29 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(SimplificationTest, AcrossBasicBlocks) { + // Testing that folding rules are combined across basic blocks. + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %i %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %i "i" + OpName %o "o" + OpDecorate %i Flat + OpDecorate %i Location 0 + OpDecorate %o Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 +; CHECK: [[constant:%[a-zA-Z_\d]+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_int = OpTypePointer Input %int + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %8 + %24 = OpLabel +; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad %v4int %i + %25 = OpLoad %v4int %i + %41 = OpCompositeInsert %v4int %int_0 %25 0 + %27 = OpAccessChain %_ptr_Input_int %i %uint_0 + %28 = OpLoad %int %27 + %29 = OpSGreaterThan %bool %28 %int_10 + OpSelectionMerge %30 None + OpBranchConditional %29 %31 %32 + %31 = OpLabel + %43 = OpCopyObject %v4int %13 + OpBranch %30 + %32 = OpLabel + %45 = OpCopyObject %v4int %13 + OpBranch %30 + %30 = OpLabel + %50 = OpPhi %v4int %43 %31 %45 %32 +; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0 + %47 = OpCompositeExtract %int %50 0 +; CHECK: [[extract2:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1 + %49 = OpCompositeExtract %int %41 1 +; CHECK: [[add:%[a-zA-Z_\d]+]] = OpIAdd %int [[extract1]] [[extract2]] + %39 = OpIAdd %int %47 %49 + OpStore %o %39 + OpReturn + OpFunctionEnd + +)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(SimplificationTest, ThroughLoops) { + // Testing that folding rules are applied multiple times to instructions + // to be able to propagate across loop iterations. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o %i + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %o "o" + OpName %i "i" + OpDecorate %o Location 0 + OpDecorate %i Flat + OpDecorate %i Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 +; CHECK: [[constant:%[a-zA-Z_\d]+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %bool = OpTypeBool +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input + %68 = OpUndef %v4int + %main = OpFunction %void None %8 + %23 = OpLabel + OpBranch %24 + %24 = OpLabel + %67 = OpPhi %v4int %13 %23 %64 %26 +; CHECK: OpLoopMerge [[merge_lab:%[a-zA-Z_\d]+]] + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %48 = OpCompositeExtract %int %67 0 + %30 = OpIEqual %bool %48 %int_0 + OpBranchConditional %30 %31 %25 + %31 = OpLabel + %50 = OpCompositeExtract %int %67 0 + %54 = OpCompositeExtract %int %67 1 + %58 = OpCompositeExtract %int %67 2 + %62 = OpCompositeExtract %int %67 3 + %64 = OpCompositeConstruct %v4int %50 %54 %58 %62 + OpBranch %26 + %26 = OpLabel + OpBranch %24 + %25 = OpLabel +; CHECK: [[merge_lab]] = OpLabel +; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0 + %66 = OpCompositeExtract %int %67 0 +; CHECK-NEXT: OpStore %o [[extract]] + OpStore %o %66 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} +#endif +} // anonymous namespace diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 709e517d5..c4ed61e6d 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -255,6 +255,9 @@ Options (in lexicographical order): blank spaces, and in each pair, spec id and default value must be separated with colon ':' without any blank spaces in between. e.g.: --set-spec-const-default-value "1:100 2:400" + --simplify-instructions + Will simplfy all instructions in the function as much as + possible. --skip-validation Will not validate the SPIR-V before optimizing. If the SPIR-V is invalid, the optimizer may fail or generate incorrect code. @@ -465,6 +468,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, options->relax_struct_store = true; } else if (0 == strcmp(cur_arg, "--replace-invalid-opcode")) { optimizer->RegisterPass(CreateReplaceInvalidOpcodePass()); + } else if (0 == strcmp(cur_arg, "--simplify-instructions")) { + optimizer->RegisterPass(CreateSimplificationPass()); } else if (0 == strcmp(cur_arg, "--skip-validation")) { *skip_validator = true; } else if (0 == strcmp(cur_arg, "-O")) {