diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp index 98eb98c24..aaba32490 100644 --- a/source/val/validate_builtins.cpp +++ b/source/val/validate_builtins.cpp @@ -207,6 +207,11 @@ class BuiltInsValidator { const Instruction& referenced_inst, const Instruction& referenced_from_inst); + spv_result_t ValidateInstanceIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + spv_result_t ValidateInstanceIndexAtReference( const Decoration& decoration, const Instruction& built_in_inst, const Instruction& referenced_inst, @@ -2098,6 +2103,43 @@ spv_result_t BuiltInsValidator::ValidateVertexIdOrInstanceIdAtDefinition( "to be used."; } + if (label == SpvBuiltInInstanceId) { + return ValidateInstanceIdAtReference(decoration, inst, inst, inst); + } + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateInstanceIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelIntersectionNV: + case SpvExecutionModelClosestHitNV: + case SpvExecutionModelAnyHitNV: + // Do nothing, valid stages + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InstanceId to be used " + "only with IntersectionNV, ClosestHitNV and AnyHitNV " + "execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst); + break; + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateInstanceIdAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + return SPV_SUCCESS; } diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp index 9e3798b13..b1458c935 100644 --- a/test/val/val_builtins_test.cpp +++ b/test/val/val_builtins_test.cpp @@ -2209,6 +2209,44 @@ OpFunctionEnd EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); } +TEST_F(ValidateBuiltIns, DisallowInstanceIdWithRayGenShader) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.capabilities_ += R"( +OpCapability RayTracingNV +)"; + + generator.extensions_ = R"( +OpExtension "SPV_NV_ray_tracing" +)"; + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn InstanceId +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %u32 +%input_ptr = OpTypePointer Input %input_type +%input_ptr_u32 = OpTypePointer Input %u32 +%input = OpVariable %input_ptr Input +)"; + + EntryPoint entry_point; + entry_point.name = "main_d_r"; + entry_point.execution_model = "RayGenerationNV"; + entry_point.interfaces = "%input"; + entry_point.body = R"( +%input_member = OpAccessChain %input_ptr_u32 %input %u32_0 +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn InstanceId to be used " + "only with IntersectionNV, ClosestHitNV and " + "AnyHitNV execution models")); +} + } // namespace } // namespace val } // namespace spvtools