diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp index f74bd1257..ad106a6f8 100644 --- a/source/val/validate_builtins.cpp +++ b/source/val/validate_builtins.cpp @@ -168,7 +168,7 @@ class BuiltInsValidator { spv_result_t ValidateVertexIndexAtDefinition(const Decoration& decoration, const Instruction& inst); spv_result_t ValidateVertexIdOrInstanceIdAtDefinition( - const Instruction& inst); + const Decoration& decoration, const Instruction& inst); spv_result_t ValidateWorkgroupSizeAtDefinition(const Decoration& decoration, const Instruction& inst); // Used for GlobalInvocationId, LocalInvocationId, NumWorkgroups, WorkgroupId. @@ -2088,8 +2088,11 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtDefinition( } spv_result_t BuiltInsValidator::ValidateVertexIdOrInstanceIdAtDefinition( - const Instruction& inst) { - if (spvIsVulkanEnv(_.context()->target_env)) { + const Decoration& decoration, const Instruction& inst) { + const SpvBuiltIn label = SpvBuiltIn(decoration.params()[0]); + bool allow_instance_id = _.HasCapability(SpvCapabilityRayTracingNV) && + label == SpvBuiltInInstanceId; + if (spvIsVulkanEnv(_.context()->target_env) && !allow_instance_id) { return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "Vulkan spec doesn't allow BuiltIn VertexId/InstanceId " "to be used."; @@ -2455,7 +2458,7 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition( } case SpvBuiltInVertexId: case SpvBuiltInInstanceId: { - return ValidateVertexIdOrInstanceIdAtDefinition(inst); + return ValidateVertexIdOrInstanceIdAtDefinition(decoration, inst); } case SpvBuiltInLocalInvocationIndex: case SpvBuiltInWorkDim: diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp index 1ad53ba42..eff21f9bd 100644 --- a/test/val/val_builtins_test.cpp +++ b/test/val/val_builtins_test.cpp @@ -2169,6 +2169,46 @@ OpFunctionEnd "be declared when using BuiltIn FragDepth")); } +TEST_F(ValidateBuiltIns, AllowInstanceIdWithIntersectionShader) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.capabilities_ += R"( +OpCapability RayTracingNV +)"; + + generator.extensions_ = R"( +OpExtension "SPV_NV_ray_tracing" +)"; + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn InstanceId +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %u32 +%input_ptr = OpTypePointer Input %input_type +%input = OpVariable %input_ptr Input +)"; + + EntryPoint entry_point; + entry_point.name = "main_d_r"; + entry_point.execution_model = "IntersectionNV"; + entry_point.interfaces = "%input"; + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + } // namespace } // namespace val } // namespace spvtools