mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-25 04:50:04 +00:00
Structured switch checks
Fixes #491 * Basic blocks now have a link to the terminator * Check all case sepecific rules * Missing check for branching into the middle of a case (#1618)
This commit is contained in:
parent
4f866abfd8
commit
ea7239fa73
@ -37,6 +37,8 @@ enum BlockType : uint32_t {
|
||||
kBlockTypeCOUNT ///< Total number of block types. (must be the last element)
|
||||
};
|
||||
|
||||
class Instruction;
|
||||
|
||||
// This class represents a basic block in a SPIR-V module
|
||||
class BasicBlock {
|
||||
public:
|
||||
@ -107,6 +109,12 @@ class BasicBlock {
|
||||
/// Ends the block without a successor
|
||||
void RegisterBranchInstruction(SpvOp branch_instruction);
|
||||
|
||||
/// Registers the terminator instruction for the block.
|
||||
void set_terminator(const Instruction* t) { terminator_ = t; }
|
||||
|
||||
/// Returns the terminator instruction for the block.
|
||||
const Instruction* terminator() const { return terminator_; }
|
||||
|
||||
/// Adds @p next BasicBlocks as successors of this BasicBlock
|
||||
void RegisterSuccessors(
|
||||
const std::vector<BasicBlock*>& next = std::vector<BasicBlock*>());
|
||||
@ -209,6 +217,9 @@ class BasicBlock {
|
||||
|
||||
/// True if the block is reachable in the CFG
|
||||
bool reachable_;
|
||||
|
||||
/// Terminator of this block.
|
||||
const Instruction* terminator_;
|
||||
};
|
||||
|
||||
/// @brief Returns true if the iterators point to the same element or if both
|
||||
|
@ -409,6 +409,11 @@ void ValidationState_t::RegisterInstruction(
|
||||
if (in_function_body()) {
|
||||
ordered_instructions_.emplace_back(&inst, ¤t_function(),
|
||||
current_function().current_block());
|
||||
if (in_block() &&
|
||||
spvOpcodeIsBlockTerminator(static_cast<SpvOp>(inst.opcode))) {
|
||||
current_function().current_block()->set_terminator(
|
||||
&ordered_instructions_.back());
|
||||
}
|
||||
} else {
|
||||
ordered_instructions_.emplace_back(&inst, nullptr, nullptr);
|
||||
}
|
||||
|
@ -167,6 +167,153 @@ string ConstructErrorString(const Construct& construct,
|
||||
exit_string;
|
||||
}
|
||||
|
||||
// Finds the fall through case construct of |target_block| and records it in
|
||||
// |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct
|
||||
// headed by |target_block| branches to multiple case constructs.
|
||||
spv_result_t FindCaseFallThrough(
|
||||
const ValidationState_t& _, const BasicBlock* target_block,
|
||||
uint32_t* case_fall_through, const BasicBlock* merge,
|
||||
const std::unordered_set<uint32_t>& case_targets) {
|
||||
std::vector<const BasicBlock*> stack;
|
||||
stack.push_back(target_block);
|
||||
std::unordered_set<const BasicBlock*> visited;
|
||||
while (!stack.empty()) {
|
||||
const auto block = stack.back();
|
||||
stack.pop_back();
|
||||
|
||||
if (block == merge) continue;
|
||||
|
||||
if (!visited.insert(block).second) continue;
|
||||
|
||||
if (target_block->reachable() && block->reachable() &&
|
||||
target_block->dominates(*block)) {
|
||||
// Still in the case construct.
|
||||
for (auto successor : *block->successors()) {
|
||||
stack.push_back(successor);
|
||||
}
|
||||
} else {
|
||||
// Exiting the case construct to non-merge block.
|
||||
if (!case_targets.count(block->id())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (*case_fall_through == 0u) {
|
||||
*case_fall_through = block->id();
|
||||
} else if (*case_fall_through != block->id()) {
|
||||
// Case construct has at most one branch to another case construct.
|
||||
return _.diag(SPV_ERROR_INVALID_CFG)
|
||||
<< "Case construct that targets "
|
||||
<< _.getIdName(target_block->id())
|
||||
<< " has branches to multiple other case construct targets "
|
||||
<< _.getIdName(*case_fall_through) << " and "
|
||||
<< _.getIdName(block->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t StructuredSwitchChecks(const ValidationState_t& _,
|
||||
const Function& function,
|
||||
const Instruction* switch_inst,
|
||||
const BasicBlock* header,
|
||||
const BasicBlock* merge) {
|
||||
std::unordered_set<uint32_t> case_targets;
|
||||
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
||||
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
||||
if (target != merge->id()) case_targets.insert(target);
|
||||
}
|
||||
// Tracks how many times each case construct is targeted by another case
|
||||
// construct.
|
||||
std::map<uint32_t, uint32_t> num_fall_through_targeted;
|
||||
uint32_t default_case_fall_through = 0u;
|
||||
uint32_t default_target = switch_inst->GetOperandAs<uint32_t>(1u);
|
||||
std::unordered_set<uint32_t> seen;
|
||||
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
||||
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
||||
if (target == merge->id()) continue;
|
||||
|
||||
if (!seen.insert(target).second) continue;
|
||||
|
||||
const auto target_block = function.GetBlock(target).first;
|
||||
// OpSwitch must dominate all its case constructs.
|
||||
if (header->reachable() && target_block->reachable() &&
|
||||
!header->dominates(*target_block)) {
|
||||
return _.diag(SPV_ERROR_INVALID_CFG)
|
||||
<< "Selection header " << _.getIdName(header->id())
|
||||
<< " does not dominate its case construct " << _.getIdName(target);
|
||||
}
|
||||
|
||||
uint32_t case_fall_through = 0u;
|
||||
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
|
||||
merge, case_targets)) {
|
||||
return error;
|
||||
}
|
||||
|
||||
// Track how many time the fall through case has been targeted.
|
||||
if (case_fall_through != 0u) {
|
||||
auto where = num_fall_through_targeted.lower_bound(case_fall_through);
|
||||
if (where == num_fall_through_targeted.end() ||
|
||||
where->first != case_fall_through) {
|
||||
num_fall_through_targeted.insert(where,
|
||||
std::make_pair(case_fall_through, 1));
|
||||
} else {
|
||||
where->second++;
|
||||
}
|
||||
}
|
||||
|
||||
if (case_fall_through == default_target) {
|
||||
case_fall_through = default_case_fall_through;
|
||||
}
|
||||
if (case_fall_through != 0u) {
|
||||
bool is_default = i == 1;
|
||||
if (is_default) {
|
||||
default_case_fall_through = case_fall_through;
|
||||
} else {
|
||||
// Allow code like:
|
||||
// case x:
|
||||
// case y:
|
||||
// ...
|
||||
// case z:
|
||||
//
|
||||
// Where x and y target the same block and fall through to z.
|
||||
uint32_t j = i;
|
||||
while ((j + 2 < switch_inst->operands().size()) &&
|
||||
target == switch_inst->GetOperandAs<uint32_t>(j + 2)) {
|
||||
j += 2;
|
||||
}
|
||||
// If Target T1 branches to Target T2, or if Target T1 branches to the
|
||||
// Default target and the Default target branches to Target T2, then T1
|
||||
// must immediately precede T2 in the list of OpSwitch Target operands.
|
||||
if ((switch_inst->operands().size() < j + 2) ||
|
||||
(case_fall_through != switch_inst->GetOperandAs<uint32_t>(j + 2))) {
|
||||
return _.diag(SPV_ERROR_INVALID_CFG)
|
||||
<< "Case construct that targets " << _.getIdName(target)
|
||||
<< " has branches to the case construct that targets "
|
||||
<< _.getIdName(case_fall_through)
|
||||
<< ", but does not immediately precede it in the "
|
||||
"OpSwitch's "
|
||||
"target list";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Each case construct must be branched to by at most one other case
|
||||
// construct.
|
||||
for (const auto& pair : num_fall_through_targeted) {
|
||||
if (pair.second > 1) {
|
||||
return _.diag(SPV_ERROR_INVALID_CFG)
|
||||
<< "Multiple case constructs have branches to the case construct "
|
||||
"that targets "
|
||||
<< _.getIdName(pair.first);
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t StructuredControlFlowChecks(
|
||||
const ValidationState_t& _, Function* function,
|
||||
const vector<pair<uint32_t, uint32_t>>& back_edges) {
|
||||
@ -261,16 +408,15 @@ spv_result_t StructuredControlFlowChecks(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(umar): an OpSwitch block dominates all its defined case
|
||||
// constructs
|
||||
// TODO(umar): each case construct has at most one branch to another
|
||||
// case construct
|
||||
// TODO(umar): each case construct is branched to by at most one other
|
||||
// case construct
|
||||
// TODO(umar): if Target T1 branches to Target T2, or if Target T1
|
||||
// branches to the Default and the Default branches to Target T2, then
|
||||
// T1 must immediately precede T2 in the list of the OpSwitch Target
|
||||
// operands
|
||||
// Checks rules for case constructs.
|
||||
if (construct.type() == ConstructType::kSelection &&
|
||||
header->terminator()->opcode() == SpvOpSwitch) {
|
||||
const auto terminator = header->terminator();
|
||||
if (auto error =
|
||||
StructuredSwitchChecks(_, *function, terminator, header, merge)) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
@ -1450,6 +1450,289 @@ OpFunctionEnd
|
||||
"selection header <ID>"));
|
||||
}
|
||||
|
||||
/// TODO(umar): Switch instructions
|
||||
TEST_F(ValidateCFG, SwitchDefaultOnly) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpFunction %1 None %4
|
||||
%6 = OpLabel
|
||||
OpSelectionMerge %7 None
|
||||
OpSwitch %3 %7
|
||||
%7 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, SwitchSingleCase) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpFunction %1 None %4
|
||||
%6 = OpLabel
|
||||
OpSelectionMerge %7 None
|
||||
OpSwitch %3 %7 0 %8
|
||||
%8 = OpLabel
|
||||
OpBranch %7
|
||||
%7 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, MultipleFallThroughBlocks) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12
|
||||
%10 = OpLabel
|
||||
OpBranchConditional %6 %11 %12
|
||||
%11 = OpLabel
|
||||
OpBranch %9
|
||||
%12 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Case construct that targets 10 has branches to multiple other case "
|
||||
"construct targets 12 and 11"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, MultipleFallThroughToDefault) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12
|
||||
%10 = OpLabel
|
||||
OpBranch %9
|
||||
%11 = OpLabel
|
||||
OpBranch %10
|
||||
%12 = OpLabel
|
||||
OpBranch %10
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Multiple case constructs have branches to the case construct "
|
||||
"that targets 10"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, MultipleFallThroughToNonDefault) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12
|
||||
%10 = OpLabel
|
||||
OpBranch %12
|
||||
%11 = OpLabel
|
||||
OpBranch %12
|
||||
%12 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Multiple case constructs have branches to the case construct "
|
||||
"that targets 12"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, DuplicateTargetWithFallThrough) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %10 1 %11
|
||||
%10 = OpLabel
|
||||
OpBranch %11
|
||||
%11 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, WrongOperandList) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12
|
||||
%10 = OpLabel
|
||||
OpBranch %9
|
||||
%12 = OpLabel
|
||||
OpBranch %11
|
||||
%11 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Case construct that targets 12 has branches to the case "
|
||||
"construct that targets 11, but does not immediately "
|
||||
"precede it in the OpSwitch's target list"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, WrongOperandListThroughDefault) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12
|
||||
%10 = OpLabel
|
||||
OpBranch %11
|
||||
%12 = OpLabel
|
||||
OpBranch %10
|
||||
%11 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Case construct that targets 12 has branches to the case "
|
||||
"construct that targets 11, but does not immediately "
|
||||
"precede it in the OpSwitch's target list"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateCFG, WrongOperandListNotLast) {
|
||||
std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
OpMemoryModel Logical GLSL450
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeInt 32 0
|
||||
%3 = OpConstant %2 0
|
||||
%4 = OpTypeFunction %1
|
||||
%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
%7 = OpFunction %1 None %4
|
||||
%8 = OpLabel
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %3 %10 0 %11 1 %12 2 %13
|
||||
%10 = OpLabel
|
||||
OpBranch %9
|
||||
%12 = OpLabel
|
||||
OpBranch %11
|
||||
%11 = OpLabel
|
||||
OpBranch %9
|
||||
%13 = OpLabel
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(text);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Case construct that targets 12 has branches to the case "
|
||||
"construct that targets 11, but does not immediately "
|
||||
"precede it in the OpSwitch's target list"));
|
||||
}
|
||||
|
||||
/// TODO(umar): Nested CFG constructs
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user