diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index e3c766292..afc065676 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -190,6 +190,35 @@ spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _, return SPV_SUCCESS; } +bool ContainsOpaqueType(ValidationState_t& _, const Instruction* str) { + const size_t elem_type_index = 1; + uint32_t elem_type_id; + Instruction* elem_type; + + if (spvOpcodeIsBaseOpaqueType(str->opcode())) { + return true; + } + + switch (str->opcode()) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + elem_type_id = str->GetOperandAs(elem_type_index); + elem_type = _.FindDef(elem_type_id); + return ContainsOpaqueType(_, elem_type); + case SpvOpTypeStruct: + for (size_t member_type_index = 1; + member_type_index < str->operands().size(); ++member_type_index) { + auto member_type_id = str->GetOperandAs(member_type_index); + auto member_type = _.FindDef(member_type_id); + if (ContainsOpaqueType(_, member_type)) return true; + } + break; + default: + break; + } + return false; +} + spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) { const uint32_t struct_id = inst->GetOperandAs(0); for (size_t member_type_index = 1; @@ -289,6 +318,14 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) { if (num_builtin_members > 0) { _.RegisterStructTypeWithBuiltInMember(struct_id); } + + if (spvIsVulkanEnv(_.context()->target_env) && + !_.options()->before_hlsl_legalization && ContainsOpaqueType(_, inst)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "In " << spvLogStringForEnv(_.context()->target_env) + << ", OpTypeStruct must not contain an opaque type."; + } + return SPV_SUCCESS; } diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index 73e2ce114..883949044 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp @@ -937,6 +937,26 @@ TEST_F(ValidateIdWithMessage, OpTypeStructMemberTypeBad) { "a type.")); } +TEST_F(ValidateIdWithMessage, OpTypeStructOpaqueTypeBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + %1 = OpTypeSampler + %2 = OpTypeStruct %1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeStruct must not contain an opaque type")); +} + TEST_F(ValidateIdWithMessage, OpTypePointerGood) { std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0