Structured switch checks

Fixes #491

* Basic blocks now have a link to the terminator
* Check all case sepecific rules
* Missing check for branching into the middle of a case (#1618)
This commit is contained in:
Alan Baker 2018-06-08 15:49:50 -04:00
parent 4f866abfd8
commit ea7239fa73
4 changed files with 456 additions and 11 deletions

View File

@ -37,6 +37,8 @@ enum BlockType : uint32_t {
kBlockTypeCOUNT ///< Total number of block types. (must be the last element)
};
class Instruction;
// This class represents a basic block in a SPIR-V module
class BasicBlock {
public:
@ -107,6 +109,12 @@ class BasicBlock {
/// Ends the block without a successor
void RegisterBranchInstruction(SpvOp branch_instruction);
/// Registers the terminator instruction for the block.
void set_terminator(const Instruction* t) { terminator_ = t; }
/// Returns the terminator instruction for the block.
const Instruction* terminator() const { return terminator_; }
/// Adds @p next BasicBlocks as successors of this BasicBlock
void RegisterSuccessors(
const std::vector<BasicBlock*>& next = std::vector<BasicBlock*>());
@ -209,6 +217,9 @@ class BasicBlock {
/// True if the block is reachable in the CFG
bool reachable_;
/// Terminator of this block.
const Instruction* terminator_;
};
/// @brief Returns true if the iterators point to the same element or if both

View File

@ -409,6 +409,11 @@ void ValidationState_t::RegisterInstruction(
if (in_function_body()) {
ordered_instructions_.emplace_back(&inst, &current_function(),
current_function().current_block());
if (in_block() &&
spvOpcodeIsBlockTerminator(static_cast<SpvOp>(inst.opcode))) {
current_function().current_block()->set_terminator(
&ordered_instructions_.back());
}
} else {
ordered_instructions_.emplace_back(&inst, nullptr, nullptr);
}

View File

@ -167,6 +167,153 @@ string ConstructErrorString(const Construct& construct,
exit_string;
}
// Finds the fall through case construct of |target_block| and records it in
// |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct
// headed by |target_block| branches to multiple case constructs.
spv_result_t FindCaseFallThrough(
const ValidationState_t& _, const BasicBlock* target_block,
uint32_t* case_fall_through, const BasicBlock* merge,
const std::unordered_set<uint32_t>& case_targets) {
std::vector<const BasicBlock*> stack;
stack.push_back(target_block);
std::unordered_set<const BasicBlock*> visited;
while (!stack.empty()) {
const auto block = stack.back();
stack.pop_back();
if (block == merge) continue;
if (!visited.insert(block).second) continue;
if (target_block->reachable() && block->reachable() &&
target_block->dominates(*block)) {
// Still in the case construct.
for (auto successor : *block->successors()) {
stack.push_back(successor);
}
} else {
// Exiting the case construct to non-merge block.
if (!case_targets.count(block->id())) {
continue;
}
if (*case_fall_through == 0u) {
*case_fall_through = block->id();
} else if (*case_fall_through != block->id()) {
// Case construct has at most one branch to another case construct.
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Case construct that targets "
<< _.getIdName(target_block->id())
<< " has branches to multiple other case construct targets "
<< _.getIdName(*case_fall_through) << " and "
<< _.getIdName(block->id());
}
}
}
return SPV_SUCCESS;
}
spv_result_t StructuredSwitchChecks(const ValidationState_t& _,
const Function& function,
const Instruction* switch_inst,
const BasicBlock* header,
const BasicBlock* merge) {
std::unordered_set<uint32_t> case_targets;
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
if (target != merge->id()) case_targets.insert(target);
}
// Tracks how many times each case construct is targeted by another case
// construct.
std::map<uint32_t, uint32_t> num_fall_through_targeted;
uint32_t default_case_fall_through = 0u;
uint32_t default_target = switch_inst->GetOperandAs<uint32_t>(1u);
std::unordered_set<uint32_t> seen;
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
if (target == merge->id()) continue;
if (!seen.insert(target).second) continue;
const auto target_block = function.GetBlock(target).first;
// OpSwitch must dominate all its case constructs.
if (header->reachable() && target_block->reachable() &&
!header->dominates(*target_block)) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Selection header " << _.getIdName(header->id())
<< " does not dominate its case construct " << _.getIdName(target);
}
uint32_t case_fall_through = 0u;
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
merge, case_targets)) {
return error;
}
// Track how many time the fall through case has been targeted.
if (case_fall_through != 0u) {
auto where = num_fall_through_targeted.lower_bound(case_fall_through);
if (where == num_fall_through_targeted.end() ||
where->first != case_fall_through) {
num_fall_through_targeted.insert(where,
std::make_pair(case_fall_through, 1));
} else {
where->second++;
}
}
if (case_fall_through == default_target) {
case_fall_through = default_case_fall_through;
}
if (case_fall_through != 0u) {
bool is_default = i == 1;
if (is_default) {
default_case_fall_through = case_fall_through;
} else {
// Allow code like:
// case x:
// case y:
// ...
// case z:
//
// Where x and y target the same block and fall through to z.
uint32_t j = i;
while ((j + 2 < switch_inst->operands().size()) &&
target == switch_inst->GetOperandAs<uint32_t>(j + 2)) {
j += 2;
}
// If Target T1 branches to Target T2, or if Target T1 branches to the
// Default target and the Default target branches to Target T2, then T1
// must immediately precede T2 in the list of OpSwitch Target operands.
if ((switch_inst->operands().size() < j + 2) ||
(case_fall_through != switch_inst->GetOperandAs<uint32_t>(j + 2))) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Case construct that targets " << _.getIdName(target)
<< " has branches to the case construct that targets "
<< _.getIdName(case_fall_through)
<< ", but does not immediately precede it in the "
"OpSwitch's "
"target list";
}
}
}
}
// Each case construct must be branched to by at most one other case
// construct.
for (const auto& pair : num_fall_through_targeted) {
if (pair.second > 1) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Multiple case constructs have branches to the case construct "
"that targets "
<< _.getIdName(pair.first);
}
}
return SPV_SUCCESS;
}
spv_result_t StructuredControlFlowChecks(
const ValidationState_t& _, Function* function,
const vector<pair<uint32_t, uint32_t>>& back_edges) {
@ -261,16 +408,15 @@ spv_result_t StructuredControlFlowChecks(
}
}
// TODO(umar): an OpSwitch block dominates all its defined case
// constructs
// TODO(umar): each case construct has at most one branch to another
// case construct
// TODO(umar): each case construct is branched to by at most one other
// case construct
// TODO(umar): if Target T1 branches to Target T2, or if Target T1
// branches to the Default and the Default branches to Target T2, then
// T1 must immediately precede T2 in the list of the OpSwitch Target
// operands
// Checks rules for case constructs.
if (construct.type() == ConstructType::kSelection &&
header->terminator()->opcode() == SpvOpSwitch) {
const auto terminator = header->terminator();
if (auto error =
StructuredSwitchChecks(_, *function, terminator, header, merge)) {
return error;
}
}
}
return SPV_SUCCESS;
}

View File

@ -1450,6 +1450,289 @@ OpFunctionEnd
"selection header <ID>"));
}
/// TODO(umar): Switch instructions
TEST_F(ValidateCFG, SwitchDefaultOnly) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpFunction %1 None %4
%6 = OpLabel
OpSelectionMerge %7 None
OpSwitch %3 %7
%7 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, SwitchSingleCase) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpFunction %1 None %4
%6 = OpLabel
OpSelectionMerge %7 None
OpSwitch %3 %7 0 %8
%8 = OpLabel
OpBranch %7
%7 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, MultipleFallThroughBlocks) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12
%10 = OpLabel
OpBranchConditional %6 %11 %12
%11 = OpLabel
OpBranch %9
%12 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Case construct that targets 10 has branches to multiple other case "
"construct targets 12 and 11"));
}
TEST_F(ValidateCFG, MultipleFallThroughToDefault) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12
%10 = OpLabel
OpBranch %9
%11 = OpLabel
OpBranch %10
%12 = OpLabel
OpBranch %10
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Multiple case constructs have branches to the case construct "
"that targets 10"));
}
TEST_F(ValidateCFG, MultipleFallThroughToNonDefault) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12
%10 = OpLabel
OpBranch %12
%11 = OpLabel
OpBranch %12
%12 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Multiple case constructs have branches to the case construct "
"that targets 12"));
}
TEST_F(ValidateCFG, DuplicateTargetWithFallThrough) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %10 1 %11
%10 = OpLabel
OpBranch %11
%11 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, WrongOperandList) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12
%10 = OpLabel
OpBranch %9
%12 = OpLabel
OpBranch %11
%11 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Case construct that targets 12 has branches to the case "
"construct that targets 11, but does not immediately "
"precede it in the OpSwitch's target list"));
}
TEST_F(ValidateCFG, WrongOperandListThroughDefault) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12
%10 = OpLabel
OpBranch %11
%12 = OpLabel
OpBranch %10
%11 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Case construct that targets 12 has branches to the case "
"construct that targets 11, but does not immediately "
"precede it in the OpSwitch's target list"));
}
TEST_F(ValidateCFG, WrongOperandListNotLast) {
std::string text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeVoid
%2 = OpTypeInt 32 0
%3 = OpConstant %2 0
%4 = OpTypeFunction %1
%5 = OpTypeBool
%6 = OpConstantTrue %5
%7 = OpFunction %1 None %4
%8 = OpLabel
OpSelectionMerge %9 None
OpSwitch %3 %10 0 %11 1 %12 2 %13
%10 = OpLabel
OpBranch %9
%12 = OpLabel
OpBranch %11
%11 = OpLabel
OpBranch %9
%13 = OpLabel
OpBranch %9
%9 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Case construct that targets 12 has branches to the case "
"construct that targets 11, but does not immediately "
"precede it in the OpSwitch's target list"));
}
/// TODO(umar): Nested CFG constructs
} // namespace