From d997c83b103ed1f3af09ed65e1cbf89fbc6d9451 Mon Sep 17 00:00:00 2001 From: Jaebaek Seo Date: Tue, 26 Oct 2021 17:20:58 -0400 Subject: [PATCH] Add spirv-opt pass to replace descriptor accesses based on variable indices (#4574) This commit adds a spirv-opt pass to replace accesses to descriptor array based on variable indices with constant elements. Before: ``` %descriptor = OpVariable %_ptr_array_Image Uniform ... %ac = OpAccessChain %_ptr_Image %descriptor %variable_index (some image instructions using %ac) ``` After: ``` %descriptor = OpVariable %_ptr_array_Image Uniform ... OpSwitch %variable_index 0 %case0 1 %case1 ... ... %case0 = OpLabel %ac = OpAccessChain %_ptr_Image %descriptor %uint_0 ... %case1 = OpLabel %ac = OpAccessChain %_ptr_Image %descriptor %uint_1 ... (use OpPhi for value with concrete type) ``` --- Android.mk | 2 + BUILD.gn | 4 + include/spirv-tools/optimizer.hpp | 7 + source/opt/CMakeLists.txt | 4 + source/opt/decoration_manager.cpp | 8 + source/opt/decoration_manager.h | 4 + source/opt/desc_sroa.cpp | 119 +---- source/opt/desc_sroa.h | 9 - source/opt/desc_sroa_util.cpp | 117 +++++ source/opt/desc_sroa_util.h | 54 +++ source/opt/optimizer.cpp | 7 + source/opt/passes.h | 1 + ...lace_desc_array_access_using_var_index.cpp | 423 ++++++++++++++++++ ...eplace_desc_array_access_using_var_index.h | 204 +++++++++ test/opt/CMakeLists.txt | 1 + ...desc_array_access_using_var_index_test.cpp | 411 +++++++++++++++++ tools/opt/opt.cpp | 5 + 17 files changed, 1265 insertions(+), 115 deletions(-) create mode 100644 source/opt/desc_sroa_util.cpp create mode 100644 source/opt/desc_sroa_util.h create mode 100644 source/opt/replace_desc_array_access_using_var_index.cpp create mode 100644 source/opt/replace_desc_array_access_using_var_index.h create mode 100644 test/opt/replace_desc_array_access_using_var_index_test.cpp diff --git a/Android.mk b/Android.mk index b616654e3..bc748e535 100644 --- a/Android.mk +++ b/Android.mk @@ -100,6 +100,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/debug_info_manager.cpp \ source/opt/def_use_manager.cpp \ source/opt/desc_sroa.cpp \ + source/opt/desc_sroa_util.cpp \ source/opt/dominator_analysis.cpp \ source/opt/dominator_tree.cpp \ source/opt/eliminate_dead_constant_pass.cpp \ @@ -157,6 +158,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/relax_float_ops_pass.cpp \ source/opt/remove_duplicates_pass.cpp \ source/opt/remove_unused_interface_variables_pass.cpp \ + source/opt/replace_desc_array_access_using_var_index.cpp \ source/opt/replace_invalid_opc.cpp \ source/opt/scalar_analysis.cpp \ source/opt/scalar_analysis_simplification.cpp \ diff --git a/BUILD.gn b/BUILD.gn index 4bf3f79e0..309d51373 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -597,6 +597,8 @@ static_library("spvtools_opt") { "source/opt/def_use_manager.h", "source/opt/desc_sroa.cpp", "source/opt/desc_sroa.h", + "source/opt/desc_sroa_util.cpp", + "source/opt/desc_sroa_util.h", "source/opt/dominator_analysis.cpp", "source/opt/dominator_analysis.h", "source/opt/dominator_tree.cpp", @@ -716,6 +718,8 @@ static_library("spvtools_opt") { "source/opt/remove_duplicates_pass.h", "source/opt/remove_unused_interface_variables_pass.cpp", "source/opt/remove_unused_interface_variables_pass.h", + "source/opt/replace_desc_array_access_using_var_index.cpp", + "source/opt/replace_desc_array_access_using_var_index.h", "source/opt/replace_invalid_opc.cpp", "source/opt/replace_invalid_opc.h", "source/opt/scalar_analysis.cpp", diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index 42eb6442e..21059cbe3 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -833,6 +833,13 @@ Optimizer::PassToken CreateFixStorageClassPass(); // inclusive. Optimizer::PassToken CreateGraphicsRobustAccessPass(); +// Create a pass to replace a descriptor access using variable index. +// This pass replaces every access using a variable index to array variable +// |desc| that has a DescriptorSet and Binding decorations with a constant +// element of the array. In order to replace the access using a variable index +// with the constant element, it uses a switch statement. +Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass(); + // Create descriptor scalar replacement pass. // This pass replaces every array variable |desc| that has a DescriptorSet and // Binding decorations with a new variable for each element of the array. diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 63af5c1da..7d522fb5e 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -39,6 +39,7 @@ set(SPIRV_TOOLS_OPT_SOURCES debug_info_manager.h def_use_manager.h desc_sroa.h + desc_sroa_util.h dominator_analysis.h dominator_tree.h eliminate_dead_constant_pass.h @@ -100,6 +101,7 @@ set(SPIRV_TOOLS_OPT_SOURCES relax_float_ops_pass.h remove_duplicates_pass.h remove_unused_interface_variables_pass.h + replace_desc_array_access_using_var_index.h replace_invalid_opc.h scalar_analysis.h scalar_analysis_nodes.h @@ -148,6 +150,7 @@ set(SPIRV_TOOLS_OPT_SOURCES debug_info_manager.cpp def_use_manager.cpp desc_sroa.cpp + desc_sroa_util.cpp dominator_analysis.cpp dominator_tree.cpp eliminate_dead_constant_pass.cpp @@ -205,6 +208,7 @@ set(SPIRV_TOOLS_OPT_SOURCES relax_float_ops_pass.cpp remove_duplicates_pass.cpp remove_unused_interface_variables_pass.cpp + replace_desc_array_access_using_var_index.cpp replace_invalid_opc.cpp scalar_analysis.cpp scalar_analysis_simplification.cpp diff --git a/source/opt/decoration_manager.cpp b/source/opt/decoration_manager.cpp index 4bf026efd..2146c359d 100644 --- a/source/opt/decoration_manager.cpp +++ b/source/opt/decoration_manager.cpp @@ -490,6 +490,14 @@ void DecorationManager::ForEachDecoration( }); } +bool DecorationManager::HasDecoration(uint32_t id, uint32_t decoration) { + bool has_decoration = false; + ForEachDecoration(id, decoration, [&has_decoration](const Instruction&) { + has_decoration = true; + }); + return has_decoration; +} + bool DecorationManager::FindDecoration( uint32_t id, uint32_t decoration, std::function f) { diff --git a/source/opt/decoration_manager.h b/source/opt/decoration_manager.h index b753e6be1..fe78f2ce6 100644 --- a/source/opt/decoration_manager.h +++ b/source/opt/decoration_manager.h @@ -90,6 +90,10 @@ class DecorationManager { bool AreDecorationsTheSame(const Instruction* inst1, const Instruction* inst2, bool ignore_target) const; + // Returns whether a decoration instruction for |id| with decoration + // |decoration| exists or not. + bool HasDecoration(uint32_t id, uint32_t decoration); + // |f| is run on each decoration instruction for |id| with decoration // |decoration|. Processed are all decorations which target |id| either // directly or indirectly by Decoration Groups. diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp index 5e950069d..0b8393705 100644 --- a/source/opt/desc_sroa.cpp +++ b/source/opt/desc_sroa.cpp @@ -14,6 +14,7 @@ #include "source/opt/desc_sroa.h" +#include "source/opt/desc_sroa_util.h" #include "source/util/string_utils.h" namespace spvtools { @@ -25,7 +26,7 @@ Pass::Status DescriptorScalarReplacement::Process() { std::vector vars_to_kill; for (Instruction& var : context()->types_values()) { - if (IsCandidate(&var)) { + if (descsroautil::IsDescriptorArray(context(), &var)) { modified = true; if (!ReplaceCandidate(&var)) { return Status::Failure; @@ -41,72 +42,6 @@ Pass::Status DescriptorScalarReplacement::Process() { return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } -bool DescriptorScalarReplacement::IsCandidate(Instruction* var) { - if (var->opcode() != SpvOpVariable) { - return false; - } - - uint32_t ptr_type_id = var->type_id(); - Instruction* ptr_type_inst = - context()->get_def_use_mgr()->GetDef(ptr_type_id); - if (ptr_type_inst->opcode() != SpvOpTypePointer) { - return false; - } - - uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1); - Instruction* var_type_inst = - context()->get_def_use_mgr()->GetDef(var_type_id); - if (var_type_inst->opcode() != SpvOpTypeArray && - var_type_inst->opcode() != SpvOpTypeStruct) { - return false; - } - - // All structures with descriptor assignments must be replaced by variables, - // one for each of their members - with the exceptions of buffers. - if (IsTypeOfStructuredBuffer(var_type_inst)) { - return false; - } - - bool has_desc_set_decoration = false; - context()->get_decoration_mgr()->ForEachDecoration( - var->result_id(), SpvDecorationDescriptorSet, - [&has_desc_set_decoration](const Instruction&) { - has_desc_set_decoration = true; - }); - if (!has_desc_set_decoration) { - return false; - } - - bool has_binding_decoration = false; - context()->get_decoration_mgr()->ForEachDecoration( - var->result_id(), SpvDecorationBinding, - [&has_binding_decoration](const Instruction&) { - has_binding_decoration = true; - }); - if (!has_binding_decoration) { - return false; - } - - return true; -} - -bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer( - const Instruction* type) const { - if (type->opcode() != SpvOpTypeStruct) { - return false; - } - - // All buffers have offset decorations for members of their structure types. - // This is how we distinguish it from a structure of descriptors. - bool has_offset_decoration = false; - context()->get_decoration_mgr()->ForEachDecoration( - type->result_id(), SpvDecorationOffset, - [&has_offset_decoration](const Instruction&) { - has_offset_decoration = true; - }); - return has_offset_decoration; -} - bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { std::vector access_chain_work_list; std::vector load_work_list; @@ -162,16 +97,15 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var, return false; } - uint32_t idx_id = use->GetSingleWordInOperand(1); - const analysis::Constant* idx_const = - context()->get_constant_mgr()->FindDeclaredConstant(idx_id); - if (idx_const == nullptr) { + const analysis::Constant* const_index = + descsroautil::GetAccessChainIndexAsConst(context(), use); + if (const_index == nullptr) { context()->EmitErrorMessage("Variable cannot be replaced: invalid index", use); return false; } - uint32_t idx = idx_const->GetU32(); + uint32_t idx = const_index->GetU32(); uint32_t replacement_var = GetReplacementVariable(var, idx); if (use->NumInOperands() == 2) { @@ -208,39 +142,12 @@ uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var, uint32_t idx) { auto replacement_vars = replacement_variables_.find(var); if (replacement_vars == replacement_variables_.end()) { - uint32_t ptr_type_id = var->type_id(); - Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); - assert(ptr_type_inst->opcode() == SpvOpTypePointer && - "Variable should be a pointer to an array or structure."); - uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); - Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id); - const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray; - const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct; - assert((is_array || is_struct) && - "Variable should be a pointer to an array or structure."); - - // For arrays, each array element should be replaced with a new replacement - // variable - if (is_array) { - uint32_t array_len_id = pointee_type_inst->GetSingleWordInOperand(1); - const analysis::Constant* array_len_const = - context()->get_constant_mgr()->FindDeclaredConstant(array_len_id); - assert(array_len_const != nullptr && "Array length must be a constant."); - uint32_t array_len = array_len_const->GetU32(); - - replacement_vars = replacement_variables_ - .insert({var, std::vector(array_len, 0)}) - .first; - } - // For structures, each member should be replaced with a new replacement - // variable - if (is_struct) { - const uint32_t num_members = pointee_type_inst->NumInOperands(); - replacement_vars = - replacement_variables_ - .insert({var, std::vector(num_members, 0)}) - .first; - } + uint32_t number_of_elements = + descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var); + replacement_vars = + replacement_variables_ + .insert({var, std::vector(number_of_elements, 0)}) + .first; } if (replacement_vars->second[idx] == 0) { @@ -377,7 +284,7 @@ uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType( // The number of bindings consumed by a structure is the sum of the bindings // used by its members. if (type_inst->opcode() == SpvOpTypeStruct && - !IsTypeOfStructuredBuffer(type_inst)) { + !descsroautil::IsTypeOfStructuredBuffer(context(), type_inst)) { uint32_t sum = 0; for (uint32_t i = 0; i < type_inst->NumInOperands(); i++) sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i)); diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h index cd72fd301..70fd381af 100644 --- a/source/opt/desc_sroa.h +++ b/source/opt/desc_sroa.h @@ -46,10 +46,6 @@ class DescriptorScalarReplacement : public Pass { } private: - // Returns true if |var| is an OpVariable instruction that represents a - // descriptor array. These are the variables that we want to replace. - bool IsCandidate(Instruction* var); - // Replaces all references to |var| by new variables, one for each element of // the array |var|. The binding for the new variables corresponding to // element i will be the binding of |var| plus i. Returns true if successful. @@ -93,11 +89,6 @@ class DescriptorScalarReplacement : public Pass { // bindings used by its members. uint32_t GetNumBindingsUsedByType(uint32_t type_id); - // Returns true if |type| is a type that could be used for a structured buffer - // as opposed to a type that would be used for a structure of resource - // descriptors. - bool IsTypeOfStructuredBuffer(const Instruction* type) const; - // A map from an OpVariable instruction to the set of variables that will be // used to replace it. The entry |replacement_variables_[var][i]| is the id of // a variable that will be used in the place of the the ith element of the diff --git a/source/opt/desc_sroa_util.cpp b/source/opt/desc_sroa_util.cpp new file mode 100644 index 000000000..1954e2cc3 --- /dev/null +++ b/source/opt/desc_sroa_util.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/desc_sroa_util.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kOpAccessChainInOperandIndexes = 1; + +// Returns the length of array type |type|. +uint32_t GetLengthOfArrayType(IRContext* context, Instruction* type) { + assert(type->opcode() == SpvOpTypeArray && "type must be array"); + uint32_t length_id = type->GetSingleWordInOperand(1); + const analysis::Constant* length_const = + context->get_constant_mgr()->FindDeclaredConstant(length_id); + assert(length_const != nullptr); + return length_const->GetU32(); +} + +} // namespace + +namespace descsroautil { + +bool IsDescriptorArray(IRContext* context, Instruction* var) { + if (var->opcode() != SpvOpVariable) { + return false; + } + + uint32_t ptr_type_id = var->type_id(); + Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id); + if (ptr_type_inst->opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* var_type_inst = context->get_def_use_mgr()->GetDef(var_type_id); + if (var_type_inst->opcode() != SpvOpTypeArray && + var_type_inst->opcode() != SpvOpTypeStruct) { + return false; + } + + // All structures with descriptor assignments must be replaced by variables, + // one for each of their members - with the exceptions of buffers. + if (IsTypeOfStructuredBuffer(context, var_type_inst)) { + return false; + } + + if (!context->get_decoration_mgr()->HasDecoration( + var->result_id(), SpvDecorationDescriptorSet)) { + return false; + } + + return context->get_decoration_mgr()->HasDecoration(var->result_id(), + SpvDecorationBinding); +} + +bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type) { + if (type->opcode() != SpvOpTypeStruct) { + return false; + } + + // All buffers have offset decorations for members of their structure types. + // This is how we distinguish it from a structure of descriptors. + return context->get_decoration_mgr()->HasDecoration(type->result_id(), + SpvDecorationOffset); +} + +const analysis::Constant* GetAccessChainIndexAsConst( + IRContext* context, Instruction* access_chain) { + if (access_chain->NumInOperands() <= 1) { + return nullptr; + } + uint32_t idx_id = GetFirstIndexOfAccessChain(access_chain); + const analysis::Constant* idx_const = + context->get_constant_mgr()->FindDeclaredConstant(idx_id); + return idx_const; +} + +uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) { + assert(access_chain->NumInOperands() > 1 && + "OpAccessChain does not have Indexes operand"); + return access_chain->GetSingleWordInOperand(kOpAccessChainInOperandIndexes); +} + +uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context, + Instruction* var) { + uint32_t ptr_type_id = var->type_id(); + Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id); + assert(ptr_type_inst->opcode() == SpvOpTypePointer && + "Variable should be a pointer to an array or structure."); + uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* pointee_type_inst = + context->get_def_use_mgr()->GetDef(pointee_type_id); + if (pointee_type_inst->opcode() == SpvOpTypeArray) { + return GetLengthOfArrayType(context, pointee_type_inst); + } + assert(pointee_type_inst->opcode() == SpvOpTypeStruct && + "Variable should be a pointer to an array or structure."); + return pointee_type_inst->NumInOperands(); +} + +} // namespace descsroautil +} // namespace opt +} // namespace spvtools diff --git a/source/opt/desc_sroa_util.h b/source/opt/desc_sroa_util.h new file mode 100644 index 000000000..2f45c0c2f --- /dev/null +++ b/source/opt/desc_sroa_util.h @@ -0,0 +1,54 @@ +// Copyright (c) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DESC_SROA_UTIL_H_ +#define SOURCE_OPT_DESC_SROA_UTIL_H_ + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +// Provides functions for the descriptor array SROA. +namespace descsroautil { + +// Returns true if |var| is an OpVariable instruction that represents a +// descriptor array. +bool IsDescriptorArray(IRContext* context, Instruction* var); + +// Returns true if |type| is a type that could be used for a structured buffer +// as opposed to a type that would be used for a structure of resource +// descriptors. +bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type); + +// Returns the first index of the OpAccessChain instruction |access_chain| as +// a constant. Returns nullptr if it is not a constant. +const analysis::Constant* GetAccessChainIndexAsConst(IRContext* context, + Instruction* access_chain); + +// Returns the number of elements of an OpVariable instruction |var| whose type +// must be a pointer to an array or a struct. +uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context, + Instruction* var); + +// Returns the first Indexes operand id of the OpAccessChain or +// OpInBoundsAccessChain instruction |access_chain|. The access chain must have +// at least 1 index. +uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain); + +} // namespace descsroautil +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DESC_SROA_UTIL_H_ diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 990be2e43..e74db26f0 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -320,6 +320,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateCombineAccessChainsPass()); } else if (pass_name == "convert-local-access-chains") { RegisterPass(CreateLocalAccessChainConvertPass()); + } else if (pass_name == "replace-desc-array-access-using-var-index") { + RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass()); } else if (pass_name == "descriptor-scalar-replacement") { RegisterPass(CreateDescriptorScalarReplacementPass()); } else if (pass_name == "eliminate-dead-code-aggressive") { @@ -958,6 +960,11 @@ Optimizer::PassToken CreateGraphicsRobustAccessPass() { MakeUnique()); } +Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() { + return MakeUnique( + MakeUnique()); +} + Optimizer::PassToken CreateDescriptorScalarReplacementPass() { return MakeUnique( MakeUnique()); diff --git a/source/opt/passes.h b/source/opt/passes.h index da837f2ab..f3c30d578 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -66,6 +66,7 @@ #include "source/opt/relax_float_ops_pass.h" #include "source/opt/remove_duplicates_pass.h" #include "source/opt/remove_unused_interface_variables_pass.h" +#include "source/opt/replace_desc_array_access_using_var_index.h" #include "source/opt/replace_invalid_opc.h" #include "source/opt/scalar_replacement_pass.h" #include "source/opt/set_spec_constant_default_value_pass.h" diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp new file mode 100644 index 000000000..1082e679b --- /dev/null +++ b/source/opt/replace_desc_array_access_using_var_index.cpp @@ -0,0 +1,423 @@ +// Copyright (c) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/replace_desc_array_access_using_var_index.h" + +#include "source/opt/desc_sroa_util.h" +#include "source/opt/ir_builder.h" +#include "source/util/string_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kOpAccessChainInOperandIndexes = 1; +const uint32_t kOpTypePointerInOperandType = 1; +const uint32_t kOpTypeArrayInOperandType = 0; +const uint32_t kOpTypeStructInOperandMember = 0; +IRContext::Analysis kAnalysisDefUseAndInstrToBlockMapping = + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; + +uint32_t GetValueWithKeyExistenceCheck( + uint32_t key, const std::unordered_map& map) { + auto itr = map.find(key); + assert(itr != map.end() && "Key does not exist"); + return itr->second; +} + +} // namespace + +Pass::Status ReplaceDescArrayAccessUsingVarIndex::Process() { + Status status = Status::SuccessWithoutChange; + for (Instruction& var : context()->types_values()) { + if (descsroautil::IsDescriptorArray(context(), &var)) { + if (ReplaceVariableAccessesWithConstantElements(&var)) + status = Status::SuccessWithChange; + } + } + return status; +} + +bool ReplaceDescArrayAccessUsingVarIndex:: + ReplaceVariableAccessesWithConstantElements(Instruction* var) const { + std::vector work_list; + get_def_use_mgr()->ForEachUser(var, [&work_list](Instruction* use) { + switch (use->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + work_list.push_back(use); + break; + default: + break; + } + }); + + bool updated = false; + for (Instruction* access_chain : work_list) { + if (descsroautil::GetAccessChainIndexAsConst(context(), access_chain) == + nullptr) { + ReplaceAccessChain(var, access_chain); + updated = true; + } + } + // Note that we do not consider OpLoad and OpCompositeExtract because + // OpCompositeExtract always has constant literals for indices. + return updated; +} + +void ReplaceDescArrayAccessUsingVarIndex::ReplaceAccessChain( + Instruction* var, Instruction* access_chain) const { + uint32_t number_of_elements = + descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var); + assert(number_of_elements != 0 && "Number of element is 0"); + if (number_of_elements == 1) { + UseConstIndexForAccessChain(access_chain, 0); + get_def_use_mgr()->AnalyzeInstUse(access_chain); + return; + } + ReplaceUsersOfAccessChain(access_chain, number_of_elements); +} + +void ReplaceDescArrayAccessUsingVarIndex::ReplaceUsersOfAccessChain( + Instruction* access_chain, uint32_t number_of_elements) const { + std::vector final_users; + CollectRecursiveUsersWithConcreteType(access_chain, &final_users); + for (auto* inst : final_users) { + std::deque insts_to_be_cloned = + CollectRequiredImageInsts(inst); + ReplaceNonUniformAccessWithSwitchCase( + inst, access_chain, number_of_elements, insts_to_be_cloned); + } +} + +void ReplaceDescArrayAccessUsingVarIndex::CollectRecursiveUsersWithConcreteType( + Instruction* access_chain, std::vector* final_users) const { + std::queue work_list; + work_list.push(access_chain); + while (!work_list.empty()) { + auto* inst_from_work_list = work_list.front(); + work_list.pop(); + get_def_use_mgr()->ForEachUser( + inst_from_work_list, [this, final_users, &work_list](Instruction* use) { + // TODO: Support Boolean type as well. + if (!use->HasResultId() || IsConcreteType(use->type_id())) { + final_users->push_back(use); + } else { + work_list.push(use); + } + }); + } +} + +std::deque +ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageInsts( + Instruction* user_of_image_insts) const { + std::unordered_set seen_inst_ids; + std::queue work_list; + + auto decision_to_include_operand = [this, &seen_inst_ids, + &work_list](uint32_t* idp) { + if (!seen_inst_ids.insert(*idp).second) return; + Instruction* operand = get_def_use_mgr()->GetDef(*idp); + if (context()->get_instr_block(operand) != nullptr && + HasImageOrImagePtrType(operand)) { + work_list.push(operand); + } + }; + + std::deque required_image_insts; + required_image_insts.push_front(user_of_image_insts); + user_of_image_insts->ForEachInId(decision_to_include_operand); + while (!work_list.empty()) { + auto* inst_from_work_list = work_list.front(); + work_list.pop(); + required_image_insts.push_front(inst_from_work_list); + inst_from_work_list->ForEachInId(decision_to_include_operand); + } + return required_image_insts; +} + +bool ReplaceDescArrayAccessUsingVarIndex::HasImageOrImagePtrType( + const Instruction* inst) const { + assert(inst != nullptr && inst->type_id() != 0 && "Invalid instruction"); + return IsImageOrImagePtrType(get_def_use_mgr()->GetDef(inst->type_id())); +} + +bool ReplaceDescArrayAccessUsingVarIndex::IsImageOrImagePtrType( + const Instruction* type_inst) const { + if (type_inst->opcode() == SpvOpTypeImage || + type_inst->opcode() == SpvOpTypeSampler || + type_inst->opcode() == SpvOpTypeSampledImage) { + return true; + } + if (type_inst->opcode() == SpvOpTypePointer) { + Instruction* pointee_type_inst = get_def_use_mgr()->GetDef( + type_inst->GetSingleWordInOperand(kOpTypePointerInOperandType)); + return IsImageOrImagePtrType(pointee_type_inst); + } + if (type_inst->opcode() == SpvOpTypeArray) { + Instruction* element_type_inst = get_def_use_mgr()->GetDef( + type_inst->GetSingleWordInOperand(kOpTypeArrayInOperandType)); + return IsImageOrImagePtrType(element_type_inst); + } + if (type_inst->opcode() != SpvOpTypeStruct) return false; + for (uint32_t in_operand_idx = kOpTypeStructInOperandMember; + in_operand_idx < type_inst->NumInOperands(); ++in_operand_idx) { + Instruction* member_type_inst = get_def_use_mgr()->GetDef( + type_inst->GetSingleWordInOperand(kOpTypeStructInOperandMember)); + if (IsImageOrImagePtrType(member_type_inst)) return true; + } + return false; +} + +bool ReplaceDescArrayAccessUsingVarIndex::IsConcreteType( + uint32_t type_id) const { + Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); + if (type_inst->opcode() == SpvOpTypeInt || + type_inst->opcode() == SpvOpTypeFloat) { + return true; + } + if (type_inst->opcode() == SpvOpTypeVector || + type_inst->opcode() == SpvOpTypeMatrix || + type_inst->opcode() == SpvOpTypeArray) { + return IsConcreteType(type_inst->GetSingleWordInOperand(0)); + } + if (type_inst->opcode() == SpvOpTypeStruct) { + for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) { + if (!IsConcreteType(type_inst->GetSingleWordInOperand(i))) return false; + } + return true; + } + return false; +} + +BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateCaseBlock( + Instruction* access_chain, uint32_t element_index, + const std::deque& insts_to_be_cloned, + uint32_t branch_target_id, + std::unordered_map* old_ids_to_new_ids) const { + auto* case_block = CreateNewBlock(); + AddConstElementAccessToCaseBlock(case_block, access_chain, element_index, + old_ids_to_new_ids); + CloneInstsToBlock(case_block, access_chain, insts_to_be_cloned, + old_ids_to_new_ids); + AddBranchToBlock(case_block, branch_target_id); + UseNewIdsInBlock(case_block, *old_ids_to_new_ids); + return case_block; +} + +void ReplaceDescArrayAccessUsingVarIndex::CloneInstsToBlock( + BasicBlock* block, Instruction* inst_to_skip_cloning, + const std::deque& insts_to_be_cloned, + std::unordered_map* old_ids_to_new_ids) const { + for (auto* inst_to_be_cloned : insts_to_be_cloned) { + if (inst_to_be_cloned == inst_to_skip_cloning) continue; + std::unique_ptr clone(inst_to_be_cloned->Clone(context())); + if (inst_to_be_cloned->HasResultId()) { + uint32_t new_id = context()->TakeNextId(); + clone->SetResultId(new_id); + (*old_ids_to_new_ids)[inst_to_be_cloned->result_id()] = new_id; + } + get_def_use_mgr()->AnalyzeInstDefUse(clone.get()); + context()->set_instr_block(clone.get(), block); + block->AddInstruction(std::move(clone)); + } +} + +void ReplaceDescArrayAccessUsingVarIndex::UseNewIdsInBlock( + BasicBlock* block, + const std::unordered_map& old_ids_to_new_ids) const { + for (auto block_itr = block->begin(); block_itr != block->end(); + ++block_itr) { + (&*block_itr)->ForEachInId([&old_ids_to_new_ids](uint32_t* idp) { + auto old_ids_to_new_ids_itr = old_ids_to_new_ids.find(*idp); + if (old_ids_to_new_ids_itr == old_ids_to_new_ids.end()) return; + *idp = old_ids_to_new_ids_itr->second; + }); + get_def_use_mgr()->AnalyzeInstUse(&*block_itr); + } +} + +void ReplaceDescArrayAccessUsingVarIndex::ReplaceNonUniformAccessWithSwitchCase( + Instruction* access_chain_final_user, Instruction* access_chain, + uint32_t number_of_elements, + const std::deque& insts_to_be_cloned) const { + // Create merge block and add terminator + auto* block = context()->get_instr_block(access_chain_final_user); + auto* merge_block = SeparateInstructionsIntoNewBlock( + block, access_chain_final_user->NextNode()); + + auto* function = block->GetParent(); + + // Add case blocks + std::vector phi_operands; + std::vector case_block_ids; + for (uint32_t idx = 0; idx < number_of_elements; ++idx) { + std::unordered_map old_ids_to_new_ids_for_cloned_insts; + std::unique_ptr case_block(CreateCaseBlock( + access_chain, idx, insts_to_be_cloned, merge_block->id(), + &old_ids_to_new_ids_for_cloned_insts)); + case_block_ids.push_back(case_block->id()); + function->InsertBasicBlockBefore(std::move(case_block), merge_block); + + // Keep the operand for OpPhi + if (!access_chain_final_user->HasResultId()) continue; + uint32_t phi_operand = + GetValueWithKeyExistenceCheck(access_chain_final_user->result_id(), + old_ids_to_new_ids_for_cloned_insts); + phi_operands.push_back(phi_operand); + } + + // Create default block + std::unique_ptr default_block( + CreateDefaultBlock(access_chain_final_user->HasResultId(), &phi_operands, + merge_block->id())); + uint32_t default_block_id = default_block->id(); + function->InsertBasicBlockBefore(std::move(default_block), merge_block); + + // Create OpSwitch + uint32_t access_chain_index_var_id = + descsroautil::GetFirstIndexOfAccessChain(access_chain); + AddSwitchForAccessChain(block, access_chain_index_var_id, default_block_id, + merge_block->id(), case_block_ids); + + // Create phi instructions + if (!phi_operands.empty()) { + uint32_t phi_id = CreatePhiInstruction(merge_block, phi_operands, + case_block_ids, default_block_id); + context()->ReplaceAllUsesWith(access_chain_final_user->result_id(), phi_id); + } + + // Replace OpPhi incoming block operand that uses |block| with |merge_block| + ReplacePhiIncomingBlock(block->id(), merge_block->id()); +} + +BasicBlock* +ReplaceDescArrayAccessUsingVarIndex::SeparateInstructionsIntoNewBlock( + BasicBlock* block, Instruction* separation_begin_inst) const { + auto separation_begin = block->begin(); + while (separation_begin != block->end() && + &*separation_begin != separation_begin_inst) { + ++separation_begin; + } + return block->SplitBasicBlock(context(), context()->TakeNextId(), + separation_begin); +} + +BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateNewBlock() const { + auto* new_block = new BasicBlock(std::unique_ptr( + new Instruction(context(), SpvOpLabel, 0, context()->TakeNextId(), {}))); + get_def_use_mgr()->AnalyzeInstDefUse(new_block->GetLabelInst()); + context()->set_instr_block(new_block->GetLabelInst(), new_block); + return new_block; +} + +void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain( + Instruction* access_chain, uint32_t const_element_idx) const { + uint32_t const_element_idx_id = + context()->get_constant_mgr()->GetUIntConst(const_element_idx); + access_chain->SetInOperand(kOpAccessChainInOperandIndexes, + {const_element_idx_id}); +} + +void ReplaceDescArrayAccessUsingVarIndex::AddConstElementAccessToCaseBlock( + BasicBlock* case_block, Instruction* access_chain, + uint32_t const_element_idx, + std::unordered_map* old_ids_to_new_ids) const { + std::unique_ptr access_clone(access_chain->Clone(context())); + UseConstIndexForAccessChain(access_clone.get(), const_element_idx); + + uint32_t new_access_id = context()->TakeNextId(); + (*old_ids_to_new_ids)[access_clone->result_id()] = new_access_id; + access_clone->SetResultId(new_access_id); + get_def_use_mgr()->AnalyzeInstDefUse(access_clone.get()); + + context()->set_instr_block(access_clone.get(), case_block); + case_block->AddInstruction(std::move(access_clone)); +} + +void ReplaceDescArrayAccessUsingVarIndex::AddBranchToBlock( + BasicBlock* parent_block, uint32_t branch_destination) const { + InstructionBuilder builder{context(), parent_block, + kAnalysisDefUseAndInstrToBlockMapping}; + builder.AddBranch(branch_destination); +} + +BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateDefaultBlock( + bool null_const_for_phi_is_needed, std::vector* phi_operands, + uint32_t merge_block_id) const { + auto* default_block = CreateNewBlock(); + AddBranchToBlock(default_block, merge_block_id); + if (!null_const_for_phi_is_needed) return default_block; + + // Create null value for OpPhi + Instruction* inst = context()->get_def_use_mgr()->GetDef((*phi_operands)[0]); + auto* null_const_inst = GetConstNull(inst->type_id()); + phi_operands->push_back(null_const_inst->result_id()); + return default_block; +} + +Instruction* ReplaceDescArrayAccessUsingVarIndex::GetConstNull( + uint32_t type_id) const { + assert(type_id != 0 && "Result type is expected"); + auto* type = context()->get_type_mgr()->GetType(type_id); + auto* null_const = context()->get_constant_mgr()->GetConstant(type, {}); + return context()->get_constant_mgr()->GetDefiningInstruction(null_const); +} + +void ReplaceDescArrayAccessUsingVarIndex::AddSwitchForAccessChain( + BasicBlock* parent_block, uint32_t access_chain_index_var_id, + uint32_t default_id, uint32_t merge_id, + const std::vector& case_block_ids) const { + InstructionBuilder builder{context(), parent_block, + kAnalysisDefUseAndInstrToBlockMapping}; + std::vector> cases; + for (uint32_t i = 0; i < static_cast(case_block_ids.size()); ++i) { + cases.emplace_back(Operand::OperandData{i}, case_block_ids[i]); + } + builder.AddSwitch(access_chain_index_var_id, default_id, cases, merge_id); +} + +uint32_t ReplaceDescArrayAccessUsingVarIndex::CreatePhiInstruction( + BasicBlock* parent_block, const std::vector& phi_operands, + const std::vector& case_block_ids, + uint32_t default_block_id) const { + std::vector incomings; + assert(case_block_ids.size() + 1 == phi_operands.size() && + "Number of Phi operands must be exactly 1 bigger than the one of case " + "blocks"); + for (size_t i = 0; i < case_block_ids.size(); ++i) { + incomings.push_back(phi_operands[i]); + incomings.push_back(case_block_ids[i]); + } + incomings.push_back(phi_operands.back()); + incomings.push_back(default_block_id); + + InstructionBuilder builder{context(), &*parent_block->begin(), + kAnalysisDefUseAndInstrToBlockMapping}; + uint32_t phi_result_type_id = + context()->get_def_use_mgr()->GetDef(phi_operands[0])->type_id(); + auto* phi = builder.AddPhi(phi_result_type_id, incomings); + return phi->result_id(); +} + +void ReplaceDescArrayAccessUsingVarIndex::ReplacePhiIncomingBlock( + uint32_t old_incoming_block_id, uint32_t new_incoming_block_id) const { + context()->ReplaceAllUsesWithPredicate( + old_incoming_block_id, new_incoming_block_id, + [](Instruction* use) { return use->opcode() == SpvOpPhi; }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/replace_desc_array_access_using_var_index.h b/source/opt/replace_desc_array_access_using_var_index.h new file mode 100644 index 000000000..e18222c85 --- /dev/null +++ b/source/opt/replace_desc_array_access_using_var_index.h @@ -0,0 +1,204 @@ +// Copyright (c) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_ +#define SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class ReplaceDescArrayAccessUsingVarIndex : public Pass { + public: + ReplaceDescArrayAccessUsingVarIndex() {} + + const char* name() const override { + return "replace-desc-array-access-using-var-index"; + } + + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Replaces all acceses to |var| using variable indices with constant + // elements of the array |var|. Creates switch-case statements to determine + // the value of the variable index for all the possible cases. Returns + // whether replacement is done or not. + bool ReplaceVariableAccessesWithConstantElements(Instruction* var) const; + + // Replaces the OpAccessChain or OpInBoundsAccessChain instruction |use| that + // uses the descriptor variable |var| with the OpAccessChain or + // OpInBoundsAccessChain instruction with a constant Indexes operand. + void ReplaceAccessChain(Instruction* var, Instruction* use) const; + + // Updates the first Indexes operand of the OpAccessChain or + // OpInBoundsAccessChain instruction |access_chain| to let it use a constant + // index |const_element_idx|. + void UseConstIndexForAccessChain(Instruction* access_chain, + uint32_t const_element_idx) const; + + // Replaces users of the OpAccessChain or OpInBoundsAccessChain instruction + // |access_chain| that accesses an array descriptor variable using variable + // indices with constant elements. |number_of_elements| is the number + // of array elements. + void ReplaceUsersOfAccessChain(Instruction* access_chain, + uint32_t number_of_elements) const; + + // Puts all the recursive users of |access_chain| with concrete result types + // or the ones without result it in |final_users|. + void CollectRecursiveUsersWithConcreteType( + Instruction* access_chain, std::vector* final_users) const; + + // Recursively collects the operands of |user_of_image_insts| (and operands + // of the operands) whose result types are images/samplers or pointers/array/ + // struct of them and returns them. + std::deque CollectRequiredImageInsts( + Instruction* user_of_image_insts) const; + + // Returns whether result type of |inst| is an image/sampler/pointer of image + // or sampler or not. + bool HasImageOrImagePtrType(const Instruction* inst) const; + + // Returns whether |type_inst| is an image/sampler or pointer/array/struct of + // image or sampler or not. + bool IsImageOrImagePtrType(const Instruction* type_inst) const; + + // Returns whether the type with |type_id| is a concrete type or not. + bool IsConcreteType(uint32_t type_id) const; + + // Replaces the non-uniform access to a descriptor variable + // |access_chain_final_user| with OpSwitch instruction and case blocks. Each + // case block will contain a clone of |access_chain| and clones of + // |non_uniform_accesses_to_clone| that are recursively used by + // |access_chain_final_user|. The clone of |access_chain| (or + // OpInBoundsAccessChain) will have a constant index for its first index. The + // OpSwitch instruction will have the cases for the variable index of + // |access_chain| from 0 to |number_of_elements| - 1. + void ReplaceNonUniformAccessWithSwitchCase( + Instruction* access_chain_final_user, Instruction* access_chain, + uint32_t number_of_elements, + const std::deque& non_uniform_accesses_to_clone) const; + + // Creates and returns a new basic block that contains all instructions of + // |block| after |separation_begin_inst|. The new basic block is added to the + // function in this method. + BasicBlock* SeparateInstructionsIntoNewBlock( + BasicBlock* block, Instruction* separation_begin_inst) const; + + // Creates and returns a new block. + BasicBlock* CreateNewBlock() const; + + // Returns the first operand id of the OpAccessChain or OpInBoundsAccessChain + // instruction |access_chain|. + uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) const; + + // Adds a clone of the OpAccessChain or OpInBoundsAccessChain instruction + // |access_chain| to |case_block|. The clone of |access_chain| will use + // |const_element_idx| for its first index. |old_ids_to_new_ids| keeps the + // mapping from the result id of |access_chain| to the result of its clone. + void AddConstElementAccessToCaseBlock( + BasicBlock* case_block, Instruction* access_chain, + uint32_t const_element_idx, + std::unordered_map* old_ids_to_new_ids) const; + + // Clones all instructions in |insts_to_be_cloned| and put them to |block|. + // |old_ids_to_new_ids| keeps the mapping from the result id of each + // instruction of |insts_to_be_cloned| to the result of their clones. + void CloneInstsToBlock( + BasicBlock* block, Instruction* inst_to_skip_cloning, + const std::deque& insts_to_be_cloned, + std::unordered_map* old_ids_to_new_ids) const; + + // Adds OpBranch to |branch_destination| at the end of |parent_block|. + void AddBranchToBlock(BasicBlock* parent_block, + uint32_t branch_destination) const; + + // Replaces in-operands of all instructions in the basic block |block| using + // |old_ids_to_new_ids|. It conducts the replacement only if the in-operand + // id is a key of |old_ids_to_new_ids|. + void UseNewIdsInBlock( + BasicBlock* block, + const std::unordered_map& old_ids_to_new_ids) const; + + // Creates a case block for |element_index| case. It adds clones of + // |insts_to_be_cloned| and a clone of |access_chain| with |element_index| as + // its first index. The termination instruction of the created case block will + // be a branch to |branch_target_id|. Puts old ids to new ids map for the + // cloned instructions in |old_ids_to_new_ids|. + BasicBlock* CreateCaseBlock( + Instruction* access_chain, uint32_t element_index, + const std::deque& insts_to_be_cloned, + uint32_t branch_target_id, + std::unordered_map* old_ids_to_new_ids) const; + + // Creates a default block for switch-case statement that has only a single + // instruction OpBranch whose target is a basic block with |merge_block_id|. + // If |null_const_for_phi_is_needed| is true, gets or creates a default null + // constant value for a phi instruction whose operands are |phi_operands| and + // puts it in |phi_operands|. + BasicBlock* CreateDefaultBlock(bool null_const_for_phi_is_needed, + std::vector* phi_operands, + uint32_t merge_block_id) const; + + // Creates and adds an OpSwitch used for the selection of OpAccessChain whose + // first Indexes operand is |access_chain_index_var_id|. The OpSwitch will be + // added at the end of |parent_block|. It will jump to |default_id| for the + // default case and jumps to one of case blocks whoes ids are |case_block_ids| + // if |access_chain_index_var_id| matches the case number. |merge_id| is the + // merge block id. + void AddSwitchForAccessChain( + BasicBlock* parent_block, uint32_t access_chain_index_var_id, + uint32_t default_id, uint32_t merge_id, + const std::vector& case_block_ids) const; + + // Creates a phi instruction with |phi_operands| as values and + // |case_block_ids| and |default_block_id| as incoming blocks. The size of + // |phi_operands| must be exactly 1 larger than the size of |case_block_ids|. + // The last element of |phi_operands| will be used for |default_block_id|. It + // adds the phi instruction to the beginning of |parent_block|. + uint32_t CreatePhiInstruction(BasicBlock* parent_block, + const std::vector& phi_operands, + const std::vector& case_block_ids, + uint32_t default_block_id) const; + + // Replaces the incoming block operand of OpPhi instructions with + // |new_incoming_block_id| if the incoming block operand is + // |old_incoming_block_id|. + void ReplacePhiIncomingBlock(uint32_t old_incoming_block_id, + uint32_t new_incoming_block_id) const; + + // Create an OpConstantNull instruction whose result type id is |type_id|. + Instruction* GetConstNull(uint32_t type_id) const; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_ diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 11299755c..bc44e8d6e 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -85,6 +85,7 @@ add_spvtools_unittest(TARGET opt remove_unused_interface_variables_test.cpp register_liveness.cpp relax_float_ops_test.cpp + replace_desc_array_access_using_var_index_test.cpp replace_invalid_opc_test.cpp scalar_analysis.cpp scalar_replacement_test.cpp diff --git a/test/opt/replace_desc_array_access_using_var_index_test.cpp b/test/opt/replace_desc_array_access_using_var_index_test.cpp new file mode 100644 index 000000000..ca6258126 --- /dev/null +++ b/test/opt/replace_desc_array_access_using_var_index_test.cpp @@ -0,0 +1,411 @@ +// Copyright (c) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ReplaceDescArrayAccessUsingVarIndexTest = PassTest<::testing::Test>; + +TEST_F(ReplaceDescArrayAccessUsingVarIndexTest, + ReplaceAccessChainToTextureArray) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET + OpExecutionMode %psmain OriginUpperLeft + OpSource HLSL 600 + OpName %type_sampler "type.sampler" + OpName %Sampler0 "Sampler0" + OpName %type_2d_image "type.2d.image" + OpName %Tex0 "Tex0" + OpName %in_var_INSTANCEID "in.var.INSTANCEID" + OpName %out_var_SV_TARGET "out.var.SV_TARGET" + OpName %psmain "psmain" + OpName %type_sampled_image "type.sampled.image" + OpDecorate %gl_FragCoord BuiltIn FragCoord + OpDecorate %in_var_INSTANCEID Flat + OpDecorate %in_var_INSTANCEID Location 0 + OpDecorate %out_var_SV_TARGET Location 0 + OpDecorate %Sampler0 DescriptorSet 0 + OpDecorate %Sampler0 Binding 1 + OpDecorate %Tex0 DescriptorSet 0 + OpDecorate %Tex0 Binding 2 + %bool = OpTypeBool +%type_sampler = OpTypeSampler +%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %float = OpTypeFloat 32 +%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown +%_arr_type_2d_image_uint_3 = OpTypeArray %type_2d_image %uint_3 +%_ptr_UniformConstant__arr_type_2d_image_uint_3 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_3 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Input_uint = OpTypePointer Input %uint +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %21 = OpTypeFunction %void +%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image + %v2float = OpTypeVector %float 2 + %v2uint = OpTypeVector %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %27 = OpConstantComposite %v2uint %uint_0 %uint_1 +%type_sampled_image = OpTypeSampledImage %type_2d_image + %Sampler0 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant + %Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_3 UniformConstant +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output + %uint_2 = OpConstant %uint 2 + %66 = OpConstantNull %v4float + +; CHECK: [[null_value:%\w+]] = OpConstantNull %v4float + + %psmain = OpFunction %void None %21 + %39 = OpLabel + %29 = OpLoad %v4float %gl_FragCoord + %30 = OpLoad %uint %in_var_INSTANCEID + %37 = OpIEqual %bool %30 %uint_2 + OpSelectionMerge %38 None + OpBranchConditional %37 %28 %40 + +; CHECK: [[var_index:%\w+]] = OpLoad %uint %in_var_INSTANCEID +; CHECK: OpSelectionMerge [[cond_branch_merge:%\w+]] None +; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb_cond_br:%\w+]] + + %28 = OpLabel + %31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30 + %32 = OpLoad %type_2d_image %31 + OpImageWrite %32 %27 %29 + +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]] 2 [[case2:%\w+]] +; CHECK: [[case0]] = OpLabel +; CHECK: OpAccessChain +; CHECK: OpLoad +; CHECK: OpImageWrite +; CHECK: OpBranch [[merge]] +; CHECK: [[case1]] = OpLabel +; CHECK: OpAccessChain +; CHECK: OpLoad +; CHECK: OpImageWrite +; CHECK: OpBranch [[merge]] +; CHECK: [[case2]] = OpLabel +; CHECK: OpAccessChain +; CHECK: OpLoad +; CHECK: OpImageWrite +; CHECK: OpBranch [[merge]] +; CHECK: [[default]] = OpLabel +; CHECK: OpBranch [[merge]] +; CHECK: [[merge]] = OpLabel + + %33 = OpLoad %type_sampler %Sampler0 + %34 = OpVectorShuffle %v2float %29 %29 0 1 + %35 = OpSampledImage %type_sampled_image %32 %33 + %36 = OpImageSampleImplicitLod %v4float %35 %34 None + +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]] 2 [[case2:%\w+]] +; CHECK: [[case0]] = OpLabel +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0 +; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0 +; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]] +; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]] +; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]] +; CHECK: OpBranch [[merge]] +; CHECK: [[case1]] = OpLabel +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1 +; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0 +; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]] +; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]] +; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]] +; CHECK: OpBranch [[merge]] +; CHECK: [[case2]] = OpLabel +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_2 +; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0 +; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]] +; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]] +; CHECK: [[value2:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]] +; CHECK: OpBranch [[merge]] +; CHECK: [[default]] = OpLabel +; CHECK: OpBranch [[merge]] +; CHECK: [[merge]] = OpLabel +; CHECK: [[phi0:%\w+]] = OpPhi %v4float [[value0]] [[case0]] [[value1]] [[case1]] [[value2]] [[case2]] [[null_value]] [[default]] + + OpBranch %38 + %40 = OpLabel + OpBranch %38 + %38 = OpLabel + %41 = OpPhi %v4float %36 %28 %29 %40 + +; CHECK: OpBranch [[cond_branch_merge]] +; CHECK: [[bb_cond_br]] = OpLabel +; CHECK: OpBranch [[cond_branch_merge]] +; CHECK: [[cond_branch_merge]] = OpLabel +; CHECK: [[phi1:%\w+]] = OpPhi %v4float [[phi0]] [[merge]] {{%\w+}} [[bb_cond_br]] +; CHECK: OpStore {{%\w+}} [[phi1]] + + OpStore %out_var_SV_TARGET %41 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ReplaceDescArrayAccessUsingVarIndexTest, + ReplaceAccessChainToTextureArrayAndSamplerArray) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET + OpExecutionMode %psmain OriginUpperLeft + OpSource HLSL 600 + OpName %type_sampler "type.sampler" + OpName %Sampler0 "Sampler0" + OpName %type_2d_image "type.2d.image" + OpName %Tex0 "Tex0" + OpName %in_var_INSTANCEID "in.var.INSTANCEID" + OpName %out_var_SV_TARGET "out.var.SV_TARGET" + OpName %psmain "psmain" + OpName %type_sampled_image "type.sampled.image" + OpDecorate %gl_FragCoord BuiltIn FragCoord + OpDecorate %in_var_INSTANCEID Flat + OpDecorate %in_var_INSTANCEID Location 0 + OpDecorate %out_var_SV_TARGET Location 0 + OpDecorate %Sampler0 DescriptorSet 0 + OpDecorate %Sampler0 Binding 1 + OpDecorate %Tex0 DescriptorSet 0 + OpDecorate %Tex0 Binding 2 +%type_sampler = OpTypeSampler + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler +%_arr_type_sampler_uint_2 = OpTypeArray %type_sampler %uint_2 +%_ptr_UniformConstant__arr_type_sampler_uint_2 = OpTypePointer UniformConstant %_arr_type_sampler_uint_2 + %float = OpTypeFloat 32 +%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown +%_arr_type_2d_image_uint_2 = OpTypeArray %type_2d_image %uint_2 +%_ptr_UniformConstant__arr_type_2d_image_uint_2 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_2 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Input_uint = OpTypePointer Input %uint +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %21 = OpTypeFunction %void +%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image + %v2float = OpTypeVector %float 2 + %v2uint = OpTypeVector %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %27 = OpConstantComposite %v2uint %uint_0 %uint_1 +%type_sampled_image = OpTypeSampledImage %type_2d_image + %Sampler0 = OpVariable %_ptr_UniformConstant__arr_type_sampler_uint_2 UniformConstant + %Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_2 UniformConstant +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output + %66 = OpConstantNull %v4float + %psmain = OpFunction %void None %21 + %28 = OpLabel + %29 = OpLoad %v4float %gl_FragCoord + %30 = OpLoad %uint %in_var_INSTANCEID + %31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30 + %32 = OpLoad %type_2d_image %31 + OpImageWrite %32 %27 %29 + +; CHECK: [[null_value:%\w+]] = OpConstantNull %v4float + +; CHECK: [[var_index:%\w+]] = OpLoad %uint %in_var_INSTANCEID +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]] +; CHECK: [[case0]] = OpLabel +; CHECK: OpAccessChain +; CHECK: OpLoad +; CHECK: OpImageWrite +; CHECK: OpBranch [[merge]] +; CHECK: [[case1]] = OpLabel +; CHECK: OpAccessChain +; CHECK: OpLoad +; CHECK: OpImageWrite +; CHECK: OpBranch [[merge]] +; CHECK: [[default]] = OpLabel +; CHECK: OpBranch [[merge]] +; CHECK: [[merge]] = OpLabel + + %33 = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %30 + %37 = OpLoad %type_sampler %33 + %34 = OpVectorShuffle %v2float %29 %29 0 1 + %35 = OpSampledImage %type_sampled_image %32 %37 + %36 = OpImageSampleImplicitLod %v4float %35 %34 None + +; SPIR-V instructions to be replaced (will be killed by ADCE) +; CHECK: OpSelectionMerge +; CHECK: OpSwitch + +; CHECK: OpSelectionMerge [[merge_sampler:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default_sampler:%\w+]] 0 [[case_sampler0:%\w+]] 1 [[case_sampler1:%\w+]] + +; CHECK: [[case_sampler0]] = OpLabel +; CHECK: OpSelectionMerge [[merge_texture0:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default_texture:%\w+]] 0 [[case_texture0:%\w+]] 1 [[case_texture1:%\w+]] +; CHECK: [[case_texture0]] = OpLabel +; CHECK: [[pt0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0 +; CHECK: [[ps0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_0 +; CHECK: [[s0:%\w+]] = OpLoad %type_sampler [[ps0]] +; CHECK: [[t0:%\w+]] = OpLoad %type_2d_image [[pt0]] +; CHECK: [[sampledImg0:%\w+]] = OpSampledImage %type_sampled_image [[t0]] [[s0]] +; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg0]] +; CHECK: OpBranch [[merge_texture0]] +; CHECK: [[case_texture1]] = OpLabel +; CHECK: [[pt1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1 +; CHECK: [[ps0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_0 +; CHECK: [[s0:%\w+]] = OpLoad %type_sampler [[ps0]] +; CHECK: [[t1:%\w+]] = OpLoad %type_2d_image [[pt1]] +; CHECK: [[sampledImg1:%\w+]] = OpSampledImage %type_sampled_image [[t1]] [[s0]] +; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg1]] +; CHECK: OpBranch [[merge_texture0]] +; CHECK: [[default_texture]] = OpLabel +; CHECK: OpBranch [[merge_texture0]] +; CHECK: [[merge_texture0]] = OpLabel +; CHECK: [[phi0:%\w+]] = OpPhi %v4float [[value0]] [[case_texture0]] [[value1]] [[case_texture1]] [[null_value]] [[default_texture]] +; CHECK: OpBranch [[merge_sampler]] + +; CHECK: [[case_sampler1]] = OpLabel +; CHECK: OpSelectionMerge [[merge_texture1:%\w+]] None +; CHECK: OpSwitch [[var_index]] [[default_texture:%\w+]] 0 [[case_texture0:%\w+]] 1 [[case_texture1:%\w+]] +; CHECK: [[case_texture0]] = OpLabel +; CHECK: [[pt0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0 +; CHECK: [[ps1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_1 +; CHECK: [[s1:%\w+]] = OpLoad %type_sampler [[ps1]] +; CHECK: [[t0:%\w+]] = OpLoad %type_2d_image [[pt0]] +; CHECK: [[sampledImg0:%\w+]] = OpSampledImage %type_sampled_image [[t0]] [[s1]] +; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg0]] +; CHECK: OpBranch [[merge_texture1]] +; CHECK: [[case_texture1]] = OpLabel +; CHECK: [[pt1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1 +; CHECK: [[ps1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_1 +; CHECK: [[s1:%\w+]] = OpLoad %type_sampler [[ps1]] +; CHECK: [[t1:%\w+]] = OpLoad %type_2d_image [[pt1]] +; CHECK: [[sampledImg1:%\w+]] = OpSampledImage %type_sampled_image [[t1]] [[s1]] +; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg1]] +; CHECK: OpBranch [[merge_texture1]] +; CHECK: [[default_texture]] = OpLabel +; CHECK: OpBranch [[merge_texture1]] +; CHECK: [[merge_texture1]] = OpLabel +; CHECK: [[phi1:%\w+]] = OpPhi %v4float [[value0]] [[case_texture0]] [[value1]] [[case_texture1]] [[null_value]] [[default_texture]] + +; CHECK: [[default_sampler]] = OpLabel +; CHECK: OpBranch [[merge_sampler]] +; CHECK: [[merge_sampler]] = OpLabel +; CHECK: OpPhi %v4float [[phi0]] [[merge_texture0]] [[phi1]] [[merge_texture1]] [[null_value]] [[default_sampler]] +; CHECK: OpStore + + OpStore %out_var_SV_TARGET %36 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ReplaceDescArrayAccessUsingVarIndexTest, + ReplaceAccessChainToTextureArrayWithSingleElement) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET + OpExecutionMode %psmain OriginUpperLeft + OpSource HLSL 600 + OpName %type_sampler "type.sampler" + OpName %Sampler0 "Sampler0" + OpName %type_2d_image "type.2d.image" + OpName %Tex0 "Tex0" + OpName %in_var_INSTANCEID "in.var.INSTANCEID" + OpName %out_var_SV_TARGET "out.var.SV_TARGET" + OpName %psmain "psmain" + OpName %type_sampled_image "type.sampled.image" + OpDecorate %gl_FragCoord BuiltIn FragCoord + OpDecorate %in_var_INSTANCEID Flat + OpDecorate %in_var_INSTANCEID Location 0 + OpDecorate %out_var_SV_TARGET Location 0 + OpDecorate %Sampler0 DescriptorSet 0 + OpDecorate %Sampler0 Binding 1 + OpDecorate %Tex0 DescriptorSet 0 + OpDecorate %Tex0 Binding 2 +%type_sampler = OpTypeSampler +%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %float = OpTypeFloat 32 +%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown +%_arr_type_2d_image_uint_1 = OpTypeArray %type_2d_image %uint_1 +%_ptr_UniformConstant__arr_type_2d_image_uint_1 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_1 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Input_uint = OpTypePointer Input %uint +%_ptr_Output_v4float = OpTypePointer Output %v4float + %void = OpTypeVoid + %21 = OpTypeFunction %void +%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image + %v2float = OpTypeVector %float 2 + %v2uint = OpTypeVector %uint 2 + %uint_0 = OpConstant %uint 0 + %27 = OpConstantComposite %v2uint %uint_0 %uint_1 +%type_sampled_image = OpTypeSampledImage %type_2d_image + %Sampler0 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant + %Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_1 UniformConstant +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output + %uint_2 = OpConstant %uint 2 + %66 = OpConstantNull %v4float + %psmain = OpFunction %void None %21 + %28 = OpLabel + %29 = OpLoad %v4float %gl_FragCoord + %30 = OpLoad %uint %in_var_INSTANCEID + %31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30 + %32 = OpLoad %type_2d_image %31 + OpImageWrite %32 %27 %29 + +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0 +; CHECK-NOT: OpAccessChain +; CHECK-NOT: OpSwitch +; CHECK-NOT: OpPhi + + %33 = OpLoad %type_sampler %Sampler0 + %34 = OpVectorShuffle %v2float %29 %29 0 1 + %35 = OpSampledImage %type_sampled_image %32 %33 + %36 = OpImageSampleImplicitLod %v4float %35 %34 None + + OpStore %out_var_SV_TARGET %36 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index d5036fc30..04f81b8c1 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -163,6 +163,11 @@ Options (in lexicographical order):)", around known issues with some Vulkan drivers for initialize variables.)"); printf(R"( + --replace-desc-array-access-using-var-index + Replaces accesses to descriptor arrays based on a variable index + with a switch that has a case for every possible value of the + index.)"); + printf(R"( --descriptor-scalar-replacement Replaces every array variable |desc| that has a DescriptorSet and Binding decorations with a new variable for each element of