// Copyright (c) 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/scalar_replacement_pass.h" #include #include #include #include #include "source/enum_string_mapping.h" #include "source/extensions.h" #include "source/opt/reflect.h" #include "source/opt/types.h" #include "source/util/make_unique.h" #include "types.h" static const uint32_t kDebugValueOperandValueIndex = 5; static const uint32_t kDebugValueOperandExpressionIndex = 6; static const uint32_t kDebugDeclareOperandVariableIndex = 5; namespace spvtools { namespace opt { Pass::Status ScalarReplacementPass::Process() { Status status = Status::SuccessWithoutChange; for (auto& f : *get_module()) { if (f.IsDeclaration()) { continue; } Status functionStatus = ProcessFunction(&f); if (functionStatus == Status::Failure) return functionStatus; else if (functionStatus == Status::SuccessWithChange) status = functionStatus; } return status; } Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { std::queue worklist; BasicBlock& entry = *function->begin(); for (auto iter = entry.begin(); iter != entry.end(); ++iter) { // Function storage class OpVariables must appear as the first instructions // of the entry block. if (iter->opcode() != SpvOpVariable) break; Instruction* varInst = &*iter; if (CanReplaceVariable(varInst)) { worklist.push(varInst); } } Status status = Status::SuccessWithoutChange; while (!worklist.empty()) { Instruction* varInst = worklist.front(); worklist.pop(); Status var_status = ReplaceVariable(varInst, &worklist); if (var_status == Status::Failure) return var_status; else if (var_status == Status::SuccessWithChange) status = var_status; } return status; } Pass::Status ScalarReplacementPass::ReplaceVariable( Instruction* inst, std::queue* worklist) { std::vector replacements; if (!CreateReplacementVariables(inst, &replacements)) { return Status::Failure; } std::vector dead; bool replaced_all_uses = get_def_use_mgr()->WhileEachUser( inst, [this, &replacements, &dead](Instruction* user) { if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) { if (ReplaceWholeDebugDeclare(user, replacements)) { dead.push_back(user); return true; } return false; } if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) { if (ReplaceWholeDebugValue(user, replacements)) { dead.push_back(user); return true; } return false; } if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpLoad: if (ReplaceWholeLoad(user, replacements)) { dead.push_back(user); } else { return false; } break; case SpvOpStore: if (ReplaceWholeStore(user, replacements)) { dead.push_back(user); } else { return false; } break; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (ReplaceAccessChain(user, replacements)) dead.push_back(user); else return false; break; case SpvOpName: case SpvOpMemberName: break; default: assert(false && "Unexpected opcode"); break; } } return true; }); if (replaced_all_uses) { dead.push_back(inst); } else { return Status::Failure; } // If there are no dead instructions to clean up, return with no changes. if (dead.empty()) return Status::SuccessWithoutChange; // Clean up some dead code. while (!dead.empty()) { Instruction* toKill = dead.back(); dead.pop_back(); context()->KillInst(toKill); } // Attempt to further scalarize. for (auto var : replacements) { if (var->opcode() == SpvOpVariable) { if (get_def_use_mgr()->NumUsers(var) == 0) { context()->KillInst(var); } else if (CanReplaceVariable(var)) { worklist->push(var); } } } return Status::SuccessWithChange; } bool ScalarReplacementPass::ReplaceWholeDebugDeclare( Instruction* dbg_decl, const std::vector& replacements) { // Insert Deref operation to the front of the operation list of |dbg_decl|. Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef( dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex)); auto* deref_expr = context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr); // Add DebugValue instruction with Indexes operand and Deref operation. int32_t idx = 0; for (const auto* var : replacements) { Instruction* insert_before = var->NextNode(); while (insert_before->opcode() == SpvOpVariable) insert_before = insert_before->NextNode(); assert(insert_before != nullptr && "unexpected end of list"); Instruction* added_dbg_value = context()->get_debug_info_mgr()->AddDebugValueForDecl( dbg_decl, /*value_id=*/var->result_id(), /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl); if (added_dbg_value == nullptr) return false; added_dbg_value->AddOperand( {SPV_OPERAND_TYPE_ID, {context()->get_constant_mgr()->GetSIntConst(idx)}}); added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex, {deref_expr->result_id()}); if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) { context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value); } ++idx; } return true; } bool ScalarReplacementPass::ReplaceWholeDebugValue( Instruction* dbg_value, const std::vector& replacements) { int32_t idx = 0; BasicBlock* block = context()->get_instr_block(dbg_value); for (auto var : replacements) { // Clone the DebugValue. std::unique_ptr new_dbg_value(dbg_value->Clone(context())); uint32_t new_id = TakeNextId(); if (new_id == 0) return false; new_dbg_value->SetResultId(new_id); // Update 'Value' operand to the |replacements|. new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()}); // Append 'Indexes' operand. new_dbg_value->AddOperand( {SPV_OPERAND_TYPE_ID, {context()->get_constant_mgr()->GetSIntConst(idx)}}); // Insert the new DebugValue to the basic block. auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value)); get_def_use_mgr()->AnalyzeInstDefUse(added_instr); context()->set_instr_block(added_instr, block); ++idx; } return true; } bool ScalarReplacementPass::ReplaceWholeLoad( Instruction* load, const std::vector& replacements) { // Replaces the load of the entire composite with a load from each replacement // variable followed by a composite construction. BasicBlock* block = context()->get_instr_block(load); std::vector loads; loads.reserve(replacements.size()); BasicBlock::iterator where(load); for (auto var : replacements) { // Create a load of each replacement variable. if (var->opcode() != SpvOpVariable) { loads.push_back(var); continue; } Instruction* type = GetStorageType(var); uint32_t loadId = TakeNextId(); if (loadId == 0) { return false; } std::unique_ptr newLoad( new Instruction(context(), SpvOpLoad, type->result_id(), loadId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Copy memory access attributes which start at index 1. Index 0 is the // pointer to load. for (uint32_t i = 1; i < load->NumInOperands(); ++i) { Operand copy(load->GetInOperand(i)); newLoad->AddOperand(std::move(copy)); } where = where.InsertBefore(std::move(newLoad)); get_def_use_mgr()->AnalyzeInstDefUse(&*where); context()->set_instr_block(&*where, block); where->UpdateDebugInfoFrom(load); loads.push_back(&*where); } // Construct a new composite. uint32_t compositeId = TakeNextId(); if (compositeId == 0) { return false; } where = load; std::unique_ptr compositeConstruct(new Instruction( context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); for (auto l : loads) { Operand op(SPV_OPERAND_TYPE_ID, std::initializer_list{l->result_id()}); compositeConstruct->AddOperand(std::move(op)); } where = where.InsertBefore(std::move(compositeConstruct)); get_def_use_mgr()->AnalyzeInstDefUse(&*where); where->UpdateDebugInfoFrom(load); context()->set_instr_block(&*where, block); context()->ReplaceAllUsesWith(load->result_id(), compositeId); return true; } bool ScalarReplacementPass::ReplaceWholeStore( Instruction* store, const std::vector& replacements) { // Replaces a store to the whole composite with a series of extract and stores // to each element. uint32_t storeInput = store->GetSingleWordInOperand(1u); BasicBlock* block = context()->get_instr_block(store); BasicBlock::iterator where(store); uint32_t elementIndex = 0; for (auto var : replacements) { // Create the extract. if (var->opcode() != SpvOpVariable) { elementIndex++; continue; } Instruction* type = GetStorageType(var); uint32_t extractId = TakeNextId(); if (extractId == 0) { return false; } std::unique_ptr extract(new Instruction( context(), SpvOpCompositeExtract, type->result_id(), extractId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {storeInput}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); auto iter = where.InsertBefore(std::move(extract)); iter->UpdateDebugInfoFrom(store); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, block); // Create the store. std::unique_ptr newStore( new Instruction(context(), SpvOpStore, 0, 0, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}, {SPV_OPERAND_TYPE_ID, {extractId}}})); // Copy memory access attributes which start at index 2. Index 0 is the // pointer and index 1 is the data. for (uint32_t i = 2; i < store->NumInOperands(); ++i) { Operand copy(store->GetInOperand(i)); newStore->AddOperand(std::move(copy)); } iter = where.InsertBefore(std::move(newStore)); iter->UpdateDebugInfoFrom(store); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, block); } return true; } bool ScalarReplacementPass::ReplaceAccessChain( Instruction* chain, const std::vector& replacements) { // Replaces the access chain with either another access chain (with one fewer // indexes) or a direct use of the replacement variable. uint32_t indexId = chain->GetSingleWordInOperand(1u); const Instruction* index = get_def_use_mgr()->GetDef(indexId); int64_t indexValue = context() ->get_constant_mgr() ->GetConstantFromInst(index) ->GetSignExtendedValue(); if (indexValue < 0 || indexValue >= static_cast(replacements.size())) { // Out of bounds access, this is illegal IR. Notice that OpAccessChain // indexing is 0-based, so we should also reject index == size-of-array. return false; } else { const Instruction* var = replacements[static_cast(indexValue)]; if (chain->NumInOperands() > 2) { // Replace input access chain with another access chain. BasicBlock::iterator chainIter(chain); uint32_t replacementId = TakeNextId(); if (replacementId == 0) { return false; } std::unique_ptr replacementChain(new Instruction( context(), chain->opcode(), chain->type_id(), replacementId, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Add the remaining indexes. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { Operand copy(chain->GetInOperand(i)); replacementChain->AddOperand(std::move(copy)); } replacementChain->UpdateDebugInfoFrom(chain); auto iter = chainIter.InsertBefore(std::move(replacementChain)); get_def_use_mgr()->AnalyzeInstDefUse(&*iter); context()->set_instr_block(&*iter, context()->get_instr_block(chain)); context()->ReplaceAllUsesWith(chain->result_id(), replacementId); } else { // Replace with a use of the variable. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); } } return true; } bool ScalarReplacementPass::CreateReplacementVariables( Instruction* inst, std::vector* replacements) { Instruction* type = GetStorageType(inst); std::unique_ptr> components_used = GetUsedComponents(inst); uint32_t elem = 0; switch (type->opcode()) { case SpvOpTypeStruct: type->ForEachInOperand( [this, inst, &elem, replacements, &components_used](uint32_t* id) { if (!components_used || components_used->count(elem)) { CreateVariable(*id, inst, elem, replacements); } else { replacements->push_back(GetUndef(*id)); } elem++; }); break; case SpvOpTypeArray: for (uint32_t i = 0; i != GetArrayLength(type); ++i) { if (!components_used || components_used->count(i)) { CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); } else { uint32_t element_type_id = type->GetSingleWordInOperand(0); replacements->push_back(GetUndef(element_type_id)); } } break; case SpvOpTypeMatrix: case SpvOpTypeVector: for (uint32_t i = 0; i != GetNumElements(type); ++i) { CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); } break; default: assert(false && "Unexpected type."); break; } TransferAnnotations(inst, replacements); return std::find(replacements->begin(), replacements->end(), nullptr) == replacements->end(); } Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) { return get_def_use_mgr()->GetDef(Type2Undef(type_id)); } void ScalarReplacementPass::TransferAnnotations( const Instruction* source, std::vector* replacements) { // Only transfer invariant and restrict decorations on the variable. There are // no type or member decorations that are necessary to transfer. for (auto inst : get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { assert(inst->opcode() == SpvOpDecorate); uint32_t decoration = inst->GetSingleWordInOperand(1u); if (decoration == SpvDecorationInvariant || decoration == SpvDecorationRestrict) { for (auto var : *replacements) { if (var == nullptr) { continue; } std::unique_ptr annotation( new Instruction(context(), SpvOpDecorate, 0, 0, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}, {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { Operand copy(inst->GetInOperand(i)); annotation->AddOperand(std::move(copy)); } context()->AddAnnotationInst(std::move(annotation)); get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); } } } } void ScalarReplacementPass::CreateVariable( uint32_t typeId, Instruction* varInst, uint32_t index, std::vector* replacements) { uint32_t ptrId = GetOrCreatePointerType(typeId); uint32_t id = TakeNextId(); if (id == 0) { replacements->push_back(nullptr); } std::unique_ptr variable(new Instruction( context(), SpvOpVariable, ptrId, id, std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); BasicBlock* block = context()->get_instr_block(varInst); block->begin().InsertBefore(std::move(variable)); Instruction* inst = &*block->begin(); // If varInst was initialized, make sure to initialize its replacement. GetOrCreateInitialValue(varInst, index, inst); get_def_use_mgr()->AnalyzeInstDefUse(inst); context()->set_instr_block(inst, block); // Copy decorations from the member to the new variable. Instruction* typeInst = GetStorageType(varInst); for (auto dec_inst : get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { uint32_t decoration; if (dec_inst->opcode() != SpvOpMemberDecorate) { continue; } if (dec_inst->GetSingleWordInOperand(1) != index) { continue; } decoration = dec_inst->GetSingleWordInOperand(2u); switch (decoration) { case SpvDecorationRelaxedPrecision: { std::unique_ptr new_dec_inst( new Instruction(context(), SpvOpDecorate, 0, 0, {})); new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id})); for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) { new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i))); } context()->AddAnnotationInst(std::move(new_dec_inst)); } break; default: break; } } // Update the DebugInfo debug information. inst->UpdateDebugInfoFrom(varInst); replacements->push_back(inst); } uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { auto iter = pointee_to_pointer_.find(id); if (iter != pointee_to_pointer_.end()) return iter->second; analysis::Type* pointeeTy; std::unique_ptr pointerTy; std::tie(pointeeTy, pointerTy) = context()->get_type_mgr()->GetTypeAndPointerType(id, SpvStorageClassFunction); uint32_t ptrId = 0; if (pointeeTy->IsUniqueType()) { // Non-ambiguous type, just ask the type manager for an id. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get()); pointee_to_pointer_[id] = ptrId; return ptrId; } // Ambiguous type. We must perform a linear search to try and find the right // type. for (auto global : context()->types_values()) { if (global.opcode() == SpvOpTypePointer && global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && global.GetSingleWordInOperand(1u) == id) { if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { // Only reuse a decoration-less pointer of the correct type. ptrId = global.result_id(); break; } } } if (ptrId != 0) { pointee_to_pointer_[id] = ptrId; return ptrId; } ptrId = TakeNextId(); context()->AddType(MakeUnique( context(), SpvOpTypePointer, 0, ptrId, std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {id}}})); Instruction* ptr = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(ptr); pointee_to_pointer_[id] = ptrId; // Register with the type manager if necessary. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy); return ptrId; } void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, uint32_t index, Instruction* newVar) { assert(source->opcode() == SpvOpVariable); if (source->NumInOperands() < 2) return; uint32_t initId = source->GetSingleWordInOperand(1u); uint32_t storageId = GetStorageType(newVar)->result_id(); Instruction* init = get_def_use_mgr()->GetDef(initId); uint32_t newInitId = 0; // TODO(dnovillo): Refactor this with constant propagation. if (init->opcode() == SpvOpConstantNull) { // Initialize to appropriate NULL. auto iter = type_to_null_.find(storageId); if (iter == type_to_null_.end()) { newInitId = TakeNextId(); type_to_null_[storageId] = newInitId; context()->AddGlobalValue( MakeUnique(context(), SpvOpConstantNull, storageId, newInitId, std::initializer_list{})); Instruction* newNull = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newNull); } else { newInitId = iter->second; } } else if (IsSpecConstantInst(init->opcode())) { // Create a new constant extract. newInitId = TakeNextId(); context()->AddGlobalValue(MakeUnique( context(), SpvOpSpecConstantOp, storageId, newInitId, std::initializer_list{ {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, {SPV_OPERAND_TYPE_ID, {init->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); Instruction* newSpecConst = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); } else if (init->opcode() == SpvOpConstantComposite) { // Get the appropriate index constant. newInitId = init->GetSingleWordInOperand(index); Instruction* element = get_def_use_mgr()->GetDef(newInitId); if (element->opcode() == SpvOpUndef) { // Undef is not a valid initializer for a variable. newInitId = 0; } } else { assert(false); } if (newInitId != 0) { newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); } } uint64_t ScalarReplacementPass::GetArrayLength( const Instruction* arrayType) const { assert(arrayType->opcode() == SpvOpTypeArray); const Instruction* length = get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); return context() ->get_constant_mgr() ->GetConstantFromInst(length) ->GetZeroExtendedValue(); } uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { assert(type->opcode() == SpvOpTypeVector || type->opcode() == SpvOpTypeMatrix); const Operand& op = type->GetInOperand(1u); assert(op.words.size() <= 2); uint64_t len = 0; for (size_t i = 0; i != op.words.size(); ++i) { len |= (static_cast(op.words[i]) << (32ull * i)); } return len; } bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { const Instruction* inst = get_def_use_mgr()->GetDef(id); assert(inst); return spvOpcodeIsSpecConstant(inst->opcode()); } Instruction* ScalarReplacementPass::GetStorageType( const Instruction* inst) const { assert(inst->opcode() == SpvOpVariable); uint32_t ptrTypeId = inst->type_id(); uint32_t typeId = get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); return get_def_use_mgr()->GetDef(typeId); } bool ScalarReplacementPass::CanReplaceVariable( const Instruction* varInst) const { assert(varInst->opcode() == SpvOpVariable); // Can only replace function scope variables. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) { return false; } if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) { return false; } const Instruction* typeInst = GetStorageType(varInst); if (!CheckType(typeInst)) { return false; } if (!CheckAnnotations(varInst)) { return false; } if (!CheckUses(varInst)) { return false; } return true; } bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { if (!CheckTypeAnnotations(typeInst)) { return false; } switch (typeInst->opcode()) { case SpvOpTypeStruct: // Don't bother with empty structs or very large structs. if (typeInst->NumInOperands() == 0 || IsLargerThanSizeLimit(typeInst->NumInOperands())) { return false; } return true; case SpvOpTypeArray: if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { return false; } if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { return false; } return true; // TODO(alanbaker): Develop some heuristics for when this should be // re-enabled. //// Specifically including matrix and vector in an attempt to reduce the //// number of vector registers required. // case SpvOpTypeMatrix: // case SpvOpTypeVector: // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; // return true; case SpvOpTypeRuntimeArray: default: return false; } } bool ScalarReplacementPass::CheckTypeAnnotations( const Instruction* typeInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { uint32_t decoration; if (inst->opcode() == SpvOpDecorate) { decoration = inst->GetSingleWordInOperand(1u); } else { assert(inst->opcode() == SpvOpMemberDecorate); decoration = inst->GetSingleWordInOperand(2u); } switch (decoration) { case SpvDecorationRowMajor: case SpvDecorationColMajor: case SpvDecorationArrayStride: case SpvDecorationMatrixStride: case SpvDecorationCPacked: case SpvDecorationInvariant: case SpvDecorationRestrict: case SpvDecorationOffset: case SpvDecorationAlignment: case SpvDecorationAlignmentId: case SpvDecorationMaxByteOffset: case SpvDecorationRelaxedPrecision: break; default: return false; } } return true; } bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { assert(inst->opcode() == SpvOpDecorate); uint32_t decoration = inst->GetSingleWordInOperand(1u); switch (decoration) { case SpvDecorationInvariant: case SpvDecorationRestrict: case SpvDecorationAlignment: case SpvDecorationAlignmentId: case SpvDecorationMaxByteOffset: break; default: return false; } } return true; } bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { VariableStats stats = {0, 0}; bool ok = CheckUses(inst, &stats); // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when // SRoA is costly, such as when the structure has many (unaccessed?) // members. return ok; } bool ScalarReplacementPass::CheckUses(const Instruction* inst, VariableStats* stats) const { uint64_t max_legal_index = GetMaxLegalIndex(inst); bool ok = true; get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok]( const Instruction* user, uint32_t index) { if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare || user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) { // TODO: include num_partial_accesses if it uses Fragment operation or // DebugValue has Indexes operand. stats->num_full_accesses++; return; } // Annotations are check as a group separately. if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (index == 2u && user->NumInOperands() > 1) { uint32_t id = user->GetSingleWordInOperand(1u); const Instruction* opInst = get_def_use_mgr()->GetDef(id); const auto* constant = context()->get_constant_mgr()->GetConstantFromInst(opInst); if (!constant) { ok = false; } else if (constant->GetZeroExtendedValue() >= max_legal_index) { ok = false; } else { if (!CheckUsesRelaxed(user)) ok = false; } stats->num_partial_accesses++; } else { ok = false; } break; case SpvOpLoad: if (!CheckLoad(user, index)) ok = false; stats->num_full_accesses++; break; case SpvOpStore: if (!CheckStore(user, index)) ok = false; stats->num_full_accesses++; break; case SpvOpName: case SpvOpMemberName: break; default: ok = false; break; } } }); return ok; } bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { bool ok = true; get_def_use_mgr()->ForEachUse( inst, [this, &ok](const Instruction* user, uint32_t index) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (index != 2u) { ok = false; } else { if (!CheckUsesRelaxed(user)) ok = false; } break; case SpvOpLoad: if (!CheckLoad(user, index)) ok = false; break; case SpvOpStore: if (!CheckStore(user, index)) ok = false; break; case SpvOpImageTexelPointer: if (!CheckImageTexelPointer(index)) ok = false; break; case SpvOpExtInst: if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare || !CheckDebugDeclare(index)) ok = false; break; default: ok = false; break; } }); return ok; } bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const { return index == 2u; } bool ScalarReplacementPass::CheckLoad(const Instruction* inst, uint32_t index) const { if (index != 2u) return false; if (inst->NumInOperands() >= 2 && inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask) return false; return true; } bool ScalarReplacementPass::CheckStore(const Instruction* inst, uint32_t index) const { if (index != 0u) return false; if (inst->NumInOperands() >= 3 && inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask) return false; return true; } bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const { if (index != kDebugDeclareOperandVariableIndex) return false; return true; } bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const { if (max_num_elements_ == 0) { return false; } return length > max_num_elements_; } std::unique_ptr> ScalarReplacementPass::GetUsedComponents(Instruction* inst) { std::unique_ptr> result( new std::unordered_set()); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, this](Instruction* use) { switch (use->opcode()) { case SpvOpLoad: { // Look for extract from the load. std::vector t; if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { if (use2->opcode() != SpvOpCompositeExtract || use2->NumInOperands() <= 1) { return false; } t.push_back(use2->GetSingleWordInOperand(1)); return true; })) { result->insert(t.begin(), t.end()); return true; } else { result.reset(nullptr); return false; } } case SpvOpName: case SpvOpMemberName: case SpvOpStore: // No components are used. return true; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: { // Add the first index it if is a constant. // TODO: Could be improved by checking if the address is used in a load. analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); uint32_t index_id = use->GetSingleWordInOperand(1); const analysis::Constant* index_const = const_mgr->FindDeclaredConstant(index_id); if (index_const) { result->insert(index_const->GetSignExtendedValue()); return true; } else { // Could be any element. Assuming all are used. result.reset(nullptr); return false; } } default: // We do not know what is happening. Have to assume the worst. result.reset(nullptr); return false; } }); return result; } uint64_t ScalarReplacementPass::GetMaxLegalIndex( const Instruction* var_inst) const { assert(var_inst->opcode() == SpvOpVariable && "|var_inst| must be a variable instruction."); Instruction* type = GetStorageType(var_inst); switch (type->opcode()) { case SpvOpTypeStruct: return type->NumInOperands(); case SpvOpTypeArray: return GetArrayLength(type); case SpvOpTypeMatrix: case SpvOpTypeVector: return GetNumElements(type); default: return 0; } return 0; } } // namespace opt } // namespace spvtools