diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp index 8fe30a881..fe79dde1e 100644 --- a/source/val/validate_cfg.cpp +++ b/source/val/validate_cfg.cpp @@ -112,6 +112,19 @@ spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) { return SPV_SUCCESS; } +spv_result_t ValidateBranch(ValidationState_t& _, const Instruction* inst) { + // target operands must be OpLabel + const auto id = inst->GetOperandAs(0); + const auto target = _.FindDef(id); + if (!target || SpvOpLabel != target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "'Target Label' operands for OpBranch must be the ID " + "of an OpLabel instruction"; + } + + return SPV_SUCCESS; +} + spv_result_t ValidateBranchConditional(ValidationState_t& _, const Instruction* inst) { // num_operands is either 3 or 5 --- if 5, the last two need to be literal @@ -155,6 +168,26 @@ spv_result_t ValidateBranchConditional(ValidationState_t& _, return SPV_SUCCESS; } +spv_result_t ValidateSwitch(ValidationState_t& _, const Instruction* inst) { + const auto num_operands = inst->operands().size(); + // At least two operands (selector, default), any more than that are + // literal/target. + + // target operands must be OpLabel + for (size_t i = 2; i < num_operands; i += 2) { + // literal, id + const auto id = inst->GetOperandAs(i + 1); + const auto target = _.FindDef(id); + if (!target || SpvOpLabel != target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "'Target Label' operands for OpSwitch must be IDs of an " + "OpLabel instruction"; + } + } + + return SPV_SUCCESS; +} + spv_result_t ValidateReturnValue(ValidationState_t& _, const Instruction* inst) { const auto value_id = inst->GetOperandAs(0); @@ -764,12 +797,18 @@ spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) { case SpvOpPhi: if (auto error = ValidatePhi(_, inst)) return error; break; + case SpvOpBranch: + if (auto error = ValidateBranch(_, inst)) return error; + break; case SpvOpBranchConditional: if (auto error = ValidateBranchConditional(_, inst)) return error; break; case SpvOpReturnValue: if (auto error = ValidateReturnValue(_, inst)) return error; break; + case SpvOpSwitch: + if (auto error = ValidateSwitch(_, inst)) return error; + break; default: break; } diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp index aed0a5788..fae570278 100644 --- a/test/val/val_cfg_test.cpp +++ b/test/val/val_cfg_test.cpp @@ -493,13 +493,10 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceValue) { str += "OpFunctionEnd\n"; CompileSuccessfully(str); - ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - MatchesRegex("Block\\(s\\) \\{11\\[%11\\]\\} are referenced but not " - "defined in function .\\[%Main\\]\n %Main = OpFunction " - "%void None %10\n")) - << str; + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("'Target Label' operands for OpBranch must " + "be the ID of an OpLabel instruction")); } TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { @@ -2060,6 +2057,57 @@ TEST_F(ValidateCFG, KernelWithPhiPtr) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateCFG, SwitchTargetMustBeLabel) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "foo" + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %6 = OpLabel + %7 = OpCopyObject %uint %uint_0 + OpSelectionMerge %8 None + OpSwitch %uint_0 %8 0 %7 + %8 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("'Target Label' operands for OpSwitch must " + "be IDs of an OpLabel instruction")); +} + +TEST_F(ValidateCFG, BranchTargetMustBeLabel) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "foo" + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %2 = OpLabel + %7 = OpCopyObject %uint %uint_0 + OpBranch %7 + %8 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("'Target Label' operands for OpBranch must " + "be the ID of an OpLabel instruction")); +} + /// TODO(umar): Nested CFG constructs } // namespace