mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 11:10:05 +00:00
Fix switch case construct validation (#5852)
* Fix switch case construct validation Fixes https://crbug.com/tint/372311599 * Stop using block depth in switch validation and instead use the more robust structured exit logic from the switch construct * This is valid because the function has already handled the additional valid cases for case constructs * formatting
This commit is contained in:
parent
2ea729062b
commit
a832c13331
@ -468,13 +468,13 @@ std::string ConstructErrorString(const Construct& construct,
|
|||||||
// headed by |target_block| branches to multiple case constructs.
|
// headed by |target_block| branches to multiple case constructs.
|
||||||
spv_result_t FindCaseFallThrough(
|
spv_result_t FindCaseFallThrough(
|
||||||
ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through,
|
ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through,
|
||||||
const BasicBlock* merge, const std::unordered_set<uint32_t>& case_targets,
|
const Construct& switch_construct,
|
||||||
Function* function) {
|
const std::unordered_set<uint32_t>& case_targets) {
|
||||||
|
const auto* merge = switch_construct.exit_block();
|
||||||
std::vector<BasicBlock*> stack;
|
std::vector<BasicBlock*> stack;
|
||||||
stack.push_back(target_block);
|
stack.push_back(target_block);
|
||||||
std::unordered_set<const BasicBlock*> visited;
|
std::unordered_set<const BasicBlock*> visited;
|
||||||
bool target_reachable = target_block->structurally_reachable();
|
bool target_reachable = target_block->structurally_reachable();
|
||||||
int target_depth = function->GetBlockDepth(target_block);
|
|
||||||
while (!stack.empty()) {
|
while (!stack.empty()) {
|
||||||
auto block = stack.back();
|
auto block = stack.back();
|
||||||
stack.pop_back();
|
stack.pop_back();
|
||||||
@ -492,9 +492,14 @@ spv_result_t FindCaseFallThrough(
|
|||||||
} else {
|
} else {
|
||||||
// Exiting the case construct to non-merge block.
|
// Exiting the case construct to non-merge block.
|
||||||
if (!case_targets.count(block->id())) {
|
if (!case_targets.count(block->id())) {
|
||||||
int depth = function->GetBlockDepth(block);
|
// We have already filtered out the following:
|
||||||
if ((depth < target_depth) ||
|
// * The switch's merge
|
||||||
(depth == target_depth && block->is_type(kBlockTypeContinue))) {
|
// * Other case targets
|
||||||
|
// * Blocks in the same case construct
|
||||||
|
//
|
||||||
|
// So the only remaining valid branches are the structured exits from
|
||||||
|
// the overall selection construct of the switch.
|
||||||
|
if (switch_construct.IsStructuredExit(_, block)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,9 +531,10 @@ spv_result_t FindCaseFallThrough(
|
|||||||
}
|
}
|
||||||
|
|
||||||
spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
|
spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
|
||||||
const Instruction* switch_inst,
|
const Construct& switch_construct) {
|
||||||
const BasicBlock* header,
|
const auto* header = switch_construct.entry_block();
|
||||||
const BasicBlock* merge) {
|
const auto* merge = switch_construct.exit_block();
|
||||||
|
const auto* switch_inst = header->terminator();
|
||||||
std::unordered_set<uint32_t> case_targets;
|
std::unordered_set<uint32_t> case_targets;
|
||||||
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
||||||
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
||||||
@ -546,6 +552,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<uint32_t, uint32_t> seen_to_fall_through;
|
std::unordered_map<uint32_t, uint32_t> seen_to_fall_through;
|
||||||
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
|
||||||
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
|
||||||
@ -566,7 +573,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
|
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
|
||||||
merge, case_targets, function)) {
|
switch_construct, case_targets)) {
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -866,9 +873,7 @@ spv_result_t StructuredControlFlowChecks(
|
|||||||
// Checks rules for case constructs.
|
// Checks rules for case constructs.
|
||||||
if (construct.type() == ConstructType::kSelection &&
|
if (construct.type() == ConstructType::kSelection &&
|
||||||
header->terminator()->opcode() == spv::Op::OpSwitch) {
|
header->terminator()->opcode() == spv::Op::OpSwitch) {
|
||||||
const auto terminator = header->terminator();
|
if (auto error = StructuredSwitchChecks(_, function, construct)) {
|
||||||
if (auto error =
|
|
||||||
StructuredSwitchChecks(_, function, terminator, header, merge)) {
|
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5155,6 +5155,92 @@ TEST_F(ValidateCFG, StructurallyUnreachableContinuePredecessor) {
|
|||||||
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateCFG, FullyLoopPrecedingSwitchToContinue) {
|
||||||
|
const std::string text = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint Fragment %main "main"
|
||||||
|
OpExecutionMode %main OriginUpperLeft
|
||||||
|
OpName %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%3 = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%true = OpConstantTrue %bool
|
||||||
|
%int = OpTypeInt 32 1
|
||||||
|
%int_0 = OpConstant %int 0
|
||||||
|
%int_1 = OpConstant %int 1
|
||||||
|
%main = OpFunction %void None %3
|
||||||
|
%4 = OpLabel
|
||||||
|
OpBranch %7
|
||||||
|
%7 = OpLabel
|
||||||
|
OpLoopMerge %8 %6 None
|
||||||
|
OpBranch %5
|
||||||
|
%5 = OpLabel
|
||||||
|
OpSelectionMerge %9 None
|
||||||
|
OpBranchConditional %true %10 %9
|
||||||
|
%10 = OpLabel
|
||||||
|
OpSelectionMerge %16 None
|
||||||
|
OpSwitch %int_0 %13
|
||||||
|
%13 = OpLabel
|
||||||
|
OpBranch %19
|
||||||
|
%19 = OpLabel
|
||||||
|
OpLoopMerge %20 %18 None
|
||||||
|
OpBranch %17
|
||||||
|
%17 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
%18 = OpLabel
|
||||||
|
OpBranch %19
|
||||||
|
%20 = OpLabel
|
||||||
|
OpSelectionMerge %23 None
|
||||||
|
OpSwitch %int_1 %21
|
||||||
|
%21 = OpLabel
|
||||||
|
OpBranch %6
|
||||||
|
%23 = OpLabel
|
||||||
|
OpBranch %16
|
||||||
|
%16 = OpLabel
|
||||||
|
OpBranch %9
|
||||||
|
%9 = OpLabel
|
||||||
|
OpBranch %6
|
||||||
|
%6 = OpLabel
|
||||||
|
OpBranch %7
|
||||||
|
%8 = OpLabel
|
||||||
|
OpUnreachable
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
CompileSuccessfully(text);
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateCFG, CaseBreak) {
|
||||||
|
const std::string text = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint Fragment %main "main"
|
||||||
|
OpExecutionMode %main OriginUpperLeft
|
||||||
|
OpName %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%3 = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%true = OpConstantTrue %bool
|
||||||
|
%int = OpTypeInt 32 1
|
||||||
|
%int_0 = OpConstant %int 0
|
||||||
|
%int_1 = OpConstant %int 1
|
||||||
|
%main = OpFunction %void None %3
|
||||||
|
%4 = OpLabel
|
||||||
|
OpSelectionMerge %merge None
|
||||||
|
OpSwitch %int_1 %case 2 %merge
|
||||||
|
%case = OpLabel
|
||||||
|
OpBranch %merge
|
||||||
|
%merge = OpLabel
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
CompileSuccessfully(text);
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace val
|
} // namespace val
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
Loading…
Reference in New Issue
Block a user