From 4b64beb1aea92ef667e89317d79fc4fa5cb1a40b Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Aug 2019 10:53:19 -0400 Subject: [PATCH] Add descriptor array scalar replacement (#2742) Creates a pass that will replace a descriptor array with individual variables. See #2740 for details. Fixes #2740. --- Android.mk | 1 + BUILD.gn | 2 + include/spirv-tools/optimizer.hpp | 11 + source/opt/CMakeLists.txt | 2 + source/opt/desc_sroa.cpp | 255 ++++++++++++++++++++++ source/opt/desc_sroa.h | 84 +++++++ source/opt/ir_context.cpp | 36 +++ source/opt/ir_context.h | 4 + source/opt/optimizer.cpp | 7 + source/opt/passes.h | 1 + source/util/string_utils.h | 44 ++++ test/assembly_context_test.cpp | 6 +- test/binary_parse_test.cpp | 3 +- test/comment_test.cpp | 3 +- test/ext_inst.debuginfo_test.cpp | 3 +- test/ext_inst.opencl_test.cpp | 3 +- test/opt/CMakeLists.txt | 1 + test/opt/decoration_manager_test.cpp | 3 +- test/opt/desc_sroa_test.cpp | 209 ++++++++++++++++++ test/text_to_binary.annotation_test.cpp | 3 +- test/text_to_binary.debug_test.cpp | 3 +- test/text_to_binary.extension_test.cpp | 3 +- test/text_to_binary.mode_setting_test.cpp | 3 +- test/unit_spirv.cpp | 3 +- test/unit_spirv.h | 23 -- tools/opt/opt.cpp | 9 + 26 files changed, 689 insertions(+), 36 deletions(-) create mode 100644 source/opt/desc_sroa.cpp create mode 100644 source/opt/desc_sroa.h create mode 100644 test/opt/desc_sroa_test.cpp diff --git a/Android.mk b/Android.mk index 8a507da1f..a6278af0d 100644 --- a/Android.mk +++ b/Android.mk @@ -95,6 +95,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/decompose_initialized_variables_pass.cpp \ source/opt/decoration_manager.cpp \ source/opt/def_use_manager.cpp \ + source/opt/desc_sroa.cpp \ source/opt/dominator_analysis.cpp \ source/opt/dominator_tree.cpp \ source/opt/eliminate_dead_constant_pass.cpp \ diff --git a/BUILD.gn b/BUILD.gn index 90d80e54f..84b21e13b 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -491,6 +491,8 @@ static_library("spvtools_opt") { "source/opt/decoration_manager.h", "source/opt/def_use_manager.cpp", "source/opt/def_use_manager.h", + "source/opt/desc_sroa.cpp", + "source/opt/desc_sroa.h", "source/opt/dominator_analysis.cpp", "source/opt/dominator_analysis.h", "source/opt/dominator_tree.cpp", diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index a52dcd0ad..d442b97f5 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -784,6 +784,17 @@ Optimizer::PassToken CreateSplitInvalidUnreachablePass(); // wide. Optimizer::PassToken CreateGraphicsRobustAccessPass(); +// Create descriptor scalar replacement pass. +// This pass replaces every array variable |desc| that has a DescriptorSet and +// Binding decorations with a new variable for each element of the array. +// Suppose |desc| was bound at binding |b|. Then the variable corresponding to +// |desc[i]| will have binding |b+i|. The descriptor set will be the same. It +// is assumed that no other variable already has a binding that will used by one +// of the new variables. If not, the pass will generate invalid Spir-V. All +// accesses to |desc| must be OpAccessChain instructions with a literal index +// for the first index. +Optimizer::PassToken CreateDescriptorScalarReplacementPass(); + } // namespace spvtools #endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 278f794c1..2ebad512a 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -33,6 +33,7 @@ set(SPIRV_TOOLS_OPT_SOURCES decompose_initialized_variables_pass.h decoration_manager.h def_use_manager.h + desc_sroa.h dominator_analysis.h dominator_tree.h eliminate_dead_constant_pass.h @@ -134,6 +135,7 @@ set(SPIRV_TOOLS_OPT_SOURCES decompose_initialized_variables_pass.cpp decoration_manager.cpp def_use_manager.cpp + desc_sroa.cpp dominator_analysis.cpp dominator_tree.cpp eliminate_dead_constant_pass.cpp diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp new file mode 100644 index 000000000..36256ffaf --- /dev/null +++ b/source/opt/desc_sroa.cpp @@ -0,0 +1,255 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/desc_sroa.h" + +#include + +namespace spvtools { +namespace opt { + +Pass::Status DescriptorScalarReplacement::Process() { + bool modified = false; + + std::vector vars_to_kill; + + for (Instruction& var : context()->types_values()) { + if (IsCandidate(&var)) { + modified = true; + if (!ReplaceCandidate(&var)) { + return Status::Failure; + } + vars_to_kill.push_back(&var); + } + } + + for (Instruction* var : vars_to_kill) { + context()->KillInst(var); + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool DescriptorScalarReplacement::IsCandidate(Instruction* var) { + if (var->opcode() != SpvOpVariable) { + return false; + } + + uint32_t ptr_type_id = var->type_id(); + Instruction* ptr_type_inst = + context()->get_def_use_mgr()->GetDef(ptr_type_id); + if (ptr_type_inst->opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* var_type_inst = + context()->get_def_use_mgr()->GetDef(var_type_id); + if (var_type_inst->opcode() != SpvOpTypeArray) { + return false; + } + + bool has_desc_set_decoration = false; + context()->get_decoration_mgr()->ForEachDecoration( + var->result_id(), SpvDecorationDescriptorSet, + [&has_desc_set_decoration](const Instruction&) { + has_desc_set_decoration = true; + }); + if (!has_desc_set_decoration) { + return false; + } + + bool has_binding_decoration = false; + context()->get_decoration_mgr()->ForEachDecoration( + var->result_id(), SpvDecorationBinding, + [&has_binding_decoration](const Instruction&) { + has_binding_decoration = true; + }); + if (!has_binding_decoration) { + return false; + } + + return true; +} + +bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { + std::vector work_list; + bool failed = !get_def_use_mgr()->WhileEachUser( + var->result_id(), [this, &work_list](Instruction* use) { + if (use->opcode() == SpvOpName) { + return true; + } + + if (use->IsDecoration()) { + return true; + } + + switch (use->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + work_list.push_back(use); + return true; + default: + context()->EmitErrorMessage( + "Variable cannot be replaced: invalid instruction", use); + return false; + } + return true; + }); + + if (failed) { + return false; + } + + for (Instruction* use : work_list) { + if (!ReplaceAccessChain(var, use)) { + return false; + } + } + return true; +} + +bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var, + Instruction* use) { + if (use->NumInOperands() <= 1) { + context()->EmitErrorMessage( + "Variable cannot be replaced: invalid instruction", use); + return false; + } + + uint32_t idx_id = use->GetSingleWordInOperand(1); + const analysis::Constant* idx_const = + context()->get_constant_mgr()->FindDeclaredConstant(idx_id); + if (idx_const == nullptr) { + context()->EmitErrorMessage("Variable cannot be replaced: invalid index", + use); + return false; + } + + uint32_t idx = idx_const->GetU32(); + uint32_t replacement_var = GetReplacementVariable(var, idx); + + if (use->NumInOperands() == 2) { + // We are not indexing into the replacement variable. We can replaces the + // access chain with the replacement varibale itself. + context()->ReplaceAllUsesWith(use->result_id(), replacement_var); + context()->KillInst(use); + return true; + } + + // We need to build a new access chain with the replacement variable as the + // base address. + Instruction::OperandList new_operands; + + // Same result id and result type. + new_operands.emplace_back(use->GetOperand(0)); + new_operands.emplace_back(use->GetOperand(1)); + + // Use the replacement variable as the base address. + new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}}); + + // Drop the first index because it is consumed by the replacment, and copy the + // rest. + for (uint32_t i = 4; i < use->NumOperands(); i++) { + new_operands.emplace_back(use->GetOperand(i)); + } + + use->ReplaceOperands(new_operands); + context()->UpdateDefUse(use); + return true; +} + +uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var, + uint32_t idx) { + auto replacement_vars = replacement_variables_.find(var); + if (replacement_vars == replacement_variables_.end()) { + uint32_t ptr_type_id = var->type_id(); + Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); + assert(ptr_type_inst->opcode() == SpvOpTypePointer && + "Variable should be a pointer to an array."); + uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id); + assert(arr_type_inst->opcode() == SpvOpTypeArray && + "Variable should be a pointer to an array."); + + uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1); + const analysis::Constant* array_len_const = + context()->get_constant_mgr()->FindDeclaredConstant(array_len_id); + assert(array_len_const != nullptr && "Array length must be a constant."); + uint32_t array_len = array_len_const->GetU32(); + + replacement_vars = replacement_variables_ + .insert({var, std::vector(array_len, 0)}) + .first; + } + + if (replacement_vars->second[idx] == 0) { + replacement_vars->second[idx] = CreateReplacementVariable(var, idx); + } + + return replacement_vars->second[idx]; +} + +uint32_t DescriptorScalarReplacement::CreateReplacementVariable( + Instruction* var, uint32_t idx) { + // The storage class for the new variable is the same as the original. + SpvStorageClass storage_class = + static_cast(var->GetSingleWordInOperand(0)); + + // The type for the new variable will be a pointer to type of the elements of + // the array. + uint32_t ptr_type_id = var->type_id(); + Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); + assert(ptr_type_inst->opcode() == SpvOpTypePointer && + "Variable should be a pointer to an array."); + uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1); + Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id); + assert(arr_type_inst->opcode() == SpvOpTypeArray && + "Variable should be a pointer to an array."); + uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0); + + uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType( + element_type_id, storage_class); + + // Create the variable. + uint32_t id = TakeNextId(); + std::unique_ptr variable( + new Instruction(context(), SpvOpVariable, ptr_element_type_id, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, + {static_cast(storage_class)}}})); + context()->AddGlobalValue(std::move(variable)); + + // Copy all of the decorations to the new variable. The only difference is + // the Binding decoration needs to be adjusted. + for (auto old_decoration : + get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) { + assert(old_decoration->opcode() == SpvOpDecorate); + std::unique_ptr new_decoration( + old_decoration->Clone(context())); + new_decoration->SetInOperand(0, {id}); + + uint32_t decoration = new_decoration->GetSingleWordInOperand(1u); + if (decoration == SpvDecorationBinding) { + uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx; + new_decoration->SetInOperand(2, {new_binding}); + } + context()->AddAnnotationInst(std::move(new_decoration)); + } + + return id; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h new file mode 100644 index 000000000..a95c6b582 --- /dev/null +++ b/source/opt/desc_sroa.h @@ -0,0 +1,84 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DESC_SROA_H_ +#define SOURCE_OPT_DESC_SROA_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" + +namespace spvtools { +namespace opt { + +// Documented in optimizer.hpp +class DescriptorScalarReplacement : public Pass { + public: + DescriptorScalarReplacement() {} + + const char* name() const override { return "descriptor-scalar-replacement"; } + + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Returns true if |var| is an OpVariable instruction that represents a + // descriptor array. These are the variables that we want to replace. + bool IsCandidate(Instruction* var); + + // Replaces all references to |var| by new variables, one for each element of + // the array |var|. The binding for the new variables corresponding to + // element i will be the binding of |var| plus i. Returns true if successful. + bool ReplaceCandidate(Instruction* var); + + // Replaces the base address |var| in the OpAccessChain or + // OpInBoundsAccessChain instruction |use| by the variable that the access + // chain accesses. The first index in |use| must be an |OpConstant|. Returns + // |true| if successful. + bool ReplaceAccessChain(Instruction* var, Instruction* use); + + // Returns the id of the variable that will be used to replace the |idx|th + // element of |var|. The variable is created if it has not already been + // created. + uint32_t GetReplacementVariable(Instruction* var, uint32_t idx); + + // Returns the id of a new variable that can be used to replace the |idx|th + // element of |var|. + uint32_t CreateReplacementVariable(Instruction* var, uint32_t idx); + + // A map from an OpVariable instruction to the set of variables that will be + // used to replace it. The entry |replacement_variables_[var][i]| is the id of + // a variable that will be used in the place of the the ith element of the + // array |var|. If the entry is |0|, then the variable has not been + // created yet. + std::map> replacement_variables_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DESC_SROA_H_ diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index b600f12bf..823c2b7bd 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -788,6 +788,42 @@ bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn, return modified; } +void IRContext::EmitErrorMessage(std::string message, Instruction* inst) { + if (!consumer()) { + return; + } + + Instruction* line_inst = inst; + while (line_inst != nullptr) { // Stop at the beginning of the basic block. + if (!line_inst->dbg_line_insts().empty()) { + line_inst = &line_inst->dbg_line_insts().back(); + if (line_inst->opcode() == SpvOpNoLine) { + line_inst = nullptr; + } + break; + } + line_inst = line_inst->PreviousNode(); + } + + uint32_t line_number = 0; + uint32_t col_number = 0; + char* source = nullptr; + if (line_inst != nullptr) { + Instruction* file_name = + get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0)); + source = reinterpret_cast(&file_name->GetInOperand(0).words[0]); + + // Get the line number and column number. + line_number = line_inst->GetSingleWordInOperand(1); + col_number = line_inst->GetSingleWordInOperand(2); + } + + message += + "\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0}, + message.c_str()); +} + // Gets the dominator analysis for function |f|. DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) { if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 308f6337e..05df9c037 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -556,6 +556,10 @@ class IRContext { bool ProcessCallTreeFromRoots(ProcessFunction& pfn, std::queue* roots); + // Emmits a error message to the message consumer indicating the error + // described by |message| occurred in |inst|. + void EmitErrorMessage(std::string message, Instruction* inst); + private: // Builds the def-use manager from scratch, even if it was already valid. void BuildDefUseManager() { diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 2dd170891..4cc5e972e 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -315,6 +315,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateCombineAccessChainsPass()); } else if (pass_name == "convert-local-access-chains") { RegisterPass(CreateLocalAccessChainConvertPass()); + } else if (pass_name == "descriptor-scalar-replacement") { + RegisterPass(CreateDescriptorScalarReplacementPass()); } else if (pass_name == "eliminate-dead-code-aggressive") { RegisterPass(CreateAggressiveDCEPass()); } else if (pass_name == "propagate-line-info") { @@ -886,4 +888,9 @@ Optimizer::PassToken CreateGraphicsRobustAccessPass() { MakeUnique()); } +Optimizer::PassToken CreateDescriptorScalarReplacementPass() { + return MakeUnique( + MakeUnique()); +} + } // namespace spvtools diff --git a/source/opt/passes.h b/source/opt/passes.h index 5eddc220f..86588f7eb 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -29,6 +29,7 @@ #include "source/opt/dead_insert_elim_pass.h" #include "source/opt/dead_variable_elimination.h" #include "source/opt/decompose_initialized_variables_pass.h" +#include "source/opt/desc_sroa.h" #include "source/opt/eliminate_dead_constant_pass.h" #include "source/opt/eliminate_dead_functions_pass.h" #include "source/opt/eliminate_dead_members_pass.h" diff --git a/source/util/string_utils.h b/source/util/string_utils.h index f1cd179c9..4282aa949 100644 --- a/source/util/string_utils.h +++ b/source/util/string_utils.h @@ -15,8 +15,10 @@ #ifndef SOURCE_UTIL_STRING_UTILS_H_ #define SOURCE_UTIL_STRING_UTILS_H_ +#include #include #include +#include #include "source/util/string_utils.h" @@ -42,6 +44,48 @@ std::string CardinalToOrdinal(size_t cardinal); // string will be empty. std::pair SplitFlagArgs(const std::string& flag); +// Encodes a string as a sequence of words, using the SPIR-V encoding. +inline std::vector MakeVector(std::string input) { + std::vector result; + uint32_t word = 0; + size_t num_bytes = input.size(); + // SPIR-V strings are null-terminated. The byte_index == num_bytes + // case is used to push the terminating null byte. + for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) { + const auto new_byte = + (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0)); + word |= (new_byte << (8 * (byte_index % sizeof(uint32_t)))); + if (3 == (byte_index % sizeof(uint32_t))) { + result.push_back(word); + word = 0; + } + } + // Emit a trailing partial word. + if ((num_bytes + 1) % sizeof(uint32_t)) { + result.push_back(word); + } + return result; +} + +// Decode a string from a sequence of words, using the SPIR-V encoding. +template +inline std::string MakeString(const VectorType& words) { + std::string result; + + for (uint32_t word : words) { + for (int byte_index = 0; byte_index < 4; byte_index++) { + uint32_t extracted_word = (word >> (8 * byte_index)) & 0xFF; + char c = static_cast(extracted_word); + if (c == 0) { + return result; + } + result += c; + } + } + assert(false && "Did not find terminating null for the string."); + return result; +} // namespace utils + } // namespace utils } // namespace spvtools diff --git a/test/assembly_context_test.cpp b/test/assembly_context_test.cpp index ee0bb24a9..c8aa06be7 100644 --- a/test/assembly_context_test.cpp +++ b/test/assembly_context_test.cpp @@ -17,6 +17,7 @@ #include "gmock/gmock.h" #include "source/instruction.h" +#include "source/util/string_utils.h" #include "test/unit_spirv.h" namespace spvtools { @@ -40,9 +41,8 @@ TEST_P(EncodeStringTest, Sample) { ASSERT_EQ(SPV_SUCCESS, context.binaryEncodeString(GetParam().str.c_str(), &inst)); // We already trust MakeVector - EXPECT_THAT(inst.words, - Eq(Concatenate({GetParam().initial_contents, - spvtest::MakeVector(GetParam().str)}))); + EXPECT_THAT(inst.words, Eq(Concatenate({GetParam().initial_contents, + utils::MakeVector(GetParam().str)}))); } // clang-format off diff --git a/test/binary_parse_test.cpp b/test/binary_parse_test.cpp index b9661023e..54664fce7 100644 --- a/test/binary_parse_test.cpp +++ b/test/binary_parse_test.cpp @@ -21,6 +21,7 @@ #include "gmock/gmock.h" #include "source/latest_version_opencl_std_header.h" #include "source/table.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -39,7 +40,7 @@ namespace { using ::spvtest::Concatenate; using ::spvtest::MakeInstruction; -using ::spvtest::MakeVector; +using utils::MakeVector; using ::spvtest::ScopedContext; using ::testing::_; using ::testing::AnyOf; diff --git a/test/comment_test.cpp b/test/comment_test.cpp index f46b72ac5..49f8df651 100644 --- a/test/comment_test.cpp +++ b/test/comment_test.cpp @@ -15,6 +15,7 @@ #include #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -23,7 +24,7 @@ namespace { using spvtest::Concatenate; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using spvtest::TextToBinaryTest; using testing::Eq; diff --git a/test/ext_inst.debuginfo_test.cpp b/test/ext_inst.debuginfo_test.cpp index ec012e0ea..9090c2473 100644 --- a/test/ext_inst.debuginfo_test.cpp +++ b/test/ext_inst.debuginfo_test.cpp @@ -17,6 +17,7 @@ #include "DebugInfo.h" #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -31,7 +32,7 @@ namespace { using spvtest::Concatenate; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using testing::Eq; struct InstructionCase { diff --git a/test/ext_inst.opencl_test.cpp b/test/ext_inst.opencl_test.cpp index 7dd903e10..7547d9224 100644 --- a/test/ext_inst.opencl_test.cpp +++ b/test/ext_inst.opencl_test.cpp @@ -17,6 +17,7 @@ #include "gmock/gmock.h" #include "source/latest_version_opencl_std_header.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -25,7 +26,7 @@ namespace { using spvtest::Concatenate; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using spvtest::TextToBinaryTest; using testing::Eq; diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 6131c9b9e..366a61f68 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -34,6 +34,7 @@ add_spvtools_unittest(TARGET opt decompose_initialized_variables_test.cpp decoration_manager_test.cpp def_use_test.cpp + desc_sroa_test.cpp eliminate_dead_const_test.cpp eliminate_dead_functions_test.cpp eliminate_dead_member_test.cpp diff --git a/test/opt/decoration_manager_test.cpp b/test/opt/decoration_manager_test.cpp index 3eb3ef58e..fcfbff060 100644 --- a/test/opt/decoration_manager_test.cpp +++ b/test/opt/decoration_manager_test.cpp @@ -22,6 +22,7 @@ #include "source/opt/decoration_manager.h" #include "source/opt/ir_context.h" #include "source/spirv_constant.h" +#include "source/util/string_utils.h" #include "test/unit_spirv.h" namespace spvtools { @@ -29,7 +30,7 @@ namespace opt { namespace analysis { namespace { -using spvtest::MakeVector; +using utils::MakeVector; class DecorationManagerTest : public ::testing::Test { public: diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp new file mode 100644 index 000000000..04ea0f736 --- /dev/null +++ b/test/opt/desc_sroa_test.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DescriptorScalarReplacementTest = PassTest<::testing::Test>; + +TEST_F(DescriptorScalarReplacementTest, ExpandTexture) { + const std::string text = R"( +; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var1]] Binding 0 +; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var2]] Binding 1 +; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var3]] Binding 2 +; CHECK: OpDecorate [[var4:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var4]] Binding 3 +; CHECK: OpDecorate [[var5:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var5]] Binding 4 +; CHECK: [[image_type:%\w+]] = OpTypeImage +; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[image_type]] +; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var4]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var5]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: OpLoad [[image_type]] [[var1]] +; CHECK: OpLoad [[image_type]] [[var2]] +; CHECK: OpLoad [[image_type]] [[var3]] +; CHECK: OpLoad [[image_type]] [[var4]] +; CHECK: OpLoad [[image_type]] [[var5]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 600 + OpDecorate %MyTextures DescriptorSet 0 + OpDecorate %MyTextures Binding 0 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %int_3 = OpConstant %int 3 + %int_4 = OpConstant %int 4 + %uint = OpTypeInt 32 0 + %uint_5 = OpConstant %uint 5 + %float = OpTypeFloat 32 +%type_2d_image = OpTypeImage %float 2D 2 0 0 1 Unknown +%_arr_type_2d_image_uint_5 = OpTypeArray %type_2d_image %uint_5 +%_ptr_UniformConstant__arr_type_2d_image_uint_5 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_5 + %v2float = OpTypeVector %float 2 + %void = OpTypeVoid + %26 = OpTypeFunction %void +%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image + %MyTextures = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_5 UniformConstant + %main = OpFunction %void None %26 + %28 = OpLabel + %29 = OpUndef %v2float + %30 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_0 + %31 = OpLoad %type_2d_image %30 + %35 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_1 + %36 = OpLoad %type_2d_image %35 + %40 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_2 + %41 = OpLoad %type_2d_image %40 + %45 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_3 + %46 = OpLoad %type_2d_image %45 + %50 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_4 + %51 = OpLoad %type_2d_image %50 + OpReturn + OpFunctionEnd + + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, ExpandSampler) { + const std::string text = R"( +; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var1]] Binding 1 +; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var2]] Binding 2 +; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var3]] Binding 3 +; CHECK: [[sampler_type:%\w+]] = OpTypeSampler +; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[sampler_type]] +; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant +; CHECK: OpLoad [[sampler_type]] [[var1]] +; CHECK: OpLoad [[sampler_type]] [[var2]] +; CHECK: OpLoad [[sampler_type]] [[var3]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 600 + OpDecorate %MySampler DescriptorSet 0 + OpDecorate %MySampler Binding 1 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 +%type_sampler = OpTypeSampler +%_arr_type_sampler_uint_3 = OpTypeArray %type_sampler %uint_3 +%_ptr_UniformConstant__arr_type_sampler_uint_3 = OpTypePointer UniformConstant %_arr_type_sampler_uint_3 + %void = OpTypeVoid + %26 = OpTypeFunction %void +%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler + %MySampler = OpVariable %_ptr_UniformConstant__arr_type_sampler_uint_3 UniformConstant + %main = OpFunction %void None %26 + %28 = OpLabel + %31 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_0 + %32 = OpLoad %type_sampler %31 + %35 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_1 + %36 = OpLoad %type_sampler %35 + %40 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_2 + %41 = OpLoad %type_sampler %40 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DescriptorScalarReplacementTest, ExpandSSBO) { + // Tests the expansion of an SSBO. Also check that an access chain with more + // than 1 index is correctly handled. + const std::string text = R"( +; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var1]] Binding 0 +; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0 +; CHECK: OpDecorate [[var2]] Binding 1 +; CHECK: OpTypeStruct +; CHECK: [[struct_type:%\w+]] = OpTypeStruct +; CHECK: [[ptr_type:%\w+]] = OpTypePointer Uniform [[struct_type]] +; CHECK: [[var1]] = OpVariable [[ptr_type]] Uniform +; CHECK: [[var2]] = OpVariable [[ptr_type]] Uniform +; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var1]] %uint_0 %uint_0 %uint_0 +; CHECK: OpLoad %v4float [[ac1]] +; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var2]] %uint_0 %uint_0 %uint_0 +; CHECK: OpLoad %v4float [[ac2]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 600 + OpDecorate %buffers DescriptorSet 0 + OpDecorate %buffers Binding 0 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %_runtimearr_S ArrayStride 16 + OpMemberDecorate %type_StructuredBuffer_S 0 Offset 0 + OpMemberDecorate %type_StructuredBuffer_S 0 NonWritable + OpDecorate %type_StructuredBuffer_S BufferBlock + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %S = OpTypeStruct %v4float +%_runtimearr_S = OpTypeRuntimeArray %S +%type_StructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_arr_type_StructuredBuffer_S_uint_2 = OpTypeArray %type_StructuredBuffer_S %uint_2 +%_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 = OpTypePointer Uniform %_arr_type_StructuredBuffer_S_uint_2 +%_ptr_Uniform_type_StructuredBuffer_S = OpTypePointer Uniform %type_StructuredBuffer_S + %void = OpTypeVoid + %19 = OpTypeFunction %void +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float + %buffers = OpVariable %_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 Uniform + %main = OpFunction %void None %19 + %21 = OpLabel + %22 = OpAccessChain %_ptr_Uniform_v4float %buffers %uint_0 %uint_0 %uint_0 %uint_0 + %23 = OpLoad %v4float %22 + %24 = OpAccessChain %_ptr_Uniform_type_StructuredBuffer_S %buffers %uint_1 + %25 = OpAccessChain %_ptr_Uniform_v4float %24 %uint_0 %uint_0 %uint_0 + %26 = OpLoad %v4float %25 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/text_to_binary.annotation_test.cpp b/test/text_to_binary.annotation_test.cpp index 69a48610d..61bdf64c8 100644 --- a/test/text_to_binary.annotation_test.cpp +++ b/test/text_to_binary.annotation_test.cpp @@ -21,6 +21,7 @@ #include #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -29,7 +30,7 @@ namespace { using spvtest::EnumCase; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using spvtest::TextToBinaryTest; using ::testing::Combine; using ::testing::Eq; diff --git a/test/text_to_binary.debug_test.cpp b/test/text_to_binary.debug_test.cpp index f9a464586..39ba5c524 100644 --- a/test/text_to_binary.debug_test.cpp +++ b/test/text_to_binary.debug_test.cpp @@ -19,6 +19,7 @@ #include #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -26,7 +27,7 @@ namespace spvtools { namespace { using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using spvtest::TextToBinaryTest; using ::testing::Eq; diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp index 84552b534..9408e9ac2 100644 --- a/test/text_to_binary.extension_test.cpp +++ b/test/text_to_binary.extension_test.cpp @@ -22,6 +22,7 @@ #include "gmock/gmock.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/latest_version_opencl_std_header.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -30,7 +31,7 @@ namespace { using spvtest::Concatenate; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using spvtest::TextToBinaryTest; using ::testing::Combine; using ::testing::Eq; diff --git a/test/text_to_binary.mode_setting_test.cpp b/test/text_to_binary.mode_setting_test.cpp index d1b69dd5e..8ddf42196 100644 --- a/test/text_to_binary.mode_setting_test.cpp +++ b/test/text_to_binary.mode_setting_test.cpp @@ -20,6 +20,7 @@ #include #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" #include "test/unit_spirv.h" @@ -28,7 +29,7 @@ namespace { using spvtest::EnumCase; using spvtest::MakeInstruction; -using spvtest::MakeVector; +using utils::MakeVector; using ::testing::Combine; using ::testing::Eq; using ::testing::TestWithParam; diff --git a/test/unit_spirv.cpp b/test/unit_spirv.cpp index 84ed87a51..085443948 100644 --- a/test/unit_spirv.cpp +++ b/test/unit_spirv.cpp @@ -15,12 +15,13 @@ #include "test/unit_spirv.h" #include "gmock/gmock.h" +#include "source/util/string_utils.h" #include "test/test_fixture.h" namespace spvtools { namespace { -using spvtest::MakeVector; +using utils::MakeVector; using ::testing::Eq; using Words = std::vector; diff --git a/test/unit_spirv.h b/test/unit_spirv.h index 224428884..32646620d 100644 --- a/test/unit_spirv.h +++ b/test/unit_spirv.h @@ -133,29 +133,6 @@ inline std::vector Concatenate( return result; } -// Encodes a string as a sequence of words, using the SPIR-V encoding. -inline std::vector MakeVector(std::string input) { - std::vector result; - uint32_t word = 0; - size_t num_bytes = input.size(); - // SPIR-V strings are null-terminated. The byte_index == num_bytes - // case is used to push the terminating null byte. - for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) { - const auto new_byte = - (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0)); - word |= (new_byte << (8 * (byte_index % sizeof(uint32_t)))); - if (3 == (byte_index % sizeof(uint32_t))) { - result.push_back(word); - word = 0; - } - } - // Emit a trailing partial word. - if ((num_bytes + 1) % sizeof(uint32_t)) { - result.push_back(word); - } - return result; -} - // A type for easily creating spv_text_t values, with an implicit conversion to // spv_text. struct AutoText { diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 8dcf9337a..c18b64c50 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -147,6 +147,15 @@ Options (in lexicographical order):)", around known issues with some Vulkan drivers for initialize variables.)"); printf(R"( + --descriptor-scalar-replacement + Replaces every array variable |desc| that has a DescriptorSet + and Binding decorations with a new variable for each element of + the array. Suppose |desc| was bound at binding |b|. Then the + variable corresponding to |desc[i]| will have binding |b+i|. + The descriptor set will be the same. All accesses to |desc| + must be in OpAccessChain instructions with a literal index for + the first index.)"); + printf(R"( --eliminate-dead-branches Convert conditional branches with constant condition to the indicated unconditional brranch. Delete all resulting dead