diff --git a/Android.mk b/Android.mk index 7c6a076d2..a43142236 100644 --- a/Android.mk +++ b/Android.mk @@ -156,6 +156,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/scalar_replacement_pass.cpp \ source/opt/set_spec_constant_default_value_pass.cpp \ source/opt/simplification_pass.cpp \ + source/opt/split_invalid_unreachable_pass.cpp \ source/opt/ssa_rewrite_pass.cpp \ source/opt/strength_reduction_pass.cpp \ source/opt/strip_atomic_counter_memory_pass.cpp \ diff --git a/BUILD.gn b/BUILD.gn index 6c28118f6..5c6f311d2 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -616,6 +616,8 @@ static_library("spvtools_opt") { "source/opt/set_spec_constant_default_value_pass.h", "source/opt/simplification_pass.cpp", "source/opt/simplification_pass.h", + "source/opt/split_invalid_unreachable_pass.cpp", + "source/opt/split_invalid_unreachable_pass.h", "source/opt/ssa_rewrite_pass.cpp", "source/opt/ssa_rewrite_pass.h", "source/opt/strength_reduction_pass.cpp", diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index af9f3e54a..313c9924d 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -772,6 +772,10 @@ Optimizer::PassToken CreateLegalizeVectorShufflePass(); // declaration and an initial store. Optimizer::PassToken CreateDecomposeInitializedVariablesPass(); +// Create a pass to attempt to split up invalid unreachable merge-blocks and +// continue-targets to legalize for WebGPU. +Optimizer::PassToken CreateSplitInvalidUnreachablePass(); + } // namespace spvtools #endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index b02485a71..4cd800106 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -97,6 +97,7 @@ set(SPIRV_TOOLS_OPT_SOURCES scalar_replacement_pass.h set_spec_constant_default_value_pass.h simplification_pass.h + split_invalid_unreachable_pass.h ssa_rewrite_pass.h strength_reduction_pass.h strip_atomic_counter_memory_pass.h @@ -195,6 +196,7 @@ set(SPIRV_TOOLS_OPT_SOURCES scalar_replacement_pass.cpp set_spec_constant_default_value_pass.cpp simplification_pass.cpp + split_invalid_unreachable_pass.cpp ssa_rewrite_pass.cpp strength_reduction_pass.cpp strip_atomic_counter_memory_pass.cpp diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 9bd46e2a0..252005252 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp @@ -147,6 +147,19 @@ BasicBlock* Function::InsertBasicBlockAfter( return nullptr; } +BasicBlock* Function::InsertBasicBlockBefore( + std::unique_ptr&& new_block, BasicBlock* position) { + for (auto bb_iter = begin(); bb_iter != end(); ++bb_iter) { + if (&*bb_iter == position) { + new_block->SetParent(this); + bb_iter = bb_iter.InsertBefore(std::move(new_block)); + return &*bb_iter; + } + } + assert(false && "Could not find insertion point."); + return nullptr; +} + bool Function::IsRecursive() const { IRContext* ctx = blocks_.front()->GetLabel()->context(); IRContext::ProcessFunction mark_visited = [this](Function* fp) { diff --git a/source/opt/function.h b/source/opt/function.h index c80b078cd..b1317ad3e 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -125,6 +125,9 @@ class Function { BasicBlock* InsertBasicBlockAfter(std::unique_ptr&& new_block, BasicBlock* position); + BasicBlock* InsertBasicBlockBefore(std::unique_ptr&& new_block, + BasicBlock* position); + // Return true if the function calls itself either directly or indirectly. bool IsRecursive() const; diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 34ddedca7..d6f9cef94 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -225,6 +225,7 @@ Optimizer& Optimizer::RegisterVulkanToWebGPUPasses() { .RegisterPass(CreateStripAtomicCounterMemoryPass()) .RegisterPass(CreateGenerateWebGPUInitializersPass()) .RegisterPass(CreateLegalizeVectorShufflePass()) + .RegisterPass(CreateSplitInvalidUnreachablePass()) .RegisterPass(CreateEliminateDeadConstantPass()) .RegisterPass(CreateFlattenDecorationPass()) .RegisterPass(CreateAggressiveDCEPass()) @@ -873,4 +874,9 @@ Optimizer::PassToken CreateDecomposeInitializedVariablesPass() { MakeUnique()); } +Optimizer::PassToken CreateSplitInvalidUnreachablePass() { + return MakeUnique( + MakeUnique()); +} + } // namespace spvtools diff --git a/source/opt/passes.h b/source/opt/passes.h index 06464b060..b7a9cb070 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -65,6 +65,7 @@ #include "source/opt/scalar_replacement_pass.h" #include "source/opt/set_spec_constant_default_value_pass.h" #include "source/opt/simplification_pass.h" +#include "source/opt/split_invalid_unreachable_pass.h" #include "source/opt/ssa_rewrite_pass.h" #include "source/opt/strength_reduction_pass.h" #include "source/opt/strip_atomic_counter_memory_pass.h" diff --git a/source/opt/split_invalid_unreachable_pass.cpp b/source/opt/split_invalid_unreachable_pass.cpp new file mode 100644 index 000000000..31cfbc330 --- /dev/null +++ b/source/opt/split_invalid_unreachable_pass.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/split_invalid_unreachable_pass.h" + +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status SplitInvalidUnreachablePass::Process() { + bool changed = false; + std::unordered_set entry_points; + for (auto entry_point : context()->module()->entry_points()) { + entry_points.insert(entry_point.GetSingleWordOperand(1)); + } + + for (auto func = context()->module()->begin(); + func != context()->module()->end(); ++func) { + if (entry_points.find(func->result_id()) == entry_points.end()) continue; + std::unordered_set continue_targets; + std::unordered_set merge_blocks; + std::unordered_set unreachable_blocks; + for (auto block = func->begin(); block != func->end(); ++block) { + unreachable_blocks.insert(&*block); + uint32_t continue_target = block->ContinueBlockIdIfAny(); + if (continue_target != 0) continue_targets.insert(continue_target); + uint32_t merge_block = block->MergeBlockIdIfAny(); + if (merge_block != 0) merge_blocks.insert(merge_block); + } + + cfg()->ForEachBlockInPostOrder( + func->entry().get(), [&unreachable_blocks](BasicBlock* inner_block) { + unreachable_blocks.erase(inner_block); + }); + + for (auto unreachable : unreachable_blocks) { + uint32_t block_id = unreachable->id(); + if (continue_targets.find(block_id) == continue_targets.end() || + merge_blocks.find(block_id) == merge_blocks.end()) { + continue; + } + + std::vector> usages; + context()->get_def_use_mgr()->ForEachUse( + unreachable->GetLabelInst(), + [&usages](Instruction* use, uint32_t idx) { + if ((use->opcode() == SpvOpLoopMerge && idx == 0) || + use->opcode() == SpvOpSelectionMerge) { + usages.push_back(std::make_pair(use, idx)); + } + }); + + for (auto usage : usages) { + Instruction* use; + uint32_t idx; + std::tie(use, idx) = usage; + uint32_t new_id = context()->TakeNextId(); + std::unique_ptr new_label( + new Instruction(context(), SpvOpLabel, 0, new_id, {})); + get_def_use_mgr()->AnalyzeInstDefUse(new_label.get()); + std::unique_ptr new_block( + new BasicBlock(std::move(new_label))); + auto* block_ptr = new_block.get(); + InstructionBuilder builder(context(), new_block.get(), + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping); + builder.AddUnreachable(); + cfg()->RegisterBlock(block_ptr); + (&*func)->InsertBasicBlockBefore(std::move(new_block), unreachable); + use->SetInOperand(0, {new_id}); + get_def_use_mgr()->UpdateDefUse(use); + cfg()->AddEdges(block_ptr); + changed = true; + } + } + } + + return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/split_invalid_unreachable_pass.h b/source/opt/split_invalid_unreachable_pass.h new file mode 100644 index 000000000..a5613448e --- /dev/null +++ b/source/opt/split_invalid_unreachable_pass.h @@ -0,0 +1,51 @@ +// Copyright (c) 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_ +#define SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// Attempts to legalize for WebGPU by splitting up invalid unreachable blocks. +// Specifically, looking for cases of unreachable merge-blocks and +// continue-targets that are used more then once, which is illegal in WebGPU. +class SplitInvalidUnreachablePass : public Pass { + public: + const char* name() const override { return "split-invalid-unreachable"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisScalarEvolution | + IRContext::kAnalysisRegisterPressure | + IRContext::kAnalysisValueNumberTable | + IRContext::kAnalysisStructuredCFG | + IRContext::kAnalysisBuiltinVarId | + IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisTypes | + IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants; + } +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_ diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index e10fe6b1c..ba29cf8b6 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -81,6 +81,7 @@ add_spvtools_unittest(TARGET opt scalar_replacement_test.cpp set_spec_const_default_value_test.cpp simplification_test.cpp + split_invalid_unreachable_test.cpp strength_reduction_test.cpp strip_atomic_counter_memory_test.cpp strip_debug_info_test.cpp diff --git a/test/opt/optimizer_test.cpp b/test/opt/optimizer_test.cpp index 549335d88..f13164b37 100644 --- a/test/opt/optimizer_test.cpp +++ b/test/opt/optimizer_test.cpp @@ -238,7 +238,8 @@ TEST(Optimizer, VulkanToWebGPUSetsCorrectPasses) { "strip-debug", "strip-atomic-counter-memory", "generate-webgpu-initializers", - "legalize-vector-shuffle"}; + "legalize-vector-shuffle", + "split-invalid-unreachable"}; std::sort(registered_passes.begin(), registered_passes.end()); std::sort(expected_passes.begin(), expected_passes.end()); @@ -524,7 +525,74 @@ INSTANTIATE_TEST_SUITE_P( "OpReturn\n" "OpFunctionEnd\n", // pass - "legalize-vector-shuffle"}})); + "legalize-vector-shuffle"}, + // Split Invalid Unreachable + {// input + "OpCapability Shader\n" + "OpCapability VulkanMemoryModelKHR\n" + "OpExtension \"SPV_KHR_vulkan_memory_model\"\n" + "OpMemoryModel Logical VulkanKHR\n" + "OpEntryPoint Vertex %1 \"shader\"\n" + "%uint = OpTypeInt 32 0\n" + "%uint_1 = OpConstant %uint 1\n" + "%uint_2 = OpConstant %uint 2\n" + "%void = OpTypeVoid\n" + "%bool = OpTypeBool\n" + "%7 = OpTypeFunction %void\n" + "%1 = OpFunction %void None %7\n" + "%8 = OpLabel\n" + "OpBranch %9\n" + "%9 = OpLabel\n" + "OpLoopMerge %10 %11 None\n" + "OpBranch %12\n" + "%12 = OpLabel\n" + "%13 = OpSLessThan %bool %uint_1 %uint_2\n" + "OpSelectionMerge %11 None\n" + "OpBranchConditional %13 %14 %15\n" + "%14 = OpLabel\n" + "OpReturn\n" + "%15 = OpLabel\n" + "OpReturn\n" + "%10 = OpLabel\n" + "OpUnreachable\n" + "%11 = OpLabel\n" + "OpBranch %9\n" + "OpFunctionEnd\n", + // expected + "OpCapability Shader\n" + "OpCapability VulkanMemoryModelKHR\n" + "OpExtension \"SPV_KHR_vulkan_memory_model\"\n" + "OpMemoryModel Logical VulkanKHR\n" + "OpEntryPoint Vertex %1 \"shader\"\n" + "%uint = OpTypeInt 32 0\n" + "%uint_1 = OpConstant %uint 1\n" + "%uint_2 = OpConstant %uint 2\n" + "%void = OpTypeVoid\n" + "%bool = OpTypeBool\n" + "%7 = OpTypeFunction %void\n" + "%1 = OpFunction %void None %7\n" + "%8 = OpLabel\n" + "OpBranch %9\n" + "%9 = OpLabel\n" + "OpLoopMerge %10 %11 None\n" + "OpBranch %12\n" + "%12 = OpLabel\n" + "%13 = OpSLessThan %bool %uint_1 %uint_2\n" + "OpSelectionMerge %16 None\n" + "OpBranchConditional %13 %14 %15\n" + "%14 = OpLabel\n" + "OpReturn\n" + "%15 = OpLabel\n" + "OpReturn\n" + "%10 = OpLabel\n" + "OpUnreachable\n" + "%16 = OpLabel\n" + "OpUnreachable\n" + "%11 = OpLabel\n" + "OpBranch %9\n" + "OpFunctionEnd\n", + // pass + "split-invalid-unreachable"}})); TEST(Optimizer, WebGPUToVulkanSetsCorrectPasses) { Optimizer opt(SPV_ENV_VULKAN_1_1); diff --git a/test/opt/split_invalid_unreachable_test.cpp b/test/opt/split_invalid_unreachable_test.cpp new file mode 100644 index 000000000..868c7b51e --- /dev/null +++ b/test/opt/split_invalid_unreachable_test.cpp @@ -0,0 +1,155 @@ +// Copyright (c) 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using SplitInvalidUnreachableTest = PassTest<::testing::Test>; + +std::string spirv_header = R"(OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Vertex %1 "shader" +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%void = OpTypeVoid +%bool = OpTypeBool +%7 = OpTypeFunction %void +)"; + +std::string function_head = R"(%1 = OpFunction %void None %7 +%8 = OpLabel +OpBranch %9 +)"; + +std::string function_tail = "OpFunctionEnd\n"; + +std::string GetLoopMergeBlock(std::string block_id, std::string merge_id, + std::string continue_id, std::string body_id) { + std::string result; + result += block_id + " = OpLabel\n"; + result += "OpLoopMerge " + merge_id + " " + continue_id + " None\n"; + result += "OpBranch " + body_id + "\n"; + return result; +} + +std::string GetSelectionMergeBlock(std::string block_id, + std::string condition_id, + std::string merge_id, std::string true_id, + std::string false_id) { + std::string result; + result += block_id + " = OpLabel\n"; + result += condition_id + " = OpSLessThan %bool %uint_1 %uint_2\n"; + result += "OpSelectionMerge " + merge_id + " None\n"; + result += "OpBranchConditional " + condition_id + " " + true_id + " " + + false_id + "\n"; + + return result; +} + +std::string GetReturnBlock(std::string block_id) { + std::string result; + result += block_id + " = OpLabel\n"; + result += "OpReturn\n"; + return result; +} + +std::string GetUnreachableBlock(std::string block_id) { + std::string result; + result += block_id + " = OpLabel\n"; + result += "OpUnreachable\n"; + return result; +} + +std::string GetBranchBlock(std::string block_id, std::string target_id) { + std::string result; + result += block_id + " = OpLabel\n"; + result += "OpBranch " + target_id + "\n"; + return result; +} + +TEST_F(SplitInvalidUnreachableTest, NoInvalidBlocks) { + std::string input = spirv_header + function_head; + input += GetLoopMergeBlock("%9", "%10", "%11", "%12"); + input += GetSelectionMergeBlock("%12", "%13", "%14", "%15", "%16"); + input += GetReturnBlock("%15"); + input += GetReturnBlock("%16"); + input += GetUnreachableBlock("%10"); + input += GetBranchBlock("%11", "%9"); + input += GetUnreachableBlock("%14"); + input += function_tail; + + SinglePassRunAndCheck(input, input, + /* skip_nop = */ false); +} + +TEST_F(SplitInvalidUnreachableTest, SelectionInLoop) { + std::string input = spirv_header + function_head; + input += GetLoopMergeBlock("%9", "%10", "%11", "%12"); + input += GetSelectionMergeBlock("%12", "%13", "%11", "%15", "%16"); + input += GetReturnBlock("%15"); + input += GetReturnBlock("%16"); + input += GetUnreachableBlock("%10"); + input += GetBranchBlock("%11", "%9"); + input += function_tail; + + std::string expected = spirv_header + function_head; + expected += GetLoopMergeBlock("%9", "%10", "%11", "%12"); + expected += GetSelectionMergeBlock("%12", "%13", "%16", "%14", "%15"); + expected += GetReturnBlock("%14"); + expected += GetReturnBlock("%15"); + expected += GetUnreachableBlock("%10"); + expected += GetUnreachableBlock("%16"); + expected += GetBranchBlock("%11", "%9"); + expected += function_tail; + + SinglePassRunAndCheck(input, expected, + /* skip_nop = */ false); +} + +TEST_F(SplitInvalidUnreachableTest, LoopInSelection) { + std::string input = spirv_header + function_head; + input += GetSelectionMergeBlock("%9", "%10", "%11", "%12", "%13"); + input += GetLoopMergeBlock("%12", "%14", "%11", "%15"); + input += GetReturnBlock("%13"); + input += GetUnreachableBlock("%14"); + input += GetBranchBlock("%11", "%12"); + input += GetReturnBlock("%15"); + input += function_tail; + + std::string expected = spirv_header + function_head; + expected += GetSelectionMergeBlock("%9", "%10", "%16", "%12", "%13"); + expected += GetLoopMergeBlock("%12", "%14", "%11", "%15"); + expected += GetReturnBlock("%13"); + expected += GetUnreachableBlock("%14"); + expected += GetUnreachableBlock("%16"); + expected += GetBranchBlock("%11", "%12"); + expected += GetReturnBlock("%15"); + expected += function_tail; + + SinglePassRunAndCheck(input, expected, + /* skip_nop = */ false); +} + +} // namespace +} // namespace opt +} // namespace spvtools