Check for unreachable blocks in merge-return. (#1966)

Merge return assumes that the only unreachable blocks are those needed
to keep the structured cfg valid.  Even those must be essentially empty
blocks.

If this is not the case, we get unpredictable behaviour.  This commit
add a check in merge return, and emits an error if it is not the case.

Added a pass of dead branch elimination before merge return in both the
performance and size passes.  It is a precondition of merge return.

Fixes #1962.
This commit is contained in:
Steven Perron 2018-10-10 15:18:15 -04:00 committed by GitHub
parent bc09f53c96
commit 82663f34c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 284 additions and 37 deletions

View File

@ -93,7 +93,7 @@ BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) {
bool DeadBranchElimPass::MarkLiveBlocks(
Function* func, std::unordered_set<BasicBlock*>* live_blocks) {
StructuredCFGAnalysis cfgAnalysis(context());
StructuredCFGAnalysis* cfgAnalysis = context()->GetStructuedCFGAnalaysis();
std::unordered_set<BasicBlock*> continues;
std::vector<BasicBlock*> stack;
@ -164,7 +164,7 @@ bool DeadBranchElimPass::MarkLiveBlocks(
if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) {
Instruction* first_break = FindFirstExitFromSelectionMerge(
live_lab_id, mergeInst->GetSingleWordInOperand(0),
cfgAnalysis.LoopMergeBlock(live_lab_id));
cfgAnalysis->LoopMergeBlock(live_lab_id));
if (first_break == nullptr) {
context()->KillInst(mergeInst);
} else {

View File

@ -55,6 +55,9 @@ void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) {
if (set & kAnalysisValueNumberTable) {
BuildValueNumberTable();
}
if (set & kAnalysisStructuredCFG) {
BuildStructuredCFGAnalysis();
}
}
void IRContext::InvalidateAnalysesExceptFor(
@ -89,6 +92,9 @@ void IRContext::InvalidateAnalyses(IRContext::Analysis analyses_to_invalidate) {
if (analyses_to_invalidate & kAnalysisValueNumberTable) {
vn_table_.reset(nullptr);
}
if (analyses_to_invalidate & kAnalysisStructuredCFG) {
struct_cfg_analysis_.reset(nullptr);
}
valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate);
}

View File

@ -37,6 +37,7 @@
#include "source/opt/module.h"
#include "source/opt/register_pressure.h"
#include "source/opt/scalar_analysis.h"
#include "source/opt/struct_cfg_analysis.h"
#include "source/opt/type_manager.h"
#include "source/opt/value_number_table.h"
#include "source/util/make_unique.h"
@ -71,7 +72,8 @@ class IRContext {
kAnalysisScalarEvolution = 1 << 8,
kAnalysisRegisterPressure = 1 << 9,
kAnalysisValueNumberTable = 1 << 10,
kAnalysisEnd = 1 << 11
kAnalysisStructuredCFG = 1 << 11,
kAnalysisEnd = 1 << 12
};
friend inline Analysis operator|(Analysis lhs, Analysis rhs);
@ -227,6 +229,15 @@ class IRContext {
return vn_table_.get();
}
// Returns a pointer to a StructuredCFGAnalysis. If the analysis is invalid,
// it is rebuilt first.
StructuredCFGAnalysis* GetStructuedCFGAnalaysis() {
if (!AreAnalysesValid(kAnalysisStructuredCFG)) {
BuildStructuredCFGAnalysis();
}
return struct_cfg_analysis_.get();
}
// Returns a pointer to a liveness analysis. If the liveness analysis is
// invalid, it is rebuilt first.
LivenessAnalysis* GetLivenessAnalysis() {
@ -509,6 +520,13 @@ class IRContext {
valid_analyses_ = valid_analyses_ | kAnalysisValueNumberTable;
}
// Builds the structured CFG analysis from scratch, even if it was already
// valid.
void BuildStructuredCFGAnalysis() {
struct_cfg_analysis_ = MakeUnique<StructuredCFGAnalysis>(this);
valid_analyses_ = valid_analyses_ | kAnalysisStructuredCFG;
}
// Removes all computed dominator and post-dominator trees. This will force
// the context to rebuild the trees on demand.
void ResetDominatorAnalysis() {
@ -619,6 +637,8 @@ class IRContext {
std::unique_ptr<InstructionFolder> inst_folder_;
std::unique_ptr<StructuredCFGAnalysis> struct_cfg_analysis_;
// The maximum legal value for the id bound.
uint32_t max_id_bound_;
};

View File

@ -22,6 +22,7 @@
#include "source/opt/ir_builder.h"
#include "source/opt/ir_context.h"
#include "source/opt/reflect.h"
#include "source/util/bit_vector.h"
#include "source/util/make_unique.h"
namespace spvtools {
@ -42,7 +43,9 @@ Pass::Status MergeReturnPass::Process() {
modified = true;
if (is_shader) {
ProcessStructured(&function, return_blocks);
if (!ProcessStructured(&function, return_blocks)) {
return Status::Failure;
}
} else {
MergeReturnBlocks(&function, return_blocks);
}
@ -51,8 +54,18 @@ Pass::Status MergeReturnPass::Process() {
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
void MergeReturnPass::ProcessStructured(
bool MergeReturnPass::ProcessStructured(
Function* function, const std::vector<BasicBlock*>& return_blocks) {
if (HasNontrivialUnreachableBlocks(function)) {
if (consumer()) {
std::string message =
"Module contains unreachable blocks during merge return. Run dead "
"branch elimination before merge return.";
consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str());
}
return false;
}
AddDummyLoopAroundFunction();
std::list<BasicBlock*> order;
@ -114,6 +127,7 @@ void MergeReturnPass::ProcessStructured(
// Invalidate it at this point to make sure it will be rebuilt.
context()->RemoveDominatorAnalysis(function);
AddNewPhiNodes();
return true;
}
void MergeReturnPass::CreateReturnBlock() {
@ -708,5 +722,41 @@ void MergeReturnPass::CreateDummyLoop(BasicBlock* merge_target) {
}
}
bool MergeReturnPass::HasNontrivialUnreachableBlocks(Function* function) {
utils::BitVector reachable_blocks;
cfg()->ForEachBlockInPostOrder(
function->entry().get(),
[&reachable_blocks](BasicBlock* bb) { reachable_blocks.Set(bb->id()); });
for (auto& bb : *function) {
if (reachable_blocks.Get(bb.id())) {
continue;
}
StructuredCFGAnalysis* struct_cfg_analysis =
context()->GetStructuedCFGAnalaysis();
if (struct_cfg_analysis->IsMergeBlock(bb.id())) {
// |bb| must be an empty block ending with OpUnreachable.
if (bb.begin()->opcode() != SpvOpUnreachable) {
return true;
}
} else if (struct_cfg_analysis->IsContinueBlock(bb.id())) {
// |bb| must be an empty block ending with a branch to the header.
Instruction* inst = &*bb.begin();
if (inst->opcode() != SpvOpBranch) {
return true;
}
if (inst->GetSingleWordInOperand(0) !=
struct_cfg_analysis->ContainingLoop(bb.id())) {
return true;
}
} else {
return true;
}
}
return false;
}
} // namespace opt
} // namespace spvtools

View File

@ -163,7 +163,7 @@ class MergeReturnPass : public MemPass {
// statement. It is assumed that |function| has structured control flow, and
// that |return_blocks| is a list of all of the basic blocks in |function|
// that have a return.
void ProcessStructured(Function* function,
bool ProcessStructured(Function* function,
const std::vector<BasicBlock*>& return_blocks);
// Changes an OpReturn* or OpUnreachable instruction at the end of |block|
@ -322,6 +322,7 @@ class MergeReturnPass : public MemPass {
// it is mapped to it original single predcessor. It is assumed there are no
// values that will need a phi on the new edges.
std::unordered_map<BasicBlock*, BasicBlock*> new_merge_nodes_;
bool HasNontrivialUnreachableBlocks(Function* function);
};
} // namespace opt

View File

@ -149,7 +149,8 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
}
Optimizer& Optimizer::RegisterPerformancePasses() {
return RegisterPass(CreateMergeReturnPass())
return RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateAggressiveDCEPass())
.RegisterPass(CreatePrivateToLocalPass())
@ -186,7 +187,8 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
}
Optimizer& Optimizer::RegisterSizePasses() {
return RegisterPass(CreateMergeReturnPass())
return RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateAggressiveDCEPass())
.RegisterPass(CreatePrivateToLocalPass())

View File

@ -14,8 +14,11 @@
#include "source/opt/struct_cfg_analysis.h"
#include "source/opt/ir_context.h"
namespace {
const uint32_t kMergeNodeIndex = 0;
const uint32_t kContinueNodeIndex = 1;
}
namespace spvtools {
@ -74,6 +77,7 @@ void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
}
state.emplace_back(new_state);
merge_blocks_.Set(new_state.merge_node);
}
}
}
@ -100,5 +104,25 @@ uint32_t StructuredCFGAnalysis::LoopMergeBlock(uint32_t bb_id) {
return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
}
uint32_t StructuredCFGAnalysis::LoopContinueBlock(uint32_t bb_id) {
uint32_t header_id = ContainingLoop(bb_id);
if (header_id == 0) {
return 0;
}
BasicBlock* header = context_->cfg()->block(header_id);
Instruction* merge_inst = header->GetMergeInst();
return merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
}
bool StructuredCFGAnalysis::IsContinueBlock(uint32_t bb_id) {
assert(bb_id != 0);
return LoopContinueBlock(bb_id) == bb_id;
}
bool StructuredCFGAnalysis::IsMergeBlock(uint32_t bb_id) {
return merge_blocks_.Get(bb_id);
}
} // namespace opt
} // namespace spvtools

View File

@ -17,11 +17,14 @@
#include <unordered_map>
#include "ir_context.h"
#include "source/opt/function.h"
#include "source/util/bit_vector.h"
namespace spvtools {
namespace opt {
class IRContext;
// An analysis that, for each basic block, finds the constructs in which it is
// contained, so we can easily get headers and merge nodes.
class StructuredCFGAnalysis {
@ -60,6 +63,14 @@ class StructuredCFGAnalysis {
// construct.
uint32_t LoopMergeBlock(uint32_t bb_id);
// Returns the id of the continue block of the innermost loop construct
// that contains |bb_id|. Return |0| if |bb_id| is not contained in any loop
// construct.
uint32_t LoopContinueBlock(uint32_t bb_id);
bool IsContinueBlock(uint32_t bb_id);
bool IsMergeBlock(uint32_t bb_id);
private:
// Struct used to hold the information for a basic block.
// |containing_construct| is the header for the innermost containing
@ -82,6 +93,7 @@ class StructuredCFGAnalysis {
// A map from a basic block to the headers of its inner most containing
// constructs.
std::unordered_map<uint32_t, ConstructInfo> bb_to_construct_;
utils::BitVector merge_blocks_;
};
} // namespace opt

View File

@ -1019,6 +1019,127 @@ OpFunctionEnd
SinglePassRunAndMatch<MergeReturnPass>(test, false);
}
TEST_F(MergeReturnPassTest,
StructuredControlFlowWithNonTrivialUnreachableMerge) {
const std::string before =
R"(
OpCapability Addresses
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %6 "simple_shader"
%2 = OpTypeVoid
%3 = OpTypeBool
%4 = OpConstantFalse %3
%1 = OpTypeFunction %2
%6 = OpFunction %2 None %1
%7 = OpLabel
OpSelectionMerge %10 None
OpBranchConditional %4 %8 %9
%8 = OpLabel
OpReturn
%9 = OpLabel
OpReturn
%10 = OpLabel
%11 = OpUndef %3
OpUnreachable
OpFunctionEnd
)";
std::vector<Message> messages = {
{SPV_MSG_ERROR, nullptr, 0, 0,
"Module contains unreachable blocks during merge return. Run dead "
"branch elimination before merge return."}};
SetMessageConsumer(GetTestMessageConsumer(messages));
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
auto result = SinglePassRunToBinary<MergeReturnPass>(before, false);
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
EXPECT_TRUE(messages.empty());
}
TEST_F(MergeReturnPassTest,
StructuredControlFlowWithNonTrivialUnreachableContinue) {
const std::string before =
R"(
OpCapability Addresses
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %6 "simple_shader"
%2 = OpTypeVoid
%3 = OpTypeBool
%4 = OpConstantFalse %3
%1 = OpTypeFunction %2
%6 = OpFunction %2 None %1
%7 = OpLabel
OpBranch %header
%header = OpLabel
OpLoopMerge %merge %continue None
OpBranchConditional %4 %8 %merge
%8 = OpLabel
OpReturn
%continue = OpLabel
%11 = OpUndef %3
OpBranch %header
%merge = OpLabel
OpReturn
OpFunctionEnd
)";
std::vector<Message> messages = {
{SPV_MSG_ERROR, nullptr, 0, 0,
"Module contains unreachable blocks during merge return. Run dead "
"branch elimination before merge return."}};
SetMessageConsumer(GetTestMessageConsumer(messages));
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
auto result = SinglePassRunToBinary<MergeReturnPass>(before, false);
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
EXPECT_TRUE(messages.empty());
}
TEST_F(MergeReturnPassTest, StructuredControlFlowWithUnreachableBlock) {
const std::string before =
R"(
OpCapability Addresses
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %6 "simple_shader"
%2 = OpTypeVoid
%3 = OpTypeBool
%4 = OpConstantFalse %3
%1 = OpTypeFunction %2
%6 = OpFunction %2 None %1
%7 = OpLabel
OpBranch %header
%header = OpLabel
OpLoopMerge %merge %continue None
OpBranchConditional %4 %8 %merge
%8 = OpLabel
OpReturn
%continue = OpLabel
OpBranch %header
%merge = OpLabel
OpReturn
%unreachable = OpLabel
OpUnreachable
OpFunctionEnd
)";
std::vector<Message> messages = {
{SPV_MSG_ERROR, nullptr, 0, 0,
"Module contains unreachable blocks during merge return. Run dead "
"branch elimination before merge return."}};
SetMessageConsumer(GetTestMessageConsumer(messages));
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
auto result = SinglePassRunToBinary<MergeReturnPass>(before, false);
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
EXPECT_TRUE(messages.empty());
}
} // namespace
} // namespace opt
} // namespace spvtools

View File

@ -35,6 +35,26 @@ const char* kDebugOpcodes[] = {
} // anonymous namespace
MessageConsumer GetTestMessageConsumer(
std::vector<Message>& expected_messages) {
return [&expected_messages](spv_message_level_t level, const char* source,
const spv_position_t& position,
const char* message) {
EXPECT_TRUE(!expected_messages.empty());
if (expected_messages.empty()) {
return;
}
EXPECT_EQ(expected_messages[0].level, level);
EXPECT_EQ(expected_messages[0].line_number, position.line);
EXPECT_EQ(expected_messages[0].column_number, position.column);
EXPECT_STREQ(expected_messages[0].source_file, source);
EXPECT_STREQ(expected_messages[0].message, message);
expected_messages.erase(expected_messages.begin());
};
}
bool FindAndReplace(std::string* process_str, const std::string find_str,
const std::string replace_str) {
if (process_str->empty() || find_str.empty()) {

View File

@ -21,9 +21,25 @@
#include <string>
#include <vector>
#include "external/googletest/googletest/include/gtest/gtest.h"
#include "include/spirv-tools/libspirv.h"
#include "include/spirv-tools/libspirv.hpp"
namespace spvtools {
namespace opt {
struct Message {
spv_message_level_t level;
const char* source_file;
uint32_t line_number;
uint32_t column_number;
const char* message;
};
// Return a message consumer that can be used to check that the message produced
// are the messages in |expexted_messages|, and in the same order.
MessageConsumer GetTestMessageConsumer(std::vector<Message>& expected_messages);
// In-place substring replacement. Finds the |find_str| in the |process_str|
// and replaces the found substring with |replace_str|. Returns true if at
// least one replacement is done successfully, returns false otherwise. The

View File

@ -17,6 +17,7 @@
#include <vector>
#include "gmock/gmock.h"
#include "pass_utils.h"
#include "test/opt/assembly_builder.h"
#include "test/opt/pass_fixture.h"
@ -434,34 +435,6 @@ TEST_F(ReplaceInvalidOpcodeTest, BarrierReplace) {
SinglePassRunAndMatch<ReplaceInvalidOpcodePass>(text, false);
}
struct Message {
spv_message_level_t level;
const char* source_file;
uint32_t line_number;
uint32_t column_number;
const char* message;
};
MessageConsumer GetTestMessageConsumer(
std::vector<Message>& expected_messages) {
return [&expected_messages](spv_message_level_t level, const char* source,
const spv_position_t& position,
const char* message) {
EXPECT_TRUE(!expected_messages.empty());
if (expected_messages.empty()) {
return;
}
EXPECT_EQ(expected_messages[0].level, level);
EXPECT_EQ(expected_messages[0].line_number, position.line);
EXPECT_EQ(expected_messages[0].column_number, position.column);
EXPECT_STREQ(expected_messages[0].source_file, source);
EXPECT_STREQ(expected_messages[0].message, message);
expected_messages.erase(expected_messages.begin());
};
}
TEST_F(ReplaceInvalidOpcodeTest, MessageTest) {
const std::string text = R"(
OpCapability Shader

View File

@ -135,6 +135,7 @@ class TestPerformanceOptimizationPasses(expect.ValidObjectFile1_3,
flags = ['-O']
expected_passes = [
'eliminate-dead-branches',
'merge-return',
'inline-entry-points-exhaustive',
'eliminate-dead-code-aggressive',
@ -181,6 +182,7 @@ class TestSizeOptimizationPasses(expect.ValidObjectFile1_3,
flags = ['-Os']
expected_passes = [
'eliminate-dead-branches',
'merge-return',
'inline-entry-points-exhaustive',
'eliminate-dead-code-aggressive',