Fix switch case construct validation (#5852)

* Fix switch case construct validation

Fixes https://crbug.com/tint/372311599

* Stop using block depth in switch validation and instead use the more
  robust structured exit logic from the switch construct
  * This is valid because the function has already handled the
    additional valid cases for case constructs

* formatting
This commit is contained in:
alan-baker 2024-10-16 11:09:21 -04:00 committed by GitHub
parent 2ea729062b
commit a832c13331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 104 additions and 13 deletions

View File

@ -468,13 +468,13 @@ std::string ConstructErrorString(const Construct& construct,
// headed by |target_block| branches to multiple case constructs.
spv_result_t FindCaseFallThrough(
ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through,
const BasicBlock* merge, const std::unordered_set<uint32_t>& case_targets,
Function* function) {
const Construct& switch_construct,
const std::unordered_set<uint32_t>& case_targets) {
const auto* merge = switch_construct.exit_block();
std::vector<BasicBlock*> stack;
stack.push_back(target_block);
std::unordered_set<const BasicBlock*> visited;
bool target_reachable = target_block->structurally_reachable();
int target_depth = function->GetBlockDepth(target_block);
while (!stack.empty()) {
auto block = stack.back();
stack.pop_back();
@ -492,9 +492,14 @@ spv_result_t FindCaseFallThrough(
} else {
// Exiting the case construct to non-merge block.
if (!case_targets.count(block->id())) {
int depth = function->GetBlockDepth(block);
if ((depth < target_depth) ||
(depth == target_depth && block->is_type(kBlockTypeContinue))) {
// We have already filtered out the following:
// * The switch's merge
// * Other case targets
// * Blocks in the same case construct
//
// So the only remaining valid branches are the structured exits from
// the overall selection construct of the switch.
if (switch_construct.IsStructuredExit(_, block)) {
continue;
}
@ -526,9 +531,10 @@ spv_result_t FindCaseFallThrough(
}
spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
const Instruction* switch_inst,
const BasicBlock* header,
const BasicBlock* merge) {
const Construct& switch_construct) {
const auto* header = switch_construct.entry_block();
const auto* merge = switch_construct.exit_block();
const auto* switch_inst = header->terminator();
std::unordered_set<uint32_t> case_targets;
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
@ -546,6 +552,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
break;
}
}
std::unordered_map<uint32_t, uint32_t> seen_to_fall_through;
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
@ -566,7 +573,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
}
if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
merge, case_targets, function)) {
switch_construct, case_targets)) {
return error;
}
@ -866,9 +873,7 @@ spv_result_t StructuredControlFlowChecks(
// Checks rules for case constructs.
if (construct.type() == ConstructType::kSelection &&
header->terminator()->opcode() == spv::Op::OpSwitch) {
const auto terminator = header->terminator();
if (auto error =
StructuredSwitchChecks(_, function, terminator, header, merge)) {
if (auto error = StructuredSwitchChecks(_, function, construct)) {
return error;
}
}

View File

@ -5155,6 +5155,92 @@ TEST_F(ValidateCFG, StructurallyUnreachableContinuePredecessor) {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, FullyLoopPrecedingSwitchToContinue) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpName %main "main"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bool = OpTypeBool
%true = OpConstantTrue %bool
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%main = OpFunction %void None %3
%4 = OpLabel
OpBranch %7
%7 = OpLabel
OpLoopMerge %8 %6 None
OpBranch %5
%5 = OpLabel
OpSelectionMerge %9 None
OpBranchConditional %true %10 %9
%10 = OpLabel
OpSelectionMerge %16 None
OpSwitch %int_0 %13
%13 = OpLabel
OpBranch %19
%19 = OpLabel
OpLoopMerge %20 %18 None
OpBranch %17
%17 = OpLabel
OpReturn
%18 = OpLabel
OpBranch %19
%20 = OpLabel
OpSelectionMerge %23 None
OpSwitch %int_1 %21
%21 = OpLabel
OpBranch %6
%23 = OpLabel
OpBranch %16
%16 = OpLabel
OpBranch %9
%9 = OpLabel
OpBranch %6
%6 = OpLabel
OpBranch %7
%8 = OpLabel
OpUnreachable
OpFunctionEnd
)";
CompileSuccessfully(text);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateCFG, CaseBreak) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpName %main "main"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bool = OpTypeBool
%true = OpConstantTrue %bool
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%main = OpFunction %void None %3
%4 = OpLabel
OpSelectionMerge %merge None
OpSwitch %int_1 %case 2 %merge
%case = OpLabel
OpBranch %merge
%merge = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(text);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
} // namespace
} // namespace val
} // namespace spvtools