From d8ca09821db1f1cf9ceab61300798861cad81512 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 27 Mar 2018 15:23:53 -0400 Subject: [PATCH] Handle non-constant accesses in memory objects (copy prop arrays) The first implementation of MemroyObject, which is used in copy propagate arrays, forced the access chain to be like the access chains in OpCompositeExtract. This excluded the possibility of the memory object from representing an array element that was extracted with a variable index. Looking at the code, that restriction is not neccessary. I also see some opportunities for doing this in some real shaders. Contributes to #1430. --- source/opt/copy_prop_arrays.cpp | 153 ++++++++++++++++++------------ source/opt/copy_prop_arrays.h | 30 ++++-- source/opt/type_manager.cpp | 3 + test/opt/copy_prop_array_test.cpp | 12 ++- 4 files changed, 121 insertions(+), 77 deletions(-) diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp index 7587e60e1..f883ae322 100644 --- a/source/opt/copy_prop_arrays.cpp +++ b/source/opt/copy_prop_arrays.cpp @@ -38,13 +38,20 @@ Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) { continue; } + // Find the only store to the entire memory location, if it exists. + ir::Instruction* store_inst = FindStoreInstruction(&*var_inst); + + if (!store_inst) { + continue; + } + std::unique_ptr source_object = - FindSourceObjectIfPossible(&*var_inst); + FindSourceObjectIfPossible(&*var_inst, store_inst); if (source_object != nullptr) { if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId())) { modified = true; - PropagateObject(&*var_inst, source_object.get()); + PropagateObject(&*var_inst, source_object.get(), store_inst); } } } @@ -53,29 +60,12 @@ Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) { } std::unique_ptr -CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst) { +CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst, + ir::Instruction* store_inst) { assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable."); - // Check that the variable is a composite object with single store that + // Check that the variable is a composite object where |store_inst| // dominates all of its loads. - - // Find the only store to the entire memory location, if it exists. - ir::Instruction* store_inst = nullptr; - get_def_use_mgr()->WhileEachUser( - var_inst, [&store_inst, var_inst](ir::Instruction* use) { - if (use->opcode() == SpvOpStore && - use->GetSingleWordInOperand(kStorePointerInOperand) == - var_inst->result_id()) { - if (store_inst == nullptr) { - store_inst = use; - } else { - store_inst = nullptr; - return false; - } - } - return true; - }); - if (!store_inst) { return nullptr; } @@ -105,15 +95,32 @@ CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst) { return source; } +ir::Instruction* CopyPropagateArrays::FindStoreInstruction( + const ir::Instruction* var_inst) const { + ir::Instruction* store_inst = nullptr; + get_def_use_mgr()->WhileEachUser( + var_inst, [&store_inst, var_inst](ir::Instruction* use) { + if (use->opcode() == SpvOpStore && + use->GetSingleWordInOperand(kStorePointerInOperand) == + var_inst->result_id()) { + if (store_inst == nullptr) { + store_inst = use; + } else { + store_inst = nullptr; + return false; + } + } + return true; + }); + return store_inst; +} + void CopyPropagateArrays::PropagateObject(ir::Instruction* var_inst, - MemoryObject* source) { + MemoryObject* source, + ir::Instruction* insertion_point) { assert(var_inst->opcode() == SpvOpVariable && "This function propagates variables."); - ir::Instruction* insertion_point = var_inst->NextNode(); - while (insertion_point->opcode() == SpvOpVariable) { - insertion_point = insertion_point->NextNode(); - } ir::Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source); context()->KillNamesAndDecorates(var_inst); @@ -123,32 +130,13 @@ void CopyPropagateArrays::PropagateObject(ir::Instruction* var_inst, ir::Instruction* CopyPropagateArrays::BuildNewAccessChain( ir::Instruction* insertion_point, CopyPropagateArrays::MemoryObject* source) const { - analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); - InstructionBuilder builder(context(), insertion_point, ir::IRContext::kAnalysisDefUse | ir::IRContext::kAnalysisInstrToBlockMapping); - analysis::Integer int_type(32, false); - const analysis::Type* uint32_type = - context()->get_type_mgr()->GetRegisteredType(&int_type); - - // Convert the access chain in the source to a series of ids that can be used - // by the |OpAccessChain| instruction. - std::vector index_ids; - for (uint32_t index : source->AccessChain()) { - const analysis::Constant* index_const = - const_mgr->GetConstant(uint32_type, {index}); - index_ids.push_back( - const_mgr->GetDefiningInstruction(index_const)->result_id()); - } - - // Get the type for the result of the OpAccessChain - uint32_t pointer_type_id = source->GetPointerTypeId(); - - // Build the access chain instruction. - return builder.AddAccessChain(pointer_type_id, - source->GetVariable()->result_id(), index_ids); + return builder.AddAccessChain(source->GetPointerTypeId(), + source->GetVariable()->result_id(), + source->AccessChain()); } bool CopyPropagateArrays::HasNoStores(ir::Instruction* ptr_inst) { @@ -219,7 +207,6 @@ std::unique_ptr CopyPropagateArrays::BuildMemoryObjectFromLoad(ir::Instruction* load_inst) { std::vector components_in_reverse; analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); - analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); ir::Instruction* current_inst = def_use_mgr->GetDef( load_inst->GetSingleWordInOperand(kLoadPointerInOperand)); @@ -234,14 +221,7 @@ CopyPropagateArrays::BuildMemoryObjectFromLoad(ir::Instruction* load_inst) { while (current_inst->opcode() == SpvOpAccessChain) { for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) { uint32_t element_index_id = current_inst->GetSingleWordInOperand(i); - const analysis::Constant* element_index_const = - const_mgr->FindDeclaredConstant(element_index_id); - if (!element_index_const) { - return nullptr; - } - assert(element_index_const->AsIntConstant()); - components_in_reverse.push_back( - element_index_const->AsIntConstant()->GetU32()); + components_in_reverse.push_back(element_index_id); } current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0)); } @@ -266,14 +246,25 @@ CopyPropagateArrays::BuildMemoryObjectFromExtract( ir::Instruction* extract_inst) { assert(extract_inst->opcode() == SpvOpCompositeExtract && "Expecting an OpCompositeExtract instruction."); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); std::unique_ptr result = GetSourceObjectIfAny( extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand)); if (result) { + analysis::Integer int_type(32, false); + const analysis::Type* uint32_type = + context()->get_type_mgr()->GetRegisteredType(&int_type); + std::vector components; + // Convert the indices in the extract instruction to a series of ids that + // can be used by the |OpAccessChain| instruction. for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) { - components.emplace_back(extract_inst->GetSingleWordInOperand(i)); + uint32_t index = extract_inst->GetSingleWordInOperand(1); + const analysis::Constant* index_const = + const_mgr->GetConstant(uint32_type, {index}); + components.push_back( + const_mgr->GetDefiningInstruction(index_const)->result_id()); } result->GetMember(components); return result; @@ -302,7 +293,15 @@ CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( return nullptr; } - if (memory_object->AccessChain().back() != 0) { + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Constant* last_access = + const_mgr->FindDeclaredConstant(memory_object->AccessChain().back()); + if (!last_access || + (!last_access->AsIntConstant() && !last_access->AsNullConstant())) { + return nullptr; + } + + if (last_access->GetU32() != 0) { return nullptr; } @@ -321,7 +320,17 @@ CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( return nullptr; } - if (member_object->AccessChain().back() != i) { + if (!member_object->IsMember()) { + return nullptr; + } + + last_access = + const_mgr->FindDeclaredConstant(member_object->AccessChain().back()); + if (!last_access || !last_access->AsIntConstant()) { + return nullptr; + } + + if (last_access->GetU32() != i) { return nullptr; } } @@ -639,7 +648,9 @@ uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() { const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id()); type = type->AsPointer()->pointee_type(); - type = type_mgr->GetMemberType(type, access_chain_); + + std::vector access_indices = GetAccessIds(); + type = type_mgr->GetMemberType(type, access_indices); if (const analysis::Struct* struct_type = type->AsStruct()) { return static_cast(struct_type->element_types().size()); @@ -663,5 +674,23 @@ CopyPropagateArrays::MemoryObject::MemoryObject(ir::Instruction* var_inst, iterator begin, iterator end) : variable_inst_(var_inst), access_chain_(begin, end) {} +std::vector CopyPropagateArrays::MemoryObject::GetAccessIds() const { + analysis::ConstantManager* const_mgr = + variable_inst_->context()->get_constant_mgr(); + + std::vector access_indices; + for (uint32_t id : AccessChain()) { + const analysis::Constant* element_index_const = + const_mgr->FindDeclaredConstant(id); + if (!element_index_const) { + access_indices.push_back(0); + } else { + assert(element_index_const->AsIntConstant()); + access_indices.push_back(element_index_const->AsIntConstant()->GetU32()); + } + } + return access_indices; +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h index ac77f33b0..f55a6158e 100644 --- a/source/opt/copy_prop_arrays.h +++ b/source/opt/copy_prop_arrays.h @@ -59,14 +59,14 @@ class CopyPropagateArrays : public MemPass { // Construction a memory object that is owned by |var_inst|. The iterator // |begin| and |end| traverse a container of integers that identify which // member of |var_inst| this memory object will represent. These integers - // are interpreted the same way they would be in an |OpCompositeExtract| + // are interpreted the same way they would be in an |OpAccessChain| // instruction. template MemoryObject(ir::Instruction* var_inst, iterator begin, iterator end); // Change |this| to now point to the member identified by |access_chain| // (starting from the current member). The elements in |access_chain| are - // interpreted the same as the indicies in the |OpCompositeExtract| + // interpreted the same as the indices in the |OpAccessChain| // instruction. void GetMember(const std::vector& access_chain); @@ -91,8 +91,8 @@ class CopyPropagateArrays : public MemPass { // Returns a vector of integers that can be used to access the specific // member that |this| represents starting from the owning variable. These - // values are to be interpreted the same way the indicies are in an - // |OpCompositeExtract| instruction. + // values are to be interpreted the same way the indices are in an + // |OpAccessChain| instruction. const std::vector& AccessChain() const { return access_chain_; } // Returns the type id of the pointer type that can be used to point to this @@ -104,7 +104,7 @@ class CopyPropagateArrays : public MemPass { type_mgr->GetType(GetVariable()->type_id())->AsPointer(); const analysis::Type* var_type = pointer_type->pointee_type(); const analysis::Type* member_type = - type_mgr->GetMemberType(var_type, AccessChain()); + type_mgr->GetMemberType(var_type, GetAccessIds()); uint32_t member_type_id = type_mgr->GetId(member_type); assert(member_type != 0); uint32_t member_pointer_type_id = type_mgr->FindPointerToType( @@ -127,19 +127,24 @@ class CopyPropagateArrays : public MemPass { // The access chain to reach the particular member the memory object // represents. It should be interpreted the same way the indices in an - // |OpCompositeExtract| are interpreted. + // |OpAccessChain| are interpreted. std::vector access_chain_; + std::vector GetAccessIds() const; }; - // Returns a memory object, if one exists, that can be used in place of + // Returns the memory object being stored to |var_inst| in the store + // instruction |store_inst|, if one exists, that can be used in place of // |var_inst| in all of the loads of |var_inst|. This code is conservative // and only identifies very simple cases. If no such memory object can be // found, the return value is |nullptr|. - std::unique_ptr FindSourceObjectIfPossible( - ir::Instruction* var_inst); + std::unique_ptr FindSourceObjectIfPossible( + ir::Instruction* var_inst, ir::Instruction* store_inst); // Replaces all loads of |var_inst| with a load from |source| instead. - void PropagateObject(ir::Instruction* var_inst, MemoryObject* source); + // |insertion_pos| is a position where it is possible to construct the + // address of |source| and also dominates all of the loads of |var_inst|. + void PropagateObject(ir::Instruction* var_inst, MemoryObject* source, + ir::Instruction* insertion_pos); // Returns true if all of the references to |ptr_inst| can be rewritten and // are dominated by |store_inst|. @@ -200,6 +205,11 @@ class CopyPropagateArrays : public MemPass { // inserted before |insertion_position|. uint32_t GenerateCopy(ir::Instruction* object_to_copy, uint32_t new_type_id, ir::Instruction* insertion_position); + + // Returns a store to |var_inst| that writes to the entire variable, and is + // the only store that does so. Note it does not look through OpAccessChain + // instruction, so partial stores are not considered. + ir::Instruction* FindStoreInstruction(const ir::Instruction* var_inst) const; }; } // namespace opt diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index 34e157d98..9e2cd8652 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -676,6 +676,9 @@ const Type* TypeManager::GetMemberType( parent_type = struct_type->element_types()[element_index]; } else if (const analysis::Array* array_type = parent_type->AsArray()) { parent_type = array_type->element_type(); + } else if (const analysis::RuntimeArray* runtime_array_type = + parent_type->AsRuntimeArray()) { + parent_type = runtime_array_type->element_type(); } else if (const analysis::Vector* vector_type = parent_type->AsVector()) { parent_type = vector_type->element_type(); } else if (const analysis::Matrix* matrix_type = parent_type->AsMatrix()) { diff --git a/test/opt/copy_prop_array_test.cpp b/test/opt/copy_prop_array_test.cpp index b2a685c59..44ae1c299 100644 --- a/test/opt/copy_prop_array_test.cpp +++ b/test/opt/copy_prop_array_test.cpp @@ -74,7 +74,7 @@ OpDecorate %MyCBuffer Binding 0 ; CHECK: OpFunction ; CHECK: OpLabel ; CHECK: OpVariable -; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %uint_0 +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 %main = OpFunction %void None %13 %22 = OpLabel %23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function @@ -156,7 +156,8 @@ OpDecorate %MyCBuffer Binding 0 ; CHECK: OpLabel ; CHECK: OpVariable ; CHECK: OpVariable -; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %uint_0 +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 %main = OpFunction %void None %14 %25 = OpLabel %26 = OpVariable %_ptr_Function__arr_v4float_uint_2_0 Function @@ -174,13 +175,14 @@ OpDecorate %MyCBuffer Binding 0 %38 = OpCompositeConstruct %_arr_v4float_uint_2_0 %36 %37 %39 = OpCompositeConstruct %_arr__arr_v4float_uint_2_0_uint_2 %34 %38 ; CHECK: OpStore -; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 [[new_address]] %28 -; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[ac1]] %28 -; CHECK: OpLoad %v4float [[ac2]] OpStore %27 %39 %40 = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %27 %28 %42 = OpAccessChain %_ptr_Function_v4float %40 %28 %43 = OpLoad %v4float %42 +; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 [[new_address]] %28 +; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[ac1]] %28 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[ac2]] +; CHECK: OpStore %out_var_SV_Target [[load]] OpStore %out_var_SV_Target %43 OpReturn OpFunctionEnd