diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index 799ba3d54..39f00fedc 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp @@ -14,6 +14,8 @@ #include "source/val/validate.h" +#include + #include "source/opcode.h" #include "source/val/instruction.h" #include "source/val/validation_state.h" @@ -38,6 +40,27 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { << "' does not match the Function Type's return type '" << _.getIdName(return_id) << "'."; } + + for (auto& pair : inst->uses()) { + const auto* use = pair.first; + const std::vector acceptable = { + SpvOpFunctionCall, + SpvOpEntryPoint, + SpvOpEnqueueKernel, + SpvOpGetKernelNDrangeSubGroupCount, + SpvOpGetKernelNDrangeMaxSubGroupSize, + SpvOpGetKernelWorkGroupSize, + SpvOpGetKernelPreferredWorkGroupSizeMultiple, + SpvOpGetKernelLocalSizeForSubgroupCount, + SpvOpGetKernelMaxNumSubgroups}; + if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == + acceptable.end()) { + return _.diag(SPV_ERROR_INVALID_ID, use) + << "Invalid use of function result id " << _.getIdName(inst->id()) + << "."; + } + } + return SPV_SUCCESS; } diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp index 14d098dd2..045166925 100644 --- a/test/val/val_cfg_test.cpp +++ b/test/val/val_cfg_test.cpp @@ -477,9 +477,10 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceEntryBlock) { TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceValue) { Block entry("entry"); + entry.SetBody("%undef = OpUndef %voidt\n"); Block bad("bad"); Block end("end", SpvOpReturn); - Block badvalue("func"); // This referenes the function name. + Block badvalue("undef"); // This referenes the OpUndef. std::string str = header(GetParam()) + nameOps("entry", "bad", std::make_pair("func", "Main")) + types_consts() + @@ -493,11 +494,10 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceValue) { CompileSuccessfully(str); ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - MatchesRegex("Block\\(s\\) \\{.\\[Main\\]\\} are referenced but not " - "defined in function .\\[Main\\]\n" - " %Main = OpFunction %void None %10\n")) + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block\\(s\\) \\{..\\} are referenced but not " + "defined in function .\\[Main\\]\n" + " %Main = OpFunction %void None %10\n")) << str; } diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index c0d865331..eee503d62 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp @@ -2192,13 +2192,14 @@ TEST_F(ValidateIdWithMessage, OpStoreObjectGood) { %6 = OpVariable %3 UniformConstant %7 = OpFunction %1 None %4 %8 = OpLabel - OpStore %6 %7 +%9 = OpUndef %1 + OpStore %6 %9 OpReturn OpFunctionEnd)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpStore Object '7's type is void.")); + HasSubstr("OpStore Object '9's type is void.")); } TEST_F(ValidateIdWithMessage, OpStoreTypeBad) { std::string spirv = kGLSL450MemoryModel + R"( @@ -3577,6 +3578,22 @@ OpFunctionEnd)"; HasSubstr("OpFunction Function Type '2' is not a function type.")); } +TEST_F(ValidateIdWithMessage, OpFunctionUseBad) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturnValue %3 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of function result id 3.")); +} + TEST_F(ValidateIdWithMessage, OpFunctionParameterGood) { std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid