Fix reachability in the validator (#3541)

Fixes #3529

* Make BasicBlock::reachable() only consider static reachability
* Fix reachability calculation to be independent of block order
* add tests
This commit is contained in:
alan-baker 2020-07-15 21:27:03 -04:00 committed by GitHub
parent 2fa735dc06
commit 7221ccf85e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 29 deletions

View File

@ -58,15 +58,9 @@ void BasicBlock::RegisterSuccessors(
for (auto& block : next_blocks) {
block->predecessors_.push_back(this);
successors_.push_back(block);
if (block->reachable_ == false) block->set_reachable(reachable_);
}
}
void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
if (branch_instruction == SpvOpUnreachable) reachable_ = false;
return;
}
bool BasicBlock::dominates(const BasicBlock& other) const {
return (this == &other) ||
!(other.dom_end() ==

View File

@ -106,9 +106,6 @@ class BasicBlock {
/// Returns the immedate post dominator of this basic block
const BasicBlock* immediate_post_dominator() const;
/// Ends the block without a successor
void RegisterBranchInstruction(SpvOp branch_instruction);
/// Returns the label instruction for the block, or nullptr if not set.
const Instruction* label() const { return label_; }

View File

@ -130,7 +130,6 @@ spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) {
undefined_blocks_.erase(block_id);
current_block_ = &inserted_block->second;
ordered_blocks_.push_back(current_block_);
if (IsFirstBlock(block_id)) current_block_->set_reachable(true);
} else if (success) { // Block doesn't exsist but this is not a definition
undefined_blocks_.insert(block_id);
}
@ -138,8 +137,7 @@ spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) {
return SPV_SUCCESS;
}
void Function::RegisterBlockEnd(std::vector<uint32_t> next_list,
SpvOp branch_instruction) {
void Function::RegisterBlockEnd(std::vector<uint32_t> next_list) {
assert(
current_block_ &&
"RegisterBlockEnd can only be called when parsing a binary in a block");
@ -174,7 +172,6 @@ void Function::RegisterBlockEnd(std::vector<uint32_t> next_list,
}
}
current_block_->RegisterBranchInstruction(branch_instruction);
current_block_->RegisterSuccessors(next_blocks);
current_block_ = nullptr;
return;

View File

@ -97,9 +97,7 @@ class Function {
/// Registers the end of the block
///
/// @param[in] successors_list A list of ids to the block's successors
/// @param[in] branch_instruction the branch instruction that ended the block
void RegisterBlockEnd(std::vector<uint32_t> successors_list,
SpvOp branch_instruction);
void RegisterBlockEnd(std::vector<uint32_t> successors_list);
/// Registers the end of the function. This is idempotent.
void RegisterFunctionEnd();

View File

@ -368,6 +368,10 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
// Catch undefined forward references before performing further checks.
if (auto error = ValidateForwardDecls(*vstate)) return error;
// Calculate reachability after all the blocks are parsed, but early that it
// can be relied on in subsequent pases.
ReachabilityPass(*vstate);
// ID usage needs be handled in its own iteration of the instructions,
// between the two others. It depends on the first loop to have been
// finished, so that all instructions have been registered. And the following

View File

@ -197,6 +197,9 @@ spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of miscellaneous instructions.
spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);
/// Validates execution limitations.
///
/// Verifies execution models are allowed for all functionality they contain.

View File

@ -1062,7 +1062,7 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
uint32_t target = inst->GetOperandAs<uint32_t>(0);
CFG_ASSERT(FirstBlockAssert, target);
_.current_function().RegisterBlockEnd({target}, opcode);
_.current_function().RegisterBlockEnd({target});
} break;
case SpvOpBranchConditional: {
uint32_t tlabel = inst->GetOperandAs<uint32_t>(1);
@ -1070,7 +1070,7 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
CFG_ASSERT(FirstBlockAssert, tlabel);
CFG_ASSERT(FirstBlockAssert, flabel);
_.current_function().RegisterBlockEnd({tlabel, flabel}, opcode);
_.current_function().RegisterBlockEnd({tlabel, flabel});
} break;
case SpvOpSwitch: {
@ -1080,7 +1080,7 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
CFG_ASSERT(FirstBlockAssert, target);
cases.push_back(target);
}
_.current_function().RegisterBlockEnd({cases}, opcode);
_.current_function().RegisterBlockEnd({cases});
} break;
case SpvOpReturn: {
const uint32_t return_type = _.current_function().GetResultTypeId();
@ -1090,13 +1090,13 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
return _.diag(SPV_ERROR_INVALID_CFG, inst)
<< "OpReturn can only be called from a function with void "
<< "return type.";
_.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
_.current_function().RegisterBlockEnd(std::vector<uint32_t>());
break;
}
case SpvOpKill:
case SpvOpReturnValue:
case SpvOpUnreachable:
_.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
_.current_function().RegisterBlockEnd(std::vector<uint32_t>());
if (opcode == SpvOpKill) {
_.current_function().RegisterExecutionModelLimitation(
SpvExecutionModelFragment,
@ -1109,6 +1109,27 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
return SPV_SUCCESS;
}
void ReachabilityPass(ValidationState_t& _) {
for (auto& f : _.functions()) {
std::vector<BasicBlock*> stack;
auto entry = f.first_block();
// Skip function declarations.
if (entry) stack.push_back(entry);
while (!stack.empty()) {
auto block = stack.back();
stack.pop_back();
if (block->reachable()) continue;
block->set_reachable(true);
for (auto succ : *block->successors()) {
stack.push_back(succ);
}
}
}
}
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpPhi:

View File

@ -1204,14 +1204,6 @@ TEST_P(ValidateCFG, UnreachableMergeWithBranchUse) {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, WebGPUUnreachableMergeWithBranchUse) {
CompileSuccessfully(
GetUnreachableMergeWithBranchUse(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("cannot be the target of a branch."));
}
std::string GetUnreachableMergeWithMultipleUses(SpvCapability cap,
spv_target_env env) {
std::string header =
@ -4503,6 +4495,76 @@ OpFunctionEnd
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, UnreachableIsStaticallyReachable) {
const std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeFunction %1
%3 = OpFunction %1 None %2
%4 = OpLabel
OpBranch %5
%5 = OpLabel
OpUnreachable
OpFunctionEnd
)";
CompileSuccessfully(text);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
auto f = vstate_->function(3);
auto entry = f->GetBlock(4).first;
ASSERT_TRUE(entry->reachable());
auto end = f->GetBlock(5).first;
ASSERT_TRUE(end->reachable());
}
TEST_F(ValidateCFG, BlockOrderDoesNotAffectReachability) {
const std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeFunction %1
%3 = OpTypeBool
%4 = OpUndef %3
%5 = OpFunction %1 None %2
%6 = OpLabel
OpBranch %7
%7 = OpLabel
OpSelectionMerge %8 None
OpBranchConditional %4 %9 %10
%8 = OpLabel
OpReturn
%9 = OpLabel
OpBranch %8
%10 = OpLabel
OpBranch %8
%11 = OpLabel
OpUnreachable
OpFunctionEnd
)";
CompileSuccessfully(text);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
auto f = vstate_->function(5);
auto b6 = f->GetBlock(6).first;
auto b7 = f->GetBlock(7).first;
auto b8 = f->GetBlock(8).first;
auto b9 = f->GetBlock(9).first;
auto b10 = f->GetBlock(10).first;
auto b11 = f->GetBlock(11).first;
ASSERT_TRUE(b6->reachable());
ASSERT_TRUE(b7->reachable());
ASSERT_TRUE(b8->reachable());
ASSERT_TRUE(b9->reachable());
ASSERT_TRUE(b10->reachable());
ASSERT_FALSE(b11->reachable());
}
} // namespace
} // namespace val
} // namespace spvtools