diff --git a/Android.mk b/Android.mk index a6488f890..2ffa35897 100644 --- a/Android.mk +++ b/Android.mk @@ -51,6 +51,7 @@ SPVTOOLS_SRC_FILES := \ source/val/validate_debug.cpp \ source/val/validate_decorations.cpp \ source/val/validate_derivatives.cpp \ + source/val/validate_extensions.cpp \ source/val/validate_ext_inst.cpp \ source/val/validate_execution_limitations.cpp \ source/val/validate_function.cpp \ diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 47b368fbd..c5f07a314 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -305,6 +305,7 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_extensions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ext_inst.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_execution_limitations.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_function.cpp diff --git a/source/val/validate.cpp b/source/val/validate.cpp index faf88a571..257c317bf 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp @@ -328,6 +328,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( // Miscellaneous if (auto error = DebugPass(*vstate, &instruction)) return error; if (auto error = AnnotationPass(*vstate, &instruction)) return error; + if (auto error = ExtensionPass(*vstate, &instruction)) return error; if (auto error = ExtInstPass(*vstate, &instruction)) return error; if (auto error = ModeSettingPass(*vstate, &instruction)) return error; if (auto error = TypePass(*vstate, &instruction)) return error; diff --git a/source/val/validate.h b/source/val/validate.h index 13ec33550..30228f09d 100644 --- a/source/val/validate.h +++ b/source/val/validate.h @@ -173,6 +173,9 @@ spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of literal numbers. spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst); +/// Validates correctness of ExtInstImport instructions. +spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst); + /// Validates correctness of ExtInst instructions. spv_result_t ExtInstPass(ValidationState_t& _, const Instruction* inst); diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp new file mode 100644 index 000000000..f3b2305b7 --- /dev/null +++ b/source/val/validate_extensions.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of extension SPIR-V instructions. + +#include "source/val/validate.h" + +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateExtInstImport(ValidationState_t& _, + const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + const auto name_id = 1; + const std::string name(reinterpret_cast( + inst->words().data() + inst->operands()[name_id].offset)); + if (name != "GLSL.std.450") { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For WebGPU, the only valid parameter to OpExtInstImport is " + "\"GLSL.std.450\"."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + if (opcode == SpvOpExtInstImport) return ValidateExtInstImport(_, inst); + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/test/val/val_webgpu_test.cpp b/test/val/val_webgpu_test.cpp index cc3f90593..ba5919850 100644 --- a/test/val/val_webgpu_test.cpp +++ b/test/val/val_webgpu_test.cpp @@ -219,6 +219,45 @@ TEST_F(ValidateWebGPU, NonVulkanKHRMemoryModelBad) { "environment.\n OpMemoryModel Logical GLSL450\n")); } +TEST_F(ValidateWebGPU, WhitelistedExtendedInstructionsImportGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Vertex %func "shader" +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func = OpFunction %void None %void_f +%label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateWebGPU, NonWhitelistedExtendedInstructionsImportBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Logical VulkanKHR +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, the only valid parameter to " + "OpExtInstImport is \"GLSL.std.450\".\n %1 = " + "OpExtInstImport \"OpenCL.std\"\n")); +} + } // namespace } // namespace val } // namespace spvtools