SPIRV-Tools/source/opt/copy_prop_arrays.cpp
Steven Perron 93c4c184d5 Handle types with self references.
By using forward pointers, we are able to define a struct that has a
pointer to itself.  This could be directly or indirectly.  The current
implementation of the type manager did not handle this case.  There are
three changes that are made in this commit inorder to handle this case:

1) Change the handling of OpTypeForwardPointer

The current handling of OpTypeForwardsPointer is broken if there is a
reference to the pointer before the real definition.  When build the
type that contain the forward delared pointer, the type manager will ask
for the type for that ID, and will get a nullptr because it does not
exists.  This nullptr is not handleded very well.

The change is to keep track of the incomplete types the first time
through all of the types.  An incomplete type is a ForwardPointer or any
type that references an incomplete type.

Then we implement a second pass through the incomplete types that will
complete them.

2) Hashing types.

When hashing a type, we want to uses all of the subtypes as part of the
hash.  However, with types that reference them selves, this creates an
infinite recursion.  To get around this, we keep track of which types
have been seen on the path from the root type.  If we have see the
current type already then we can stop the recursion.

3) Comparing types.

In order to check if two types are the same, we must check that all of
their subtypes are the same as well.  This also causes an infinit
recursion.  The solution is to stop comparing the subtypes if we are
trying to compare two pointer types that we are already in the middle of
comparing.  The ideas is that if the two pointer are different, then in
progress compare will return false itself.

Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/1578.
2018-05-30 15:48:38 -04:00

849 lines
31 KiB
C++

// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "copy_prop_arrays.h"
#include "ir_builder.h"
namespace {
const uint32_t kLoadPointerInOperand = 0;
const uint32_t kStorePointerInOperand = 0;
const uint32_t kStoreObjectInOperand = 1;
const uint32_t kCompositeExtractObjectInOperand = 0;
} // namespace
namespace spvtools {
namespace opt {
Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) {
InitializeProcessing(ctx);
bool modified = false;
for (ir::Function& function : *get_module()) {
ir::BasicBlock* entry_bb = &*function.begin();
for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
++var_inst) {
if (!IsPointerToArrayType(var_inst->type_id())) {
continue;
}
// Find the only store to the entire memory location, if it exists.
ir::Instruction* store_inst = FindStoreInstruction(&*var_inst);
if (!store_inst) {
continue;
}
std::unique_ptr<MemoryObject> source_object =
FindSourceObjectIfPossible(&*var_inst, store_inst);
if (source_object != nullptr) {
if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId())) {
modified = true;
PropagateObject(&*var_inst, source_object.get(), store_inst);
}
}
}
}
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst,
ir::Instruction* store_inst) {
assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
// Check that the variable is a composite object where |store_inst|
// dominates all of its loads.
if (!store_inst) {
return nullptr;
}
// Look at the loads to ensure they are dominated by the store.
if (!HasValidReferencesOnly(var_inst, store_inst)) {
return nullptr;
}
// If so, look at the store to see if it is the copy of an object.
std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
if (!source) {
return nullptr;
}
// Ensure that |source| does not change between the point at which it is
// loaded, and the position in which |var_inst| is loaded.
//
// For now we will go with the easy to implement approach, and check that the
// entire variable (not just the specific component) is never written to.
if (!HasNoStores(source->GetVariable())) {
return nullptr;
}
return source;
}
ir::Instruction* CopyPropagateArrays::FindStoreInstruction(
const ir::Instruction* var_inst) const {
ir::Instruction* store_inst = nullptr;
get_def_use_mgr()->WhileEachUser(
var_inst, [&store_inst, var_inst](ir::Instruction* use) {
if (use->opcode() == SpvOpStore &&
use->GetSingleWordInOperand(kStorePointerInOperand) ==
var_inst->result_id()) {
if (store_inst == nullptr) {
store_inst = use;
} else {
store_inst = nullptr;
return false;
}
}
return true;
});
return store_inst;
}
void CopyPropagateArrays::PropagateObject(ir::Instruction* var_inst,
MemoryObject* source,
ir::Instruction* insertion_point) {
assert(var_inst->opcode() == SpvOpVariable &&
"This function propagates variables.");
ir::Instruction* new_access_chain =
BuildNewAccessChain(insertion_point, source);
context()->KillNamesAndDecorates(var_inst);
UpdateUses(var_inst, new_access_chain);
}
ir::Instruction* CopyPropagateArrays::BuildNewAccessChain(
ir::Instruction* insertion_point,
CopyPropagateArrays::MemoryObject* source) const {
InstructionBuilder builder(context(), insertion_point,
ir::IRContext::kAnalysisDefUse |
ir::IRContext::kAnalysisInstrToBlockMapping);
if (source->AccessChain().size() == 0) {
return source->GetVariable();
}
return builder.AddAccessChain(source->GetPointerTypeId(),
source->GetVariable()->result_id(),
source->AccessChain());
}
bool CopyPropagateArrays::HasNoStores(ir::Instruction* ptr_inst) {
return get_def_use_mgr()->WhileEachUser(
ptr_inst, [this](ir::Instruction* use) {
if (use->opcode() == SpvOpLoad) {
return true;
} else if (use->opcode() == SpvOpAccessChain) {
return HasNoStores(use);
} else if (use->IsDecoration() || use->opcode() == SpvOpName) {
return true;
} else if (use->opcode() == SpvOpStore) {
return false;
} else if (use->opcode() == SpvOpImageTexelPointer) {
return true;
}
// Some other instruction. Be conservative.
return false;
});
}
bool CopyPropagateArrays::HasValidReferencesOnly(ir::Instruction* ptr_inst,
ir::Instruction* store_inst) {
ir::BasicBlock* store_block = context()->get_instr_block(store_inst);
opt::DominatorAnalysis* dominator_analysis =
context()->GetDominatorAnalysis(store_block->GetParent());
return get_def_use_mgr()->WhileEachUser(
ptr_inst,
[this, store_inst, dominator_analysis, ptr_inst](ir::Instruction* use) {
if (use->opcode() == SpvOpLoad ||
use->opcode() == SpvOpImageTexelPointer) {
// TODO: If there are many load in the same BB as |store_inst| the
// time to do the multiple traverses can add up. Consider collecting
// those loads and doing a single traversal.
return dominator_analysis->Dominates(store_inst, use);
} else if (use->opcode() == SpvOpAccessChain) {
return HasValidReferencesOnly(use, store_inst);
} else if (use->IsDecoration() || use->opcode() == SpvOpName) {
return true;
} else if (use->opcode() == SpvOpStore) {
// If we are storing to part of the object it is not an candidate.
return ptr_inst->opcode() == SpvOpVariable &&
store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
ptr_inst->result_id();
}
// Some other instruction. Be conservative.
return false;
});
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
ir::Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
switch (result_inst->opcode()) {
case SpvOpLoad:
return BuildMemoryObjectFromLoad(result_inst);
case SpvOpCompositeExtract:
return BuildMemoryObjectFromExtract(result_inst);
case SpvOpCompositeConstruct:
return BuildMemoryObjectFromCompositeConstruct(result_inst);
case SpvOpCopyObject:
return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
case SpvOpCompositeInsert:
return BuildMemoryObjectFromInsert(result_inst);
default:
return nullptr;
}
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::BuildMemoryObjectFromLoad(ir::Instruction* load_inst) {
std::vector<uint32_t> components_in_reverse;
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
ir::Instruction* current_inst = def_use_mgr->GetDef(
load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
// Build the access chain for the memory object by collecting the indices used
// in the OpAccessChain instructions. If we find a variable index, then
// return |nullptr| because we cannot know for sure which memory location is
// used.
//
// It is built in reverse order because the different |OpAccessChain|
// instructions are visited in reverse order from which they are applied.
while (current_inst->opcode() == SpvOpAccessChain) {
for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
components_in_reverse.push_back(element_index_id);
}
current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
}
// If the address in the load is not constructed from an |OpVariable|
// instruction followed by a series of |OpAccessChain| instructions, then
// return |nullptr| because we cannot identify the owner or access chain
// exactly.
if (current_inst->opcode() != SpvOpVariable) {
return nullptr;
}
// Build the memory object. Use |rbegin| and |rend| to put the access chain
// back in the correct order.
return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
new MemoryObject(current_inst, components_in_reverse.rbegin(),
components_in_reverse.rend()));
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::BuildMemoryObjectFromExtract(
ir::Instruction* extract_inst) {
assert(extract_inst->opcode() == SpvOpCompositeExtract &&
"Expecting an OpCompositeExtract instruction.");
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
if (result) {
analysis::Integer int_type(32, false);
const analysis::Type* uint32_type =
context()->get_type_mgr()->GetRegisteredType(&int_type);
std::vector<uint32_t> components;
// Convert the indices in the extract instruction to a series of ids that
// can be used by the |OpAccessChain| instruction.
for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
uint32_t index = extract_inst->GetSingleWordInOperand(i);
const analysis::Constant* index_const =
const_mgr->GetConstant(uint32_type, {index});
components.push_back(
const_mgr->GetDefiningInstruction(index_const)->result_id());
}
result->GetMember(components);
return result;
}
return nullptr;
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
ir::Instruction* conststruct_inst) {
assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
"Expecting an OpCompositeConstruct instruction.");
// If every operand in the instruction are part of the same memory object, and
// are being combined in the same order, then the result is the same as the
// parent.
std::unique_ptr<MemoryObject> memory_object =
GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
if (!memory_object) {
return nullptr;
}
if (!memory_object->IsMember()) {
return nullptr;
}
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
const analysis::Constant* last_access =
const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
if (!last_access ||
(!last_access->AsIntConstant() && !last_access->AsNullConstant())) {
return nullptr;
}
if (last_access->GetU32() != 0) {
return nullptr;
}
memory_object->GetParent();
if (memory_object->GetNumberOfMembers() !=
conststruct_inst->NumInOperands()) {
return nullptr;
}
for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
std::unique_ptr<MemoryObject> member_object =
GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
if (!member_object->IsMember()) {
return nullptr;
}
if (!memory_object->Contains(member_object.get())) {
return nullptr;
}
last_access =
const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
if (!last_access || !last_access->AsIntConstant()) {
return nullptr;
}
if (last_access->GetU32() != i) {
return nullptr;
}
}
return memory_object;
}
std::unique_ptr<CopyPropagateArrays::MemoryObject>
CopyPropagateArrays::BuildMemoryObjectFromInsert(ir::Instruction* insert_inst) {
assert(insert_inst->opcode() == SpvOpCompositeInsert &&
"Expecting an OpCompositeInsert instruction.");
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
uint32_t number_of_elements = 0;
if (const analysis::Struct* struct_type = result_type->AsStruct()) {
number_of_elements =
static_cast<uint32_t>(struct_type->element_types().size());
} else if (const analysis::Array* array_type = result_type->AsArray()) {
const analysis::Constant* length_const =
const_mgr->FindDeclaredConstant(array_type->LengthId());
assert(length_const->AsIntConstant());
number_of_elements = length_const->AsIntConstant()->GetU32();
} else if (const analysis::Vector* vector_type = result_type->AsVector()) {
number_of_elements = vector_type->element_count();
} else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
number_of_elements = matrix_type->element_count();
}
if (number_of_elements == 0) {
return nullptr;
}
if (insert_inst->NumInOperands() != 3) {
return nullptr;
}
if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
return nullptr;
}
std::unique_ptr<MemoryObject> memory_object =
GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
if (!memory_object) {
return nullptr;
}
if (!memory_object->IsMember()) {
return nullptr;
}
const analysis::Constant* last_access =
const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
if (!last_access || !last_access->AsIntConstant()) {
return nullptr;
}
if (last_access->GetU32() != number_of_elements - 1) {
return nullptr;
}
memory_object->GetParent();
ir::Instruction* current_insert =
def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
for (uint32_t i = number_of_elements - 1; i > 0; --i) {
if (current_insert->opcode() != SpvOpCompositeInsert) {
return nullptr;
}
if (current_insert->NumInOperands() != 3) {
return nullptr;
}
if (current_insert->GetSingleWordInOperand(2) != i - 1) {
return nullptr;
}
std::unique_ptr<MemoryObject> current_memory_object =
GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
if (!current_memory_object) {
return nullptr;
}
if (!current_memory_object->IsMember()) {
return nullptr;
}
if (memory_object->AccessChain().size() + 1 !=
current_memory_object->AccessChain().size()) {
return nullptr;
}
if (!memory_object->Contains(current_memory_object.get())) {
return nullptr;
}
const analysis::Constant* current_last_access =
const_mgr->FindDeclaredConstant(
current_memory_object->AccessChain().back());
if (!current_last_access || !current_last_access->AsIntConstant()) {
return nullptr;
}
if (current_last_access->GetU32() != i - 1) {
return nullptr;
}
current_insert =
def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
}
return memory_object;
}
bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
if (pointer_type) {
return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
pointer_type->pointee_type()->kind() == analysis::Type::kImage;
}
return false;
}
bool CopyPropagateArrays::CanUpdateUses(ir::Instruction* original_ptr_inst,
uint32_t type_id) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
analysis::Type* type = type_mgr->GetType(type_id);
if (type->AsRuntimeArray()) {
return false;
}
if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
// If the type is not an aggregate, then the desired type must be the
// same as the current type. No work to do, and we can do that.
return true;
}
return def_use_mgr->WhileEachUse(
original_ptr_inst,
[this, type_mgr, const_mgr, type](ir::Instruction* use, uint32_t) {
switch (use->opcode()) {
case SpvOpLoad: {
analysis::Pointer* pointer_type = type->AsPointer();
uint32_t new_type_id =
type_mgr->GetId(pointer_type->pointee_type());
if (new_type_id != use->type_id()) {
return CanUpdateUses(use, new_type_id);
}
return true;
}
case SpvOpAccessChain: {
analysis::Pointer* pointer_type = type->AsPointer();
const analysis::Type* pointee_type = pointer_type->pointee_type();
std::vector<uint32_t> access_chain;
for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(
use->GetSingleWordInOperand(i));
if (index_const) {
access_chain.push_back(index_const->AsIntConstant()->GetU32());
} else {
// Variable index means the type is a type where every element
// is the same type. Use element 0 to get the type.
access_chain.push_back(0);
}
}
const analysis::Type* new_pointee_type =
type_mgr->GetMemberType(pointee_type, access_chain);
opt::analysis::Pointer pointerTy(new_pointee_type,
pointer_type->storage_class());
uint32_t new_pointer_type_id =
context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
if (new_pointer_type_id != use->type_id()) {
return CanUpdateUses(use, new_pointer_type_id);
}
return true;
}
case SpvOpCompositeExtract: {
std::vector<uint32_t> access_chain;
for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
access_chain.push_back(use->GetSingleWordInOperand(i));
}
const analysis::Type* new_type =
type_mgr->GetMemberType(type, access_chain);
uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
if (new_type_id != use->type_id()) {
return CanUpdateUses(use, new_type_id);
}
return true;
}
case SpvOpStore:
// Can't handle changing the type of a store. There are too many
// other things that might need to change as well. Not worth the
// effort. Punting for now.
// TODO (s-perron): This can be handled by expanding the store into
// a series of extracts, composite constructs, and a store.
return true;
case SpvOpImageTexelPointer:
case SpvOpName:
return true;
default:
return use->IsDecoration();
}
});
}
void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst,
ir::Instruction* new_ptr_inst) {
// TODO (s-perron): Keep the def-use manager up to date. Not done now because
// it can cause problems for the |ForEachUse| traversals. Can be use by
// keeping a list of instructions that need updating, and then updating them
// in |PropagateObject|.
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
std::vector<std::pair<ir::Instruction*, uint32_t> > uses;
def_use_mgr->ForEachUse(original_ptr_inst,
[&uses](ir::Instruction* use, uint32_t index) {
uses.push_back({use, index});
});
for (auto pair : uses) {
ir::Instruction* use = pair.first;
uint32_t index = pair.second;
analysis::Pointer* pointer_type = nullptr;
switch (use->opcode()) {
case SpvOpLoad: {
// Replace the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
// Update the type.
pointer_type = type_mgr->GetType(new_ptr_inst->type_id())->AsPointer();
uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
if (new_type_id != use->type_id()) {
use->SetResultType(new_type_id);
context()->AnalyzeUses(use);
UpdateUses(use, use);
} else {
context()->AnalyzeUses(use);
}
} break;
case SpvOpAccessChain: {
// Update the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
// Update the result type.
pointer_type = type_mgr->GetType(new_ptr_inst->type_id())->AsPointer();
const analysis::Type* pointee_type = pointer_type->pointee_type();
// Convert the ids on the OpAccessChain to indices that can be used to
// get the specific member.
std::vector<uint32_t> access_chain;
for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
if (index_const) {
access_chain.push_back(index_const->AsIntConstant()->GetU32());
} else {
// Variable index means the type is an type where every element
// is the same type. Use element 0 to get the type.
access_chain.push_back(0);
}
}
const analysis::Type* new_pointee_type =
type_mgr->GetMemberType(pointee_type, access_chain);
// Now build a pointer to the type of the member.
opt::analysis::Pointer new_pointer_type(new_pointee_type,
pointer_type->storage_class());
uint32_t new_pointer_type_id =
context()->get_type_mgr()->GetTypeInstruction(&new_pointer_type);
if (new_pointer_type_id != use->type_id()) {
use->SetResultType(new_pointer_type_id);
context()->AnalyzeUses(use);
UpdateUses(use, use);
} else {
context()->AnalyzeUses(use);
}
} break;
case SpvOpCompositeExtract: {
// Update the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
std::vector<uint32_t> access_chain;
for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
access_chain.push_back(use->GetSingleWordInOperand(i));
}
const analysis::Type* type = type_mgr->GetType(new_ptr_inst->type_id());
const analysis::Type* new_type =
type_mgr->GetMemberType(type, access_chain);
uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
if (new_type_id != use->type_id()) {
use->SetResultType(new_type_id);
context()->AnalyzeUses(use);
UpdateUses(use, use);
} else {
context()->AnalyzeUses(use);
}
} break;
case SpvOpStore:
// If the use is the pointer, then it is the single store to that
// variable. We do not want to replace it. Instead, it will become
// dead after all of the loads are removed, and ADCE will get rid of it.
//
// If the use is the object being stored, we will create a copy of the
// object turning it into the correct type. The copy is done by
// decomposing the object into the base type, which must be the same,
// and then rebuilding them.
if (index == 1) {
ir::Instruction* target_pointer = def_use_mgr->GetDef(
use->GetSingleWordInOperand(kStorePointerInOperand));
pointer_type =
type_mgr->GetType(target_pointer->type_id())->AsPointer();
uint32_t copy =
GenerateCopy(original_ptr_inst,
type_mgr->GetId(pointer_type->pointee_type()), use);
context()->ForgetUses(use);
use->SetInOperand(index, {copy});
context()->AnalyzeUses(use);
}
break;
case SpvOpImageTexelPointer:
// We treat an OpImageTexelPointer as a load. The result type should
// always have the Image storage class, and should not need to be
// updated.
// Replace the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
context()->AnalyzeUses(use);
break;
default:
assert(false && "Don't know how to rewrite instruction");
break;
}
}
}
uint32_t CopyPropagateArrays::GenerateCopy(
ir::Instruction* object_inst, uint32_t new_type_id,
ir::Instruction* insertion_position) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
uint32_t original_type_id = object_inst->type_id();
if (original_type_id == new_type_id) {
return object_inst->result_id();
}
opt::InstructionBuilder ir_builder(
context(), insertion_position,
ir::IRContext::kAnalysisInstrToBlockMapping |
ir::IRContext::kAnalysisDefUse);
analysis::Type* original_type = type_mgr->GetType(original_type_id);
analysis::Type* new_type = type_mgr->GetType(new_type_id);
if (const analysis::Array* original_array_type = original_type->AsArray()) {
uint32_t original_element_type_id =
type_mgr->GetId(original_array_type->element_type());
analysis::Array* new_array_type = new_type->AsArray();
assert(new_array_type != nullptr && "Can't copy an array to a non-array.");
uint32_t new_element_type_id =
type_mgr->GetId(new_array_type->element_type());
std::vector<uint32_t> element_ids;
const analysis::Constant* length_const =
const_mgr->FindDeclaredConstant(original_array_type->LengthId());
assert(length_const->AsIntConstant());
uint32_t array_length = length_const->AsIntConstant()->GetU32();
for (uint32_t i = 0; i < array_length; i++) {
ir::Instruction* extract = ir_builder.AddCompositeExtract(
original_element_type_id, object_inst->result_id(), {i});
element_ids.push_back(
GenerateCopy(extract, new_element_type_id, insertion_position));
}
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
} else if (const analysis::Struct* original_struct_type =
original_type->AsStruct()) {
analysis::Struct* new_struct_type = new_type->AsStruct();
const std::vector<const analysis::Type*>& original_types =
original_struct_type->element_types();
const std::vector<const analysis::Type*>& new_types =
new_struct_type->element_types();
std::vector<uint32_t> element_ids;
for (uint32_t i = 0; i < original_types.size(); i++) {
ir::Instruction* extract = ir_builder.AddCompositeExtract(
type_mgr->GetId(original_types[i]), object_inst->result_id(), {i});
element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]),
insertion_position));
}
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
} else {
// If we do not have an aggregate type, then we have a problem. Either we
// found multiple instances of the same type, or we are copying to an
// incompatible type. Either way the code is illegal.
assert(false &&
"Don't know how to copy this type. Code is likely illegal.");
}
return 0;
}
void CopyPropagateArrays::MemoryObject::GetMember(
const std::vector<uint32_t>& access_chain) {
access_chain_.insert(access_chain_.end(), access_chain.begin(),
access_chain.end());
}
uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
ir::IRContext* context = variable_inst_->context();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
type = type->AsPointer()->pointee_type();
std::vector<uint32_t> access_indices = GetAccessIds();
type = type_mgr->GetMemberType(type, access_indices);
if (const analysis::Struct* struct_type = type->AsStruct()) {
return static_cast<uint32_t>(struct_type->element_types().size());
} else if (const analysis::Array* array_type = type->AsArray()) {
const analysis::Constant* length_const =
context->get_constant_mgr()->FindDeclaredConstant(
array_type->LengthId());
assert(length_const->AsIntConstant());
return length_const->AsIntConstant()->GetU32();
} else if (const analysis::Vector* vector_type = type->AsVector()) {
return vector_type->element_count();
} else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
return matrix_type->element_count();
} else {
return 0;
}
}
template <class iterator>
CopyPropagateArrays::MemoryObject::MemoryObject(ir::Instruction* var_inst,
iterator begin, iterator end)
: variable_inst_(var_inst), access_chain_(begin, end) {}
std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
analysis::ConstantManager* const_mgr =
variable_inst_->context()->get_constant_mgr();
std::vector<uint32_t> access_indices;
for (uint32_t id : AccessChain()) {
const analysis::Constant* element_index_const =
const_mgr->FindDeclaredConstant(id);
if (!element_index_const) {
access_indices.push_back(0);
} else {
assert(element_index_const->AsIntConstant());
access_indices.push_back(element_index_const->AsIntConstant()->GetU32());
}
}
return access_indices;
}
bool CopyPropagateArrays::MemoryObject::Contains(
CopyPropagateArrays::MemoryObject* other) {
if (this->GetVariable() != other->GetVariable()) {
return false;
}
if (AccessChain().size() > other->AccessChain().size()) {
return false;
}
for (uint32_t i = 0; i < AccessChain().size(); i++) {
if (AccessChain()[i] != other->AccessChain()[i]) {
return false;
}
}
return true;
}
} // namespace opt
} // namespace spvtools