From 68e36ec7e9da5300b5d183be543162928aee5776 Mon Sep 17 00:00:00 2001 From: Ehsan Nasiri Date: Wed, 11 Jan 2017 15:03:53 -0500 Subject: [PATCH] Introducing a new flow for running the Validator. We are adding a new API which can be called to run the SPIR-V validator, and retrieve the ValidationState_t object. This is very useful for unit testing. I have also added basic unit tests that demonstrate usage of this flow and ease of use to verify correctness. --- source/validate.cpp | 74 +++++++++++------ source/validate.h | 11 +++ test/val/CMakeLists.txt | 6 ++ test/val/val_fixtures.cpp | 8 ++ test/val/val_fixtures.h | 7 ++ test/val/val_validation_state_test.cpp | 106 +++++++++++++++++++++++++ 6 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 test/val/val_validation_state_test.cpp diff --git a/source/validate.cpp b/source/validate.cpp index 45ae5a457..fe6c4c4ac 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -182,64 +182,58 @@ spv_result_t spvValidate(const spv_const_context context, return spvValidateBinary(context, binary->code, binary->wordCount, pDiagnostic); } -spv_result_t spvValidateBinary(const spv_const_context context, - const uint32_t* words, const size_t num_words, - spv_diagnostic* pDiagnostic) { - spv_context_t hijack_context = *context; +spv_result_t ValidateBinaryUsingContextAndValidationState( + const spv_context_t& context, const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic, ValidationState_t* vstate) { spv_const_binary binary = new spv_const_binary_t{words, num_words}; - if (pDiagnostic) { - *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); - } spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(binary, &endian)) { - return libspirv::DiagnosticStream(position, hijack_context.consumer, + return libspirv::DiagnosticStream(position, context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(binary, endian, &header)) { - return libspirv::DiagnosticStream(position, hijack_context.consumer, + return libspirv::DiagnosticStream(position, context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } // NOTE: Parse the module and perform inline validation checks. These // checks do not require the the knowledge of the whole module. - ValidationState_t vstate(&hijack_context); - if (auto error = spvBinaryParse(&hijack_context, &vstate, words, num_words, + if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader, ProcessInstruction, pDiagnostic)) return error; - if (vstate.in_function_body()) - return vstate.diag(SPV_ERROR_INVALID_LAYOUT) + if (vstate->in_function_body()) + return vstate->diag(SPV_ERROR_INVALID_LAYOUT) << "Missing OpFunctionEnd at end of module."; // TODO(umar): Add validation checks which require the parsing of the entire // module. Use the information from the ProcessInstruction pass to make the // checks. - if (vstate.unresolved_forward_id_count() > 0) { + if (vstate->unresolved_forward_id_count() > 0) { stringstream ss; - vector ids = vstate.UnresolvedForwardIds(); + vector ids = vstate->UnresolvedForwardIds(); transform(begin(ids), end(ids), ostream_iterator(ss, " "), - bind(&ValidationState_t::getIdName, std::ref(vstate), _1)); + bind(&ValidationState_t::getIdName, std::ref(*vstate), _1)); auto id_str = ss.str(); - return vstate.diag(SPV_ERROR_INVALID_ID) + return vstate->diag(SPV_ERROR_INVALID_ID) << "The following forward referenced IDs have not been defined:\n" << id_str.substr(0, id_str.size() - 1); } // CFG checks are performed after the binary has been parsed // and the CFGPass has collected information about the control flow - if (auto error = PerformCfgChecks(vstate)) return error; - if (auto error = UpdateIdUse(vstate)) return error; - if (auto error = CheckIdDefinitionDominateUse(vstate)) return error; + if (auto error = PerformCfgChecks(*vstate)) return error; + if (auto error = UpdateIdUse(*vstate)) return error; + if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error; // NOTE: Copy each instruction for easier processing std::vector instructions; @@ -258,7 +252,39 @@ spv_result_t spvValidateBinary(const spv_const_context context, position.index = SPV_INDEX_INSTRUCTION; return spvValidateIDs(instructions.data(), instructions.size(), - hijack_context.opcode_table, - hijack_context.operand_table, - hijack_context.ext_inst_table, vstate, &position); + context.opcode_table, + context.operand_table, + context.ext_inst_table, *vstate, &position); } + +spv_result_t spvValidateBinary(const spv_const_context context, + const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + // Create the ValidationState using the context. + ValidationState_t vstate(&hijack_context); + + return ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, &vstate); +} + +spv_result_t spvtools::ValidateBinaryAndKeepValidationState( + const spv_const_context context, const uint32_t* words, + const size_t num_words, spv_diagnostic* pDiagnostic, + std::unique_ptr* vstate) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + vstate->reset(new ValidationState_t(&hijack_context)); + return ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, vstate->get()); +} + diff --git a/source/validate.h b/source/validate.h index 258b0ebbb..3237e6cc5 100644 --- a/source/validate.h +++ b/source/validate.h @@ -190,4 +190,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, spv_position position, const spvtools::MessageConsumer& consumer); +namespace spvtools { +// Performs validation for the SPIRV-V module binary. +// The main difference between this API and spvValidateBinary is that the +// "Validation State" is not destroyed upon function return; it lives on and is +// pointed to by the vstate unique_ptr. +spv_result_t ValidateBinaryAndKeepValidationState( + const spv_const_context context, const uint32_t* words, + const size_t num_words, spv_diagnostic* pDiagnostic, + std::unique_ptr* vstate); +} // namespace spvtools + #endif // LIBSPIRV_VALIDATE_H_ diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index 5ac130256..1e4becdd6 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt @@ -75,3 +75,9 @@ add_spvtools_unittest(TARGET val_limits LIBS ${SPIRV_TOOLS} ) +add_spvtools_unittest(TARGET val_validation_state + SRCS val_validation_state_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} +) + diff --git a/test/val/val_fixtures.cpp b/test/val/val_fixtures.cpp index 2db99a89f..b0174335b 100644 --- a/test/val/val_fixtures.cpp +++ b/test/val/val_fixtures.cpp @@ -67,6 +67,14 @@ spv_result_t ValidateBase::ValidateInstructions(spv_target_env env) { &diagnostic_); } +template +spv_result_t ValidateBase::ValidateAndRetrieveValidationState( + spv_target_env env) { + return spvtools::ValidateBinaryAndKeepValidationState( + ScopedContext(env).context, get_const_binary()->code, + get_const_binary()->wordCount, &diagnostic_, &vstate_); +} + template std::string ValidateBase::getDiagnosticString() { return std::string(diagnostic_->error); diff --git a/test/val/val_fixtures.h b/test/val/val_fixtures.h index bb4fe1871..1d947058a 100644 --- a/test/val/val_fixtures.h +++ b/test/val/val_fixtures.h @@ -18,6 +18,7 @@ #define LIBSPIRV_TEST_VALIDATE_FIXTURES_H_ #include "unit_spirv.h" +#include "source/val/validation_state.h" namespace spvtest { @@ -45,11 +46,17 @@ class ValidateBase : public ::testing::Test, // spvValidate function spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + // Performs validation. Returns the status and stores validation state into + // the vstate_ member. + spv_result_t ValidateAndRetrieveValidationState( + spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + std::string getDiagnosticString(); spv_position_t getErrorPosition(); spv_binary binary_; spv_diagnostic diagnostic_; + std::unique_ptr vstate_; }; } #endif diff --git a/test/val/val_validation_state_test.cpp b/test/val/val_validation_state_test.cpp new file mode 100644 index 000000000..5eb09f774 --- /dev/null +++ b/test/val/val_validation_state_test.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Basic tests for the ValidationState_t datastructure. + +#include + +#include "gmock/gmock.h" +#include "unit_spirv.h" +#include "val_fixtures.h" + +namespace { + +using std::string; +using ::testing::HasSubstr; + +using ValidationStateTest = spvtest::ValidateBase; + +const char header[] = + " OpCapability Shader" + " OpCapability Linkage" + " OpMemoryModel Logical GLSL450 "; + +const char kVoidFVoid[] = + " %void = OpTypeVoid" + " %void_f = OpTypeFunction %void" + " %func = OpFunction %void None %void_f" + " %label = OpLabel" + " OpReturn" + " OpFunctionEnd "; + +// Tests that the instruction count in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumInstructions) { + string spirv = string(header) + "%int = OpTypeInt 32 0"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size()); +} + +// Tests that the number of global variables in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumGlobalVars) { + string spirv = string(header) + R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + %var_1 = OpVariable %_ptr_int Input + %var_2 = OpVariable %_ptr_int Input + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(2), vstate_->num_global_vars()); +} + +// Tests that the number of local variables in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumLocalVars) { + string spirv = string(header) + R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + %var_1 = OpVariable %_ptr_int Function + %var_2 = OpVariable %_ptr_int Function + %var_3 = OpVariable %_ptr_int Function + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(3), vstate_->num_local_vars()); +} + +// Tests that the "id bound" in ValidationState is correct. +TEST_F(ValidationStateTest, CheckIdBound) { + string spirv = string(header) + R"( + %int = OpTypeInt 32 0 + %voidt = OpTypeVoid + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(3), vstate_->getIdBound()); +} + +// Tests that the entry_points in ValidationState is correct. +TEST_F(ValidationStateTest, CheckEntryPoints) { + string spirv = string(header) + " OpEntryPoint Vertex %func \"shader\"" + + string(kVoidFVoid); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(size_t(1), vstate_->entry_points().size()); + EXPECT_EQ(SpvOpFunction, + vstate_->FindDef(vstate_->entry_points()[0])->opcode()); +} + +} // anonymous namespace