// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "reduce_load_size.h" #include #include "instruction.h" #include "ir_builder.h" #include "ir_context.h" namespace { const uint32_t kExtractCompositeIdInIdx = 0; const uint32_t kVariableStorageClassInIdx = 0; const uint32_t kLoadPointerInIdx = 0; const double kThreshold = 0.9; } // namespace namespace spvtools { namespace opt { Pass::Status ReduceLoadSize::Process() { bool modified = false; for (auto& func : *get_module()) { func.ForEachInst([&modified, this](Instruction* inst) { if (inst->opcode() == SpvOpCompositeExtract) { if (ShouldReplaceExtract(inst)) { modified |= ReplaceExtract(inst); } } }); } return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } bool ReduceLoadSize::ReplaceExtract(Instruction* inst) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); uint32_t composite_id = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); Instruction* composite_inst = def_use_mgr->GetDef(composite_id); if (composite_inst->opcode() != SpvOpLoad) { return false; } analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id()); if (composite_type->kind() == analysis::Type::kVector || composite_type->kind() == analysis::Type::kMatrix) { return false; } Instruction* var = composite_inst->GetBaseAddress(); if (var == nullptr || var->opcode() != SpvOpVariable) { return false; } SpvStorageClass storage_class = static_cast( var->GetSingleWordInOperand(kVariableStorageClassInIdx)); switch (storage_class) { case SpvStorageClassUniform: case SpvStorageClassUniformConstant: case SpvStorageClassInput: break; default: return false; } // Create a new access chain and load just after the old load. // We cannot create the new access chain load in the position of the extract // because the storage may have been written to in between. InstructionBuilder ir_builder( inst->context(), composite_inst, IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); uint32_t pointer_to_result_type_id = type_mgr->FindPointerToType(inst->type_id(), storage_class); assert(pointer_to_result_type_id != 0 && "We did not find the pointer type that we need."); analysis::Integer int_type(32, false); const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type); std::vector ids; for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { uint32_t index = inst->GetSingleWordInOperand(i); const analysis::Constant* index_const = const_mgr->GetConstant(uint32_type, {index}); ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id()); } Instruction* new_access_chain = ir_builder.AddAccessChain( pointer_to_result_type_id, composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids); Instruction* new_laod = ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id()); context()->ReplaceAllUsesWith(inst->result_id(), new_laod->result_id()); context()->KillInst(inst); return true; } bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) { analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); Instruction* op_inst = def_use_mgr->GetDef( inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)); if (op_inst->opcode() != SpvOpLoad) { return false; } auto cached_result = should_replace_cache_.find(op_inst->result_id()); if (cached_result != should_replace_cache_.end()) { return cached_result->second; } bool all_elements_used = false; std::set elements_used; all_elements_used = !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) { if (use->opcode() != SpvOpCompositeExtract) { return false; } elements_used.insert(use->GetSingleWordInOperand(1)); return true; }); bool should_replace = false; if (all_elements_used) { should_replace = false; } else { analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::Type* load_type = type_mgr->GetType(op_inst->type_id()); uint32_t total_size = 1; switch (load_type->kind()) { case analysis::Type::kArray: { const analysis::Constant* size_const = const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId()); assert(size_const->AsIntConstant()); total_size = size_const->GetU32(); } break; case analysis::Type::kStruct: total_size = static_cast( load_type->AsStruct()->element_types().size()); break; default: break; } double percent_used = static_cast(elements_used.size()) / static_cast(total_size); should_replace = (percent_used < kThreshold); } should_replace_cache_[op_inst->result_id()] = should_replace; return should_replace; } } // namespace opt } // namespace spvtools