// Copyright (c) 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "spirv-tools/optimizer.hpp" #include #include #include #include #include #include #include #include "source/opt/build_module.h" #include "source/opt/graphics_robust_access_pass.h" #include "source/opt/log.h" #include "source/opt/pass_manager.h" #include "source/opt/passes.h" #include "source/spirv_optimizer_options.h" #include "source/util/make_unique.h" #include "source/util/string_utils.h" namespace spvtools { std::vector GetVectorOfStrings(const char** strings, const size_t string_count) { std::vector result; for (uint32_t i = 0; i < string_count; i++) { result.emplace_back(strings[i]); } return result; } struct Optimizer::PassToken::Impl { Impl(std::unique_ptr p) : pass(std::move(p)) {} std::unique_ptr pass; // Internal implementation pass. }; Optimizer::PassToken::PassToken( std::unique_ptr impl) : impl_(std::move(impl)) {} Optimizer::PassToken::PassToken(std::unique_ptr&& pass) : impl_(MakeUnique(std::move(pass))) {} Optimizer::PassToken::PassToken(PassToken&& that) : impl_(std::move(that.impl_)) {} Optimizer::PassToken& Optimizer::PassToken::operator=(PassToken&& that) { impl_ = std::move(that.impl_); return *this; } Optimizer::PassToken::~PassToken() {} struct Optimizer::Impl { explicit Impl(spv_target_env env) : target_env(env), pass_manager() {} spv_target_env target_env; // Target environment. opt::PassManager pass_manager; // Internal implementation pass manager. std::unordered_set live_locs; // Arg to debug dead output passes }; Optimizer::Optimizer(spv_target_env env) : impl_(new Impl(env)) { assert(env != SPV_ENV_WEBGPU_0); } Optimizer::~Optimizer() {} void Optimizer::SetMessageConsumer(MessageConsumer c) { // All passes' message consumer needs to be updated. for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); ++i) { impl_->pass_manager.GetPass(i)->SetMessageConsumer(c); } impl_->pass_manager.SetMessageConsumer(std::move(c)); } const MessageConsumer& Optimizer::consumer() const { return impl_->pass_manager.consumer(); } Optimizer& Optimizer::RegisterPass(PassToken&& p) { // Change to use the pass manager's consumer. p.impl_->pass->SetMessageConsumer(consumer()); impl_->pass_manager.AddPass(std::move(p.impl_->pass)); return *this; } // The legalization passes take a spir-v shader generated by an HLSL front-end // and turn it into a valid vulkan spir-v shader. There are two ways in which // the code will be invalid at the start: // // 1) There will be opaque objects, like images, which will be passed around // in intermediate objects. Valid spir-v will have to replace the use of // the opaque object with an intermediate object that is the result of the // load of the global opaque object. // // 2) There will be variables that contain pointers to structured or uniform // buffers. It be legal, the variables must be eliminated, and the // references to the structured buffers must use the result of OpVariable // in the Uniform storage class. // // Optimization in this list must accept shaders with these relaxation of the // rules. There is not guarantee that this list of optimizations is able to // legalize all inputs, but it is on a best effort basis. // // The legalization problem is essentially a very general copy propagation // problem. The optimization we use are all used to either do copy propagation // or enable more copy propagation. Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) { return // Wrap OpKill instructions so all other code can be inlined. RegisterPass(CreateWrapOpKillPass()) // Remove unreachable block so that merge return works. .RegisterPass(CreateDeadBranchElimPass()) // Merge the returns so we can inline. .RegisterPass(CreateMergeReturnPass()) // Make sure uses and definitions are in the same function. .RegisterPass(CreateInlineExhaustivePass()) // Make private variable function scope .RegisterPass(CreateEliminateDeadFunctionsPass()) .RegisterPass(CreatePrivateToLocalPass()) // Fix up the storage classes that DXC may have purposely generated // incorrectly. All functions are inlined, and a lot of dead code has // been removed. .RegisterPass(CreateFixStorageClassPass()) // Propagate the value stored to the loads in very simple cases. .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) // Split up aggregates so they are easier to deal with. .RegisterPass(CreateScalarReplacementPass(0)) // Remove loads and stores so everything is in intermediate values. // Takes care of copy propagation of non-members. .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateLocalMultiStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) // Propagate constants to get as many constant conditions on branches // as possible. .RegisterPass(CreateCCPPass()) .RegisterPass(CreateLoopUnrollPass(true)) .RegisterPass(CreateDeadBranchElimPass()) // Copy propagate members. Cleans up code sequences generated by // scalar replacement. Also important for removing OpPhi nodes. .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateCopyPropagateArraysPass()) // May need loop unrolling here see // https://github.com/Microsoft/DirectXShaderCompiler/pull/930 // Get rid of unused code that contain traces of illegal code // or unused references to unbound external objects .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateReduceLoadSizePass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass()) .RegisterPass(CreateInterpolateFixupPass()) .RegisterPass(CreateInvocationInterlockPlacementPass()); } Optimizer& Optimizer::RegisterLegalizationPasses() { return RegisterLegalizationPasses(false); } Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) { return RegisterPass(CreateWrapOpKillPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateMergeReturnPass()) .RegisterPass(CreateInlineExhaustivePass()) .RegisterPass(CreateEliminateDeadFunctionsPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreatePrivateToLocalPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateScalarReplacementPass()) .RegisterPass(CreateLocalAccessChainConvertPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateLocalMultiStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateCCPPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateLoopUnrollPass(true)) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateCombineAccessChainsPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateScalarReplacementPass()) .RegisterPass(CreateLocalAccessChainConvertPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateSSARewritePass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateIfConversionPass()) .RegisterPass(CreateCopyPropagateArraysPass()) .RegisterPass(CreateReduceLoadSizePass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateBlockMergePass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateBlockMergePass()) .RegisterPass(CreateSimplificationPass()); } Optimizer& Optimizer::RegisterPerformancePasses() { return RegisterPerformancePasses(false); } Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) { return RegisterPass(CreateWrapOpKillPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateMergeReturnPass()) .RegisterPass(CreateInlineExhaustivePass()) .RegisterPass(CreateEliminateDeadFunctionsPass()) .RegisterPass(CreatePrivateToLocalPass()) .RegisterPass(CreateScalarReplacementPass(0)) .RegisterPass(CreateLocalMultiStoreElimPass()) .RegisterPass(CreateCCPPass()) .RegisterPass(CreateLoopUnrollPass(true)) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateScalarReplacementPass(0)) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateIfConversionPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateBlockMergePass()) .RegisterPass(CreateLocalAccessChainConvertPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateCopyPropagateArraysPass()) .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateEliminateDeadMembersPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateBlockMergePass()) .RegisterPass(CreateLocalMultiStoreElimPass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) .RegisterPass(CreateCFGCleanupPass()); } Optimizer& Optimizer::RegisterSizePasses() { return RegisterSizePasses(false); } bool Optimizer::RegisterPassesFromFlags(const std::vector& flags) { return RegisterPassesFromFlags(flags, false); } bool Optimizer::RegisterPassesFromFlags(const std::vector& flags, bool preserve_interface) { for (const auto& flag : flags) { if (!RegisterPassFromFlag(flag, preserve_interface)) { return false; } } return true; } bool Optimizer::FlagHasValidForm(const std::string& flag) const { if (flag == "-O" || flag == "-Os") { return true; } else if (flag.size() > 2 && flag.substr(0, 2) == "--") { return true; } Errorf(consumer(), nullptr, {}, "%s is not a valid flag. Flag passes should have the form " "'--pass_name[=pass_args]'. Special flag names also accepted: -O " "and -Os.", flag.c_str()); return false; } bool Optimizer::RegisterPassFromFlag(const std::string& flag) { return RegisterPassFromFlag(flag, false); } bool Optimizer::RegisterPassFromFlag(const std::string& flag, bool preserve_interface) { if (!FlagHasValidForm(flag)) { return false; } // Split flags of the form --pass_name=pass_args. auto p = utils::SplitFlagArgs(flag); std::string pass_name = p.first; std::string pass_args = p.second; // FIXME(dnovillo): This should be re-factored so that pass names can be // automatically checked against Pass::name() and PassToken instances created // via a template function. Additionally, class Pass should have a desc() // method that describes the pass (so it can be used in --help). // // Both Pass::name() and Pass::desc() should be static class members so they // can be invoked without creating a pass instance. if (pass_name == "strip-debug") { RegisterPass(CreateStripDebugInfoPass()); } else if (pass_name == "strip-reflect") { RegisterPass(CreateStripReflectInfoPass()); } else if (pass_name == "strip-nonsemantic") { RegisterPass(CreateStripNonSemanticInfoPass()); } else if (pass_name == "set-spec-const-default-value") { if (pass_args.size() > 0) { auto spec_ids_vals = opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString( pass_args.c_str()); if (!spec_ids_vals) { Errorf(consumer(), nullptr, {}, "Invalid argument for --set-spec-const-default-value: %s", pass_args.c_str()); return false; } RegisterPass( CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals))); } else { Errorf(consumer(), nullptr, {}, "Invalid spec constant value string '%s'. Expected a string of " ": pairs.", pass_args.c_str()); return false; } } else if (pass_name == "if-conversion") { RegisterPass(CreateIfConversionPass()); } else if (pass_name == "freeze-spec-const") { RegisterPass(CreateFreezeSpecConstantValuePass()); } else if (pass_name == "inline-entry-points-exhaustive") { RegisterPass(CreateInlineExhaustivePass()); } else if (pass_name == "inline-entry-points-opaque") { RegisterPass(CreateInlineOpaquePass()); } else if (pass_name == "combine-access-chains") { RegisterPass(CreateCombineAccessChainsPass()); } else if (pass_name == "convert-local-access-chains") { RegisterPass(CreateLocalAccessChainConvertPass()); } else if (pass_name == "replace-desc-array-access-using-var-index") { RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass()); } else if (pass_name == "spread-volatile-semantics") { RegisterPass(CreateSpreadVolatileSemanticsPass()); } else if (pass_name == "descriptor-scalar-replacement") { RegisterPass(CreateDescriptorScalarReplacementPass()); } else if (pass_name == "eliminate-dead-code-aggressive") { RegisterPass(CreateAggressiveDCEPass(preserve_interface)); } else if (pass_name == "eliminate-insert-extract") { RegisterPass(CreateInsertExtractElimPass()); } else if (pass_name == "eliminate-local-single-block") { RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()); } else if (pass_name == "eliminate-local-single-store") { RegisterPass(CreateLocalSingleStoreElimPass()); } else if (pass_name == "merge-blocks") { RegisterPass(CreateBlockMergePass()); } else if (pass_name == "merge-return") { RegisterPass(CreateMergeReturnPass()); } else if (pass_name == "eliminate-dead-branches") { RegisterPass(CreateDeadBranchElimPass()); } else if (pass_name == "eliminate-dead-functions") { RegisterPass(CreateEliminateDeadFunctionsPass()); } else if (pass_name == "eliminate-local-multi-store") { RegisterPass(CreateLocalMultiStoreElimPass()); } else if (pass_name == "eliminate-dead-const") { RegisterPass(CreateEliminateDeadConstantPass()); } else if (pass_name == "eliminate-dead-inserts") { RegisterPass(CreateDeadInsertElimPass()); } else if (pass_name == "eliminate-dead-variables") { RegisterPass(CreateDeadVariableEliminationPass()); } else if (pass_name == "eliminate-dead-members") { RegisterPass(CreateEliminateDeadMembersPass()); } else if (pass_name == "fold-spec-const-op-composite") { RegisterPass(CreateFoldSpecConstantOpAndCompositePass()); } else if (pass_name == "loop-unswitch") { RegisterPass(CreateLoopUnswitchPass()); } else if (pass_name == "scalar-replacement") { if (pass_args.size() == 0) { RegisterPass(CreateScalarReplacementPass()); } else { int limit = -1; if (pass_args.find_first_not_of("0123456789") == std::string::npos) { limit = atoi(pass_args.c_str()); } if (limit >= 0) { RegisterPass(CreateScalarReplacementPass(limit)); } else { Error(consumer(), nullptr, {}, "--scalar-replacement must have no arguments or a non-negative " "integer argument"); return false; } } } else if (pass_name == "strength-reduction") { RegisterPass(CreateStrengthReductionPass()); } else if (pass_name == "unify-const") { RegisterPass(CreateUnifyConstantPass()); } else if (pass_name == "flatten-decorations") { RegisterPass(CreateFlattenDecorationPass()); } else if (pass_name == "compact-ids") { RegisterPass(CreateCompactIdsPass()); } else if (pass_name == "cfg-cleanup") { RegisterPass(CreateCFGCleanupPass()); } else if (pass_name == "local-redundancy-elimination") { RegisterPass(CreateLocalRedundancyEliminationPass()); } else if (pass_name == "loop-invariant-code-motion") { RegisterPass(CreateLoopInvariantCodeMotionPass()); } else if (pass_name == "reduce-load-size") { if (pass_args.size() == 0) { RegisterPass(CreateReduceLoadSizePass()); } else { double load_replacement_threshold = 0.9; if (pass_args.find_first_not_of(".0123456789") == std::string::npos) { load_replacement_threshold = atof(pass_args.c_str()); } if (load_replacement_threshold >= 0) { RegisterPass(CreateReduceLoadSizePass(load_replacement_threshold)); } else { Error(consumer(), nullptr, {}, "--reduce-load-size must have no arguments or a non-negative " "double argument"); return false; } } } else if (pass_name == "redundancy-elimination") { RegisterPass(CreateRedundancyEliminationPass()); } else if (pass_name == "private-to-local") { RegisterPass(CreatePrivateToLocalPass()); } else if (pass_name == "remove-duplicates") { RegisterPass(CreateRemoveDuplicatesPass()); } else if (pass_name == "workaround-1209") { RegisterPass(CreateWorkaround1209Pass()); } else if (pass_name == "replace-invalid-opcode") { RegisterPass(CreateReplaceInvalidOpcodePass()); } else if (pass_name == "convert-relaxed-to-half") { RegisterPass(CreateConvertRelaxedToHalfPass()); } else if (pass_name == "relax-float-ops") { RegisterPass(CreateRelaxFloatOpsPass()); } else if (pass_name == "inst-debug-printf") { // This private option is not for user consumption. // It is here to assist in debugging and fixing the debug printf // instrumentation pass. // For users who wish to utilize debug printf, see the white paper at // https://www.lunarg.com/wp-content/uploads/2021/08/Using-Debug-Printf-02August2021.pdf RegisterPass(CreateInstDebugPrintfPass(7, 23)); } else if (pass_name == "simplify-instructions") { RegisterPass(CreateSimplificationPass()); } else if (pass_name == "ssa-rewrite") { RegisterPass(CreateSSARewritePass()); } else if (pass_name == "copy-propagate-arrays") { RegisterPass(CreateCopyPropagateArraysPass()); } else if (pass_name == "loop-fission") { int register_threshold_to_split = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; if (register_threshold_to_split > 0) { RegisterPass(CreateLoopFissionPass( static_cast(register_threshold_to_split))); } else { Error(consumer(), nullptr, {}, "--loop-fission must have a positive integer argument"); return false; } } else if (pass_name == "loop-fusion") { int max_registers_per_loop = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; if (max_registers_per_loop > 0) { RegisterPass( CreateLoopFusionPass(static_cast(max_registers_per_loop))); } else { Error(consumer(), nullptr, {}, "--loop-fusion must have a positive integer argument"); return false; } } else if (pass_name == "loop-unroll") { RegisterPass(CreateLoopUnrollPass(true)); } else if (pass_name == "upgrade-memory-model") { RegisterPass(CreateUpgradeMemoryModelPass()); } else if (pass_name == "vector-dce") { RegisterPass(CreateVectorDCEPass()); } else if (pass_name == "loop-unroll-partial") { int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; if (factor > 0) { RegisterPass(CreateLoopUnrollPass(false, factor)); } else { Error(consumer(), nullptr, {}, "--loop-unroll-partial must have a positive integer argument"); return false; } } else if (pass_name == "loop-peeling") { RegisterPass(CreateLoopPeelingPass()); } else if (pass_name == "loop-peeling-threshold") { int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; if (factor > 0) { opt::LoopPeelingPass::SetLoopPeelingThreshold(factor); } else { Error(consumer(), nullptr, {}, "--loop-peeling-threshold must have a positive integer argument"); return false; } } else if (pass_name == "ccp") { RegisterPass(CreateCCPPass()); } else if (pass_name == "code-sink") { RegisterPass(CreateCodeSinkingPass()); } else if (pass_name == "fix-storage-class") { RegisterPass(CreateFixStorageClassPass()); } else if (pass_name == "O") { RegisterPerformancePasses(preserve_interface); } else if (pass_name == "Os") { RegisterSizePasses(preserve_interface); } else if (pass_name == "legalize-hlsl") { RegisterLegalizationPasses(preserve_interface); } else if (pass_name == "remove-unused-interface-variables") { RegisterPass(CreateRemoveUnusedInterfaceVariablesPass()); } else if (pass_name == "graphics-robust-access") { RegisterPass(CreateGraphicsRobustAccessPass()); } else if (pass_name == "wrap-opkill") { RegisterPass(CreateWrapOpKillPass()); } else if (pass_name == "amd-ext-to-khr") { RegisterPass(CreateAmdExtToKhrPass()); } else if (pass_name == "interpolate-fixup") { RegisterPass(CreateInterpolateFixupPass()); } else if (pass_name == "remove-dont-inline") { RegisterPass(CreateRemoveDontInlinePass()); } else if (pass_name == "eliminate-dead-input-components") { RegisterPass(CreateEliminateDeadInputComponentsSafePass()); } else if (pass_name == "fix-func-call-param") { RegisterPass(CreateFixFuncCallArgumentsPass()); } else if (pass_name == "convert-to-sampled-image") { if (pass_args.size() > 0) { auto descriptor_set_binding_pairs = opt::ConvertToSampledImagePass::ParseDescriptorSetBindingPairsString( pass_args.c_str()); if (!descriptor_set_binding_pairs) { Errorf(consumer(), nullptr, {}, "Invalid argument for --convert-to-sampled-image: %s", pass_args.c_str()); return false; } RegisterPass(CreateConvertToSampledImagePass( std::move(*descriptor_set_binding_pairs))); } else { Errorf(consumer(), nullptr, {}, "Invalid pairs of descriptor set and binding '%s'. Expected a " "string of : pairs.", pass_args.c_str()); return false; } } else if (pass_name == "switch-descriptorset") { if (pass_args.size() == 0) { Error(consumer(), nullptr, {}, "--switch-descriptorset requires a from:to argument."); return false; } uint32_t from_set = 0, to_set = 0; const char* start = pass_args.data(); const char* end = pass_args.data() + pass_args.size(); auto result = std::from_chars(start, end, from_set); if (result.ec != std::errc()) { Errorf(consumer(), nullptr, {}, "Invalid argument for --switch-descriptorset: %s", pass_args.c_str()); return false; } start = result.ptr; if (start[0] != ':') { Errorf(consumer(), nullptr, {}, "Invalid argument for --switch-descriptorset: %s", pass_args.c_str()); return false; } start++; result = std::from_chars(start, end, to_set); if (result.ec != std::errc() || result.ptr != end) { Errorf(consumer(), nullptr, {}, "Invalid argument for --switch-descriptorset: %s", pass_args.c_str()); return false; } RegisterPass(CreateSwitchDescriptorSetPass(from_set, to_set)); } else if (pass_name == "modify-maximal-reconvergence") { if (pass_args.size() == 0) { Error(consumer(), nullptr, {}, "--modify-maximal-reconvergence requires an argument"); return false; } if (pass_args == "add") { RegisterPass(CreateModifyMaximalReconvergencePass(true)); } else if (pass_args == "remove") { RegisterPass(CreateModifyMaximalReconvergencePass(false)); } else { Errorf(consumer(), nullptr, {}, "Invalid argument for --modify-maximal-reconvergence: %s (must be " "'add' or 'remove')", pass_args.c_str()); return false; } } else if (pass_name == "trim-capabilities") { RegisterPass(CreateTrimCapabilitiesPass()); } else { Errorf(consumer(), nullptr, {}, "Unknown flag '--%s'. Use --help for a list of valid flags", pass_name.c_str()); return false; } return true; } void Optimizer::SetTargetEnv(const spv_target_env env) { impl_->target_env = env; } bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector* optimized_binary) const { return Run(original_binary, original_binary_size, optimized_binary, OptimizerOptions()); } bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector* optimized_binary, const ValidatorOptions& validator_options, bool skip_validation) const { OptimizerOptions opt_options; opt_options.set_run_validator(!skip_validation); opt_options.set_validator_options(validator_options); return Run(original_binary, original_binary_size, optimized_binary, opt_options); } bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector* optimized_binary, const spv_optimizer_options opt_options) const { spvtools::SpirvTools tools(impl_->target_env); tools.SetMessageConsumer(impl_->pass_manager.consumer()); if (opt_options->run_validator_ && !tools.Validate(original_binary, original_binary_size, &opt_options->val_options_)) { return false; } std::unique_ptr context = BuildModule( impl_->target_env, consumer(), original_binary, original_binary_size); if (context == nullptr) return false; context->set_max_id_bound(opt_options->max_id_bound_); context->set_preserve_bindings(opt_options->preserve_bindings_); context->set_preserve_spec_constants(opt_options->preserve_spec_constants_); impl_->pass_manager.SetValidatorOptions(&opt_options->val_options_); impl_->pass_manager.SetTargetEnv(impl_->target_env); auto status = impl_->pass_manager.Run(context.get()); if (status == opt::Pass::Status::Failure) { return false; } #ifndef NDEBUG // We do not keep the result id of DebugScope in struct DebugScope. // Instead, we assign random ids for them, which results in integrity // check failures. In addition, propagating the OpLine/OpNoLine to preserve // the debug information through transformations results in integrity // check failures. We want to skip the integrity check when the module // contains DebugScope or OpLine/OpNoLine instructions. if (status == opt::Pass::Status::SuccessWithoutChange && !context->module()->ContainsDebugInfo()) { std::vector optimized_binary_with_nop; context->module()->ToBinary(&optimized_binary_with_nop, /* skip_nop = */ false); assert(optimized_binary_with_nop.size() == original_binary_size && "Binary size unexpectedly changed despite the optimizer saying " "there was no change"); // Compare the magic number to make sure the binaries were encoded in the // endianness. If not, the contents of the binaries will be different, so // do not check the contents. if (optimized_binary_with_nop[0] == original_binary[0]) { assert(memcmp(optimized_binary_with_nop.data(), original_binary, original_binary_size) == 0 && "Binary content unexpectedly changed despite the optimizer saying " "there was no change"); } } #endif // !NDEBUG // Note that |original_binary| and |optimized_binary| may share the same // buffer and the below will invalidate |original_binary|. optimized_binary->clear(); context->module()->ToBinary(optimized_binary, /* skip_nop = */ true); return true; } Optimizer& Optimizer::SetPrintAll(std::ostream* out) { impl_->pass_manager.SetPrintAll(out); return *this; } Optimizer& Optimizer::SetTimeReport(std::ostream* out) { impl_->pass_manager.SetTimeReport(out); return *this; } Optimizer& Optimizer::SetValidateAfterAll(bool validate) { impl_->pass_manager.SetValidateAfterAll(validate); return *this; } Optimizer::PassToken CreateNullPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateStripDebugInfoPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateStripReflectInfoPass() { return CreateStripNonSemanticInfoPass(); } Optimizer::PassToken CreateStripNonSemanticInfoPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateEliminateDeadFunctionsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateEliminateDeadMembersPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( const std::unordered_map& id_value_map) { return MakeUnique( MakeUnique(id_value_map)); } Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( const std::unordered_map>& id_value_map) { return MakeUnique( MakeUnique(id_value_map)); } Optimizer::PassToken CreateFlattenDecorationPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateFreezeSpecConstantValuePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateFoldSpecConstantOpAndCompositePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateUnifyConstantPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateEliminateDeadConstantPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateDeadVariableEliminationPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateStrengthReductionPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateBlockMergePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateInlineExhaustivePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateInlineOpaquePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLocalAccessChainConvertPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLocalSingleStoreElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateInsertExtractElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateDeadInsertElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateDeadBranchElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLocalMultiStoreElimPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateAggressiveDCEPass() { return MakeUnique( MakeUnique(false, false)); } Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface) { return MakeUnique( MakeUnique(preserve_interface, false)); } Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface, bool remove_outputs) { return MakeUnique( MakeUnique(preserve_interface, remove_outputs)); } Optimizer::PassToken CreateRemoveUnusedInterfaceVariablesPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreatePropagateLineInfoPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateRedundantLineInfoElimPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateCompactIdsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateMergeReturnPass() { return MakeUnique( MakeUnique()); } std::vector Optimizer::GetPassNames() const { std::vector v; for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); i++) { v.push_back(impl_->pass_manager.GetPass(i)->name()); } return v; } Optimizer::PassToken CreateCFGCleanupPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLocalRedundancyEliminationPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLoopFissionPass(size_t threshold) { return MakeUnique( MakeUnique(threshold)); } Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop) { return MakeUnique( MakeUnique(max_registers_per_loop)); } Optimizer::PassToken CreateLoopInvariantCodeMotionPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateLoopPeelingPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLoopUnswitchPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateRedundancyEliminationPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateRemoveDuplicatesPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit) { return MakeUnique( MakeUnique(size_limit)); } Optimizer::PassToken CreatePrivateToLocalPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateCCPPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateWorkaround1209Pass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateIfConversionPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateReplaceInvalidOpcodePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateSimplificationPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) { return MakeUnique( MakeUnique(fully_unroll, factor)); } Optimizer::PassToken CreateSSARewritePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateCopyPropagateArraysPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateVectorDCEPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateReduceLoadSizePass( double load_replacement_threshold) { return MakeUnique( MakeUnique(load_replacement_threshold)); } Optimizer::PassToken CreateCombineAccessChainsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateUpgradeMemoryModelPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set, uint32_t shader_id) { return MakeUnique( MakeUnique(desc_set, shader_id)); } Optimizer::PassToken CreateConvertRelaxedToHalfPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateRelaxFloatOpsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateCodeSinkingPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateFixStorageClassPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateGraphicsRobustAccessPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateSpreadVolatileSemanticsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateDescriptorScalarReplacementPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateWrapOpKillPass() { return MakeUnique(MakeUnique()); } Optimizer::PassToken CreateAmdExtToKhrPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateInterpolateFixupPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateEliminateDeadInputComponentsPass() { return MakeUnique( MakeUnique(spv::StorageClass::Input, /* safe_mode */ false)); } Optimizer::PassToken CreateEliminateDeadOutputComponentsPass() { return MakeUnique( MakeUnique(spv::StorageClass::Output, /* safe_mode */ false)); } Optimizer::PassToken CreateEliminateDeadInputComponentsSafePass() { return MakeUnique( MakeUnique(spv::StorageClass::Input, /* safe_mode */ true)); } Optimizer::PassToken CreateAnalyzeLiveInputPass( std::unordered_set* live_locs, std::unordered_set* live_builtins) { return MakeUnique( MakeUnique(live_locs, live_builtins)); } Optimizer::PassToken CreateEliminateDeadOutputStoresPass( std::unordered_set* live_locs, std::unordered_set* live_builtins) { return MakeUnique( MakeUnique(live_locs, live_builtins)); } Optimizer::PassToken CreateConvertToSampledImagePass( const std::vector& descriptor_set_binding_pairs) { return MakeUnique( MakeUnique(descriptor_set_binding_pairs)); } Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateRemoveDontInlinePass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateFixFuncCallArgumentsPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateTrimCapabilitiesPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t from, uint32_t to) { return MakeUnique( MakeUnique(from, to)); } Optimizer::PassToken CreateInvocationInterlockPlacementPass() { return MakeUnique( MakeUnique()); } Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add) { return MakeUnique( MakeUnique(add)); } } // namespace spvtools extern "C" { SPIRV_TOOLS_EXPORT spv_optimizer_t* spvOptimizerCreate(spv_target_env env) { return reinterpret_cast(new spvtools::Optimizer(env)); } SPIRV_TOOLS_EXPORT void spvOptimizerDestroy(spv_optimizer_t* optimizer) { delete reinterpret_cast(optimizer); } SPIRV_TOOLS_EXPORT void spvOptimizerSetMessageConsumer( spv_optimizer_t* optimizer, spv_message_consumer consumer) { reinterpret_cast(optimizer)-> SetMessageConsumer( [consumer](spv_message_level_t level, const char* source, const spv_position_t& position, const char* message) { return consumer(level, source, &position, message); }); } SPIRV_TOOLS_EXPORT void spvOptimizerRegisterLegalizationPasses( spv_optimizer_t* optimizer) { reinterpret_cast(optimizer)-> RegisterLegalizationPasses(); } SPIRV_TOOLS_EXPORT void spvOptimizerRegisterPerformancePasses( spv_optimizer_t* optimizer) { reinterpret_cast(optimizer)-> RegisterPerformancePasses(); } SPIRV_TOOLS_EXPORT void spvOptimizerRegisterSizePasses( spv_optimizer_t* optimizer) { reinterpret_cast(optimizer)->RegisterSizePasses(); } SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassFromFlag( spv_optimizer_t* optimizer, const char* flag) { return reinterpret_cast(optimizer)-> RegisterPassFromFlag(flag); } SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlags( spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) { std::vector opt_flags = spvtools::GetVectorOfStrings(flags, flag_count); return reinterpret_cast(optimizer) ->RegisterPassesFromFlags(opt_flags, false); } SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface( spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) { std::vector opt_flags = spvtools::GetVectorOfStrings(flags, flag_count); return reinterpret_cast(optimizer) ->RegisterPassesFromFlags(opt_flags, true); } SPIRV_TOOLS_EXPORT spv_result_t spvOptimizerRun(spv_optimizer_t* optimizer, const uint32_t* binary, const size_t word_count, spv_binary* optimized_binary, const spv_optimizer_options options) { std::vector optimized; if (!reinterpret_cast(optimizer)-> Run(binary, word_count, &optimized, options)) { return SPV_ERROR_INTERNAL; } auto result_binary = new spv_binary_t(); if (!result_binary) { *optimized_binary = nullptr; return SPV_ERROR_OUT_OF_MEMORY; } result_binary->code = new uint32_t[optimized.size()]; if (!result_binary->code) { delete result_binary; *optimized_binary = nullptr; return SPV_ERROR_OUT_OF_MEMORY; } result_binary->wordCount = optimized.size(); memcpy(result_binary->code, optimized.data(), optimized.size() * sizeof(uint32_t)); *optimized_binary = result_binary; return SPV_SUCCESS; } } // extern "C"