Add pass to fix some invalid unreachable blocks for WebGPU (#2563)

Attempts to split up unreachable blocks that are used both as a
merge-block and a continue-target.

Fixes #2429
This commit is contained in:
Ryan Harrison 2019-05-09 12:56:10 -04:00 committed by GitHub
parent 89fe836fe2
commit f6d9a17843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 404 additions and 2 deletions

View File

@ -156,6 +156,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/scalar_replacement_pass.cpp \
source/opt/set_spec_constant_default_value_pass.cpp \
source/opt/simplification_pass.cpp \
source/opt/split_invalid_unreachable_pass.cpp \
source/opt/ssa_rewrite_pass.cpp \
source/opt/strength_reduction_pass.cpp \
source/opt/strip_atomic_counter_memory_pass.cpp \

View File

@ -616,6 +616,8 @@ static_library("spvtools_opt") {
"source/opt/set_spec_constant_default_value_pass.h",
"source/opt/simplification_pass.cpp",
"source/opt/simplification_pass.h",
"source/opt/split_invalid_unreachable_pass.cpp",
"source/opt/split_invalid_unreachable_pass.h",
"source/opt/ssa_rewrite_pass.cpp",
"source/opt/ssa_rewrite_pass.h",
"source/opt/strength_reduction_pass.cpp",

View File

@ -772,6 +772,10 @@ Optimizer::PassToken CreateLegalizeVectorShufflePass();
// declaration and an initial store.
Optimizer::PassToken CreateDecomposeInitializedVariablesPass();
// Create a pass to attempt to split up invalid unreachable merge-blocks and
// continue-targets to legalize for WebGPU.
Optimizer::PassToken CreateSplitInvalidUnreachablePass();
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_

View File

@ -97,6 +97,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
scalar_replacement_pass.h
set_spec_constant_default_value_pass.h
simplification_pass.h
split_invalid_unreachable_pass.h
ssa_rewrite_pass.h
strength_reduction_pass.h
strip_atomic_counter_memory_pass.h
@ -195,6 +196,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
scalar_replacement_pass.cpp
set_spec_constant_default_value_pass.cpp
simplification_pass.cpp
split_invalid_unreachable_pass.cpp
ssa_rewrite_pass.cpp
strength_reduction_pass.cpp
strip_atomic_counter_memory_pass.cpp

View File

@ -147,6 +147,19 @@ BasicBlock* Function::InsertBasicBlockAfter(
return nullptr;
}
BasicBlock* Function::InsertBasicBlockBefore(
std::unique_ptr<BasicBlock>&& new_block, BasicBlock* position) {
for (auto bb_iter = begin(); bb_iter != end(); ++bb_iter) {
if (&*bb_iter == position) {
new_block->SetParent(this);
bb_iter = bb_iter.InsertBefore(std::move(new_block));
return &*bb_iter;
}
}
assert(false && "Could not find insertion point.");
return nullptr;
}
bool Function::IsRecursive() const {
IRContext* ctx = blocks_.front()->GetLabel()->context();
IRContext::ProcessFunction mark_visited = [this](Function* fp) {

View File

@ -125,6 +125,9 @@ class Function {
BasicBlock* InsertBasicBlockAfter(std::unique_ptr<BasicBlock>&& new_block,
BasicBlock* position);
BasicBlock* InsertBasicBlockBefore(std::unique_ptr<BasicBlock>&& new_block,
BasicBlock* position);
// Return true if the function calls itself either directly or indirectly.
bool IsRecursive() const;

View File

@ -225,6 +225,7 @@ Optimizer& Optimizer::RegisterVulkanToWebGPUPasses() {
.RegisterPass(CreateStripAtomicCounterMemoryPass())
.RegisterPass(CreateGenerateWebGPUInitializersPass())
.RegisterPass(CreateLegalizeVectorShufflePass())
.RegisterPass(CreateSplitInvalidUnreachablePass())
.RegisterPass(CreateEliminateDeadConstantPass())
.RegisterPass(CreateFlattenDecorationPass())
.RegisterPass(CreateAggressiveDCEPass())
@ -873,4 +874,9 @@ Optimizer::PassToken CreateDecomposeInitializedVariablesPass() {
MakeUnique<opt::DecomposeInitializedVariablesPass>());
}
Optimizer::PassToken CreateSplitInvalidUnreachablePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::SplitInvalidUnreachablePass>());
}
} // namespace spvtools

View File

@ -65,6 +65,7 @@
#include "source/opt/scalar_replacement_pass.h"
#include "source/opt/set_spec_constant_default_value_pass.h"
#include "source/opt/simplification_pass.h"
#include "source/opt/split_invalid_unreachable_pass.h"
#include "source/opt/ssa_rewrite_pass.h"
#include "source/opt/strength_reduction_pass.h"
#include "source/opt/strip_atomic_counter_memory_pass.h"

View File

@ -0,0 +1,95 @@
// Copyright (c) 2019 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/split_invalid_unreachable_pass.h"
#include "source/opt/ir_builder.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace opt {
Pass::Status SplitInvalidUnreachablePass::Process() {
bool changed = false;
std::unordered_set<uint32_t> entry_points;
for (auto entry_point : context()->module()->entry_points()) {
entry_points.insert(entry_point.GetSingleWordOperand(1));
}
for (auto func = context()->module()->begin();
func != context()->module()->end(); ++func) {
if (entry_points.find(func->result_id()) == entry_points.end()) continue;
std::unordered_set<uint32_t> continue_targets;
std::unordered_set<uint32_t> merge_blocks;
std::unordered_set<BasicBlock*> unreachable_blocks;
for (auto block = func->begin(); block != func->end(); ++block) {
unreachable_blocks.insert(&*block);
uint32_t continue_target = block->ContinueBlockIdIfAny();
if (continue_target != 0) continue_targets.insert(continue_target);
uint32_t merge_block = block->MergeBlockIdIfAny();
if (merge_block != 0) merge_blocks.insert(merge_block);
}
cfg()->ForEachBlockInPostOrder(
func->entry().get(), [&unreachable_blocks](BasicBlock* inner_block) {
unreachable_blocks.erase(inner_block);
});
for (auto unreachable : unreachable_blocks) {
uint32_t block_id = unreachable->id();
if (continue_targets.find(block_id) == continue_targets.end() ||
merge_blocks.find(block_id) == merge_blocks.end()) {
continue;
}
std::vector<std::tuple<Instruction*, uint32_t>> usages;
context()->get_def_use_mgr()->ForEachUse(
unreachable->GetLabelInst(),
[&usages](Instruction* use, uint32_t idx) {
if ((use->opcode() == SpvOpLoopMerge && idx == 0) ||
use->opcode() == SpvOpSelectionMerge) {
usages.push_back(std::make_pair(use, idx));
}
});
for (auto usage : usages) {
Instruction* use;
uint32_t idx;
std::tie(use, idx) = usage;
uint32_t new_id = context()->TakeNextId();
std::unique_ptr<Instruction> new_label(
new Instruction(context(), SpvOpLabel, 0, new_id, {}));
get_def_use_mgr()->AnalyzeInstDefUse(new_label.get());
std::unique_ptr<BasicBlock> new_block(
new BasicBlock(std::move(new_label)));
auto* block_ptr = new_block.get();
InstructionBuilder builder(context(), new_block.get(),
IRContext::kAnalysisDefUse |
IRContext::kAnalysisInstrToBlockMapping);
builder.AddUnreachable();
cfg()->RegisterBlock(block_ptr);
(&*func)->InsertBasicBlockBefore(std::move(new_block), unreachable);
use->SetInOperand(0, {new_id});
get_def_use_mgr()->UpdateDefUse(use);
cfg()->AddEdges(block_ptr);
changed = true;
}
}
}
return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,51 @@
// Copyright (c) 2019 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_
#define SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "source/opt/pass.h"
namespace spvtools {
namespace opt {
// Attempts to legalize for WebGPU by splitting up invalid unreachable blocks.
// Specifically, looking for cases of unreachable merge-blocks and
// continue-targets that are used more then once, which is illegal in WebGPU.
class SplitInvalidUnreachablePass : public Pass {
public:
const char* name() const override { return "split-invalid-unreachable"; }
Status Process() override;
IRContext::Analysis GetPreservedAnalyses() override {
return IRContext::kAnalysisInstrToBlockMapping |
IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
IRContext::kAnalysisScalarEvolution |
IRContext::kAnalysisRegisterPressure |
IRContext::kAnalysisValueNumberTable |
IRContext::kAnalysisStructuredCFG |
IRContext::kAnalysisBuiltinVarId |
IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisTypes |
IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants;
}
};
} // namespace opt
} // namespace spvtools
#endif // SOURCE_OPT_SPLIT_INVALID_UNREACHABLE_PASS_H_

View File

@ -81,6 +81,7 @@ add_spvtools_unittest(TARGET opt
scalar_replacement_test.cpp
set_spec_const_default_value_test.cpp
simplification_test.cpp
split_invalid_unreachable_test.cpp
strength_reduction_test.cpp
strip_atomic_counter_memory_test.cpp
strip_debug_info_test.cpp

View File

@ -238,7 +238,8 @@ TEST(Optimizer, VulkanToWebGPUSetsCorrectPasses) {
"strip-debug",
"strip-atomic-counter-memory",
"generate-webgpu-initializers",
"legalize-vector-shuffle"};
"legalize-vector-shuffle",
"split-invalid-unreachable"};
std::sort(registered_passes.begin(), registered_passes.end());
std::sort(expected_passes.begin(), expected_passes.end());
@ -524,7 +525,74 @@ INSTANTIATE_TEST_SUITE_P(
"OpReturn\n"
"OpFunctionEnd\n",
// pass
"legalize-vector-shuffle"}}));
"legalize-vector-shuffle"},
// Split Invalid Unreachable
{// input
"OpCapability Shader\n"
"OpCapability VulkanMemoryModelKHR\n"
"OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
"OpMemoryModel Logical VulkanKHR\n"
"OpEntryPoint Vertex %1 \"shader\"\n"
"%uint = OpTypeInt 32 0\n"
"%uint_1 = OpConstant %uint 1\n"
"%uint_2 = OpConstant %uint 2\n"
"%void = OpTypeVoid\n"
"%bool = OpTypeBool\n"
"%7 = OpTypeFunction %void\n"
"%1 = OpFunction %void None %7\n"
"%8 = OpLabel\n"
"OpBranch %9\n"
"%9 = OpLabel\n"
"OpLoopMerge %10 %11 None\n"
"OpBranch %12\n"
"%12 = OpLabel\n"
"%13 = OpSLessThan %bool %uint_1 %uint_2\n"
"OpSelectionMerge %11 None\n"
"OpBranchConditional %13 %14 %15\n"
"%14 = OpLabel\n"
"OpReturn\n"
"%15 = OpLabel\n"
"OpReturn\n"
"%10 = OpLabel\n"
"OpUnreachable\n"
"%11 = OpLabel\n"
"OpBranch %9\n"
"OpFunctionEnd\n",
// expected
"OpCapability Shader\n"
"OpCapability VulkanMemoryModelKHR\n"
"OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
"OpMemoryModel Logical VulkanKHR\n"
"OpEntryPoint Vertex %1 \"shader\"\n"
"%uint = OpTypeInt 32 0\n"
"%uint_1 = OpConstant %uint 1\n"
"%uint_2 = OpConstant %uint 2\n"
"%void = OpTypeVoid\n"
"%bool = OpTypeBool\n"
"%7 = OpTypeFunction %void\n"
"%1 = OpFunction %void None %7\n"
"%8 = OpLabel\n"
"OpBranch %9\n"
"%9 = OpLabel\n"
"OpLoopMerge %10 %11 None\n"
"OpBranch %12\n"
"%12 = OpLabel\n"
"%13 = OpSLessThan %bool %uint_1 %uint_2\n"
"OpSelectionMerge %16 None\n"
"OpBranchConditional %13 %14 %15\n"
"%14 = OpLabel\n"
"OpReturn\n"
"%15 = OpLabel\n"
"OpReturn\n"
"%10 = OpLabel\n"
"OpUnreachable\n"
"%16 = OpLabel\n"
"OpUnreachable\n"
"%11 = OpLabel\n"
"OpBranch %9\n"
"OpFunctionEnd\n",
// pass
"split-invalid-unreachable"}}));
TEST(Optimizer, WebGPUToVulkanSetsCorrectPasses) {
Optimizer opt(SPV_ENV_VULKAN_1_1);

View File

@ -0,0 +1,155 @@
// Copyright (c) 2019 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "test/opt/pass_fixture.h"
#include "test/opt/pass_utils.h"
namespace spvtools {
namespace opt {
namespace {
using SplitInvalidUnreachableTest = PassTest<::testing::Test>;
std::string spirv_header = R"(OpCapability Shader
OpCapability VulkanMemoryModelKHR
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical VulkanKHR
OpEntryPoint Vertex %1 "shader"
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%void = OpTypeVoid
%bool = OpTypeBool
%7 = OpTypeFunction %void
)";
std::string function_head = R"(%1 = OpFunction %void None %7
%8 = OpLabel
OpBranch %9
)";
std::string function_tail = "OpFunctionEnd\n";
std::string GetLoopMergeBlock(std::string block_id, std::string merge_id,
std::string continue_id, std::string body_id) {
std::string result;
result += block_id + " = OpLabel\n";
result += "OpLoopMerge " + merge_id + " " + continue_id + " None\n";
result += "OpBranch " + body_id + "\n";
return result;
}
std::string GetSelectionMergeBlock(std::string block_id,
std::string condition_id,
std::string merge_id, std::string true_id,
std::string false_id) {
std::string result;
result += block_id + " = OpLabel\n";
result += condition_id + " = OpSLessThan %bool %uint_1 %uint_2\n";
result += "OpSelectionMerge " + merge_id + " None\n";
result += "OpBranchConditional " + condition_id + " " + true_id + " " +
false_id + "\n";
return result;
}
std::string GetReturnBlock(std::string block_id) {
std::string result;
result += block_id + " = OpLabel\n";
result += "OpReturn\n";
return result;
}
std::string GetUnreachableBlock(std::string block_id) {
std::string result;
result += block_id + " = OpLabel\n";
result += "OpUnreachable\n";
return result;
}
std::string GetBranchBlock(std::string block_id, std::string target_id) {
std::string result;
result += block_id + " = OpLabel\n";
result += "OpBranch " + target_id + "\n";
return result;
}
TEST_F(SplitInvalidUnreachableTest, NoInvalidBlocks) {
std::string input = spirv_header + function_head;
input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
input += GetSelectionMergeBlock("%12", "%13", "%14", "%15", "%16");
input += GetReturnBlock("%15");
input += GetReturnBlock("%16");
input += GetUnreachableBlock("%10");
input += GetBranchBlock("%11", "%9");
input += GetUnreachableBlock("%14");
input += function_tail;
SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, input,
/* skip_nop = */ false);
}
TEST_F(SplitInvalidUnreachableTest, SelectionInLoop) {
std::string input = spirv_header + function_head;
input += GetLoopMergeBlock("%9", "%10", "%11", "%12");
input += GetSelectionMergeBlock("%12", "%13", "%11", "%15", "%16");
input += GetReturnBlock("%15");
input += GetReturnBlock("%16");
input += GetUnreachableBlock("%10");
input += GetBranchBlock("%11", "%9");
input += function_tail;
std::string expected = spirv_header + function_head;
expected += GetLoopMergeBlock("%9", "%10", "%11", "%12");
expected += GetSelectionMergeBlock("%12", "%13", "%16", "%14", "%15");
expected += GetReturnBlock("%14");
expected += GetReturnBlock("%15");
expected += GetUnreachableBlock("%10");
expected += GetUnreachableBlock("%16");
expected += GetBranchBlock("%11", "%9");
expected += function_tail;
SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
/* skip_nop = */ false);
}
TEST_F(SplitInvalidUnreachableTest, LoopInSelection) {
std::string input = spirv_header + function_head;
input += GetSelectionMergeBlock("%9", "%10", "%11", "%12", "%13");
input += GetLoopMergeBlock("%12", "%14", "%11", "%15");
input += GetReturnBlock("%13");
input += GetUnreachableBlock("%14");
input += GetBranchBlock("%11", "%12");
input += GetReturnBlock("%15");
input += function_tail;
std::string expected = spirv_header + function_head;
expected += GetSelectionMergeBlock("%9", "%10", "%16", "%12", "%13");
expected += GetLoopMergeBlock("%12", "%14", "%11", "%15");
expected += GetReturnBlock("%13");
expected += GetUnreachableBlock("%14");
expected += GetUnreachableBlock("%16");
expected += GetBranchBlock("%11", "%12");
expected += GetReturnBlock("%15");
expected += function_tail;
SinglePassRunAndCheck<SplitInvalidUnreachablePass>(input, expected,
/* skip_nop = */ false);
}
} // namespace
} // namespace opt
} // namespace spvtools