Add a check for invalid exits from case construct.

Fixes #1618.

Adds a check that validates acceptable exits from case constructs. Case
constructs may only exit to another case construct, the corresponding
merge, an outer loop continue or outer loop merge.
This commit is contained in:
Alan Baker 2018-06-29 11:01:13 -04:00 committed by David Neto
parent fa78d3bec9
commit c460f44fbc
2 changed files with 90 additions and 10 deletions

View File

@ -173,21 +173,23 @@ string ConstructErrorString(const Construct& construct,
// |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct
// headed by |target_block| branches to multiple case constructs.
spv_result_t FindCaseFallThrough(
const ValidationState_t& _, const BasicBlock* target_block,
const ValidationState_t& _, BasicBlock* target_block,
uint32_t* case_fall_through, const BasicBlock* merge,
const std::unordered_set<uint32_t>& case_targets) {
std::vector<const BasicBlock*> stack;
const std::unordered_set<uint32_t>& case_targets, Function* function) {
std::vector<BasicBlock*> stack;
stack.push_back(target_block);
std::unordered_set<const BasicBlock*> visited;
bool target_reachable = target_block->reachable();
int target_depth = function->GetBlockDepth(target_block);
while (!stack.empty()) {
const auto block = stack.back();
auto block = stack.back();
stack.pop_back();
if (block == merge) continue;
if (!visited.insert(block).second) continue;
if (target_block->reachable() && block->reachable() &&
if (target_reachable && block->reachable() &&
target_block->dominates(*block)) {
// Still in the case construct.
for (auto successor : *block->successors()) {
@ -196,7 +198,18 @@ spv_result_t FindCaseFallThrough(
} else {
// Exiting the case construct to non-merge block.
if (!case_targets.count(block->id())) {
continue;
int depth = function->GetBlockDepth(block);
if ((depth < target_depth) ||
(depth == target_depth && block->is_type(kBlockTypeContinue))) {
continue;
}
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Case construct that targets "
<< _.getIdName(target_block->id())
<< " has invalid branch to block " << _.getIdName(block->id())
<< " (not another case construct, corresponding merge, outer "
"loop merge or outer loop continue)";
}
if (*case_fall_through == 0u) {
@ -221,7 +234,7 @@ spv_result_t FindCaseFallThrough(
}
spv_result_t StructuredSwitchChecks(const ValidationState_t& _,
const Function& function,
Function* function,
const Instruction* switch_inst,
const BasicBlock* header,
const BasicBlock* merge) {
@ -242,7 +255,7 @@ spv_result_t StructuredSwitchChecks(const ValidationState_t& _,
if (!seen.insert(target).second) continue;
const auto target_block = function.GetBlock(target).first;
const auto target_block = function->GetBlock(target).first;
// OpSwitch must dominate all its case constructs.
if (header->reachable() && target_block->reachable() &&
!header->dominates(*target_block)) {
@ -253,7 +266,7 @@ spv_result_t StructuredSwitchChecks(const ValidationState_t& _,
uint32_t case_fall_through = 0u;
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
merge, case_targets)) {
merge, case_targets, function)) {
return error;
}
@ -429,7 +442,7 @@ spv_result_t StructuredControlFlowChecks(
header->terminator()->opcode() == SpvOpSwitch) {
const auto terminator = header->terminator();
if (auto error =
StructuredSwitchChecks(_, *function, terminator, header, merge)) {
StructuredSwitchChecks(_, function, terminator, header, merge)) {
return error;
}
}

View File

@ -1754,5 +1754,72 @@ OpFunctionEnd
" OpSwitch %uint_0 %10 0 %11 1 %12 2 %13"));
}
TEST_F(ValidateCFG, InvalidCaseExit) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %1 "func"
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeFunction %2
%5 = OpConstant %3 0
%1 = OpFunction %2 None %4
%6 = OpLabel
OpSelectionMerge %7 None
OpSwitch %5 %7 0 %8 1 %9
%8 = OpLabel
OpBranch %10
%9 = OpLabel
OpBranch %10
%10 = OpLabel
OpReturn
%7 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Case construct that targets 8 has invalid branch to "
"block 10 (not another case construct, corresponding "
"merge, outer loop merge or outer loop continue"));
}
TEST_F(ValidateCFG, GoodCaseExitsToOuterConstructs) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %func "func"
%void = OpTypeVoid
%bool = OpTypeBool
%true = OpConstantTrue %bool
%int = OpTypeInt 32 0
%int0 = OpConstant %int 0
%func_ty = OpTypeFunction %void
%func = OpFunction %void None %func_ty
%1 = OpLabel
OpBranch %2
%2 = OpLabel
OpLoopMerge %7 %6 None
OpBranch %3
%3 = OpLabel
OpSelectionMerge %5 None
OpSwitch %int0 %5 0 %4
%4 = OpLabel
OpBranchConditional %true %6 %7
%5 = OpLabel
OpBranchConditional %true %6 %7
%6 = OpLabel
OpBranch %2
%7 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
/// TODO(umar): Nested CFG constructs
} // namespace