Introducing a new flow for running the Validator.

We are adding a new API which can be called to run the SPIR-V validator,
and retrieve the ValidationState_t object. This is very useful for
unit testing.

I have also added basic unit tests that demonstrate usage of this flow
and ease of use to verify correctness.
This commit is contained in:
Ehsan Nasiri 2017-01-11 15:03:53 -05:00 committed by David Neto
parent d5e4f06eec
commit 68e36ec7e9
6 changed files with 188 additions and 24 deletions

View File

@ -182,64 +182,58 @@ spv_result_t spvValidate(const spv_const_context context,
return spvValidateBinary(context, binary->code, binary->wordCount,
pDiagnostic);
}
spv_result_t spvValidateBinary(const spv_const_context context,
const uint32_t* words, const size_t num_words,
spv_diagnostic* pDiagnostic) {
spv_context_t hijack_context = *context;
spv_result_t ValidateBinaryUsingContextAndValidationState(
const spv_context_t& context, const uint32_t* words, const size_t num_words,
spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
spv_const_binary binary = new spv_const_binary_t{words, num_words};
if (pDiagnostic) {
*pDiagnostic = nullptr;
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
}
spv_endianness_t endian;
spv_position_t position = {};
if (spvBinaryEndianness(binary, &endian)) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
return libspirv::DiagnosticStream(position, context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V magic number.";
}
spv_header_t header;
if (spvBinaryHeaderGet(binary, endian, &header)) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
return libspirv::DiagnosticStream(position, context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V header.";
}
// NOTE: Parse the module and perform inline validation checks. These
// checks do not require the the knowledge of the whole module.
ValidationState_t vstate(&hijack_context);
if (auto error = spvBinaryParse(&hijack_context, &vstate, words, num_words,
if (auto error = spvBinaryParse(&context, vstate, words, num_words,
setHeader, ProcessInstruction, pDiagnostic))
return error;
if (vstate.in_function_body())
return vstate.diag(SPV_ERROR_INVALID_LAYOUT)
if (vstate->in_function_body())
return vstate->diag(SPV_ERROR_INVALID_LAYOUT)
<< "Missing OpFunctionEnd at end of module.";
// TODO(umar): Add validation checks which require the parsing of the entire
// module. Use the information from the ProcessInstruction pass to make the
// checks.
if (vstate.unresolved_forward_id_count() > 0) {
if (vstate->unresolved_forward_id_count() > 0) {
stringstream ss;
vector<uint32_t> ids = vstate.UnresolvedForwardIds();
vector<uint32_t> ids = vstate->UnresolvedForwardIds();
transform(begin(ids), end(ids), ostream_iterator<string>(ss, " "),
bind(&ValidationState_t::getIdName, std::ref(vstate), _1));
bind(&ValidationState_t::getIdName, std::ref(*vstate), _1));
auto id_str = ss.str();
return vstate.diag(SPV_ERROR_INVALID_ID)
return vstate->diag(SPV_ERROR_INVALID_ID)
<< "The following forward referenced IDs have not been defined:\n"
<< id_str.substr(0, id_str.size() - 1);
}
// CFG checks are performed after the binary has been parsed
// and the CFGPass has collected information about the control flow
if (auto error = PerformCfgChecks(vstate)) return error;
if (auto error = UpdateIdUse(vstate)) return error;
if (auto error = CheckIdDefinitionDominateUse(vstate)) return error;
if (auto error = PerformCfgChecks(*vstate)) return error;
if (auto error = UpdateIdUse(*vstate)) return error;
if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
// NOTE: Copy each instruction for easier processing
std::vector<spv_instruction_t> instructions;
@ -258,7 +252,39 @@ spv_result_t spvValidateBinary(const spv_const_context context,
position.index = SPV_INDEX_INSTRUCTION;
return spvValidateIDs(instructions.data(), instructions.size(),
hijack_context.opcode_table,
hijack_context.operand_table,
hijack_context.ext_inst_table, vstate, &position);
context.opcode_table,
context.operand_table,
context.ext_inst_table, *vstate, &position);
}
spv_result_t spvValidateBinary(const spv_const_context context,
const uint32_t* words, const size_t num_words,
spv_diagnostic* pDiagnostic) {
spv_context_t hijack_context = *context;
if (pDiagnostic) {
*pDiagnostic = nullptr;
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
}
// Create the ValidationState using the context.
ValidationState_t vstate(&hijack_context);
return ValidateBinaryUsingContextAndValidationState(
hijack_context, words, num_words, pDiagnostic, &vstate);
}
spv_result_t spvtools::ValidateBinaryAndKeepValidationState(
const spv_const_context context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<ValidationState_t>* vstate) {
spv_context_t hijack_context = *context;
if (pDiagnostic) {
*pDiagnostic = nullptr;
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
}
vstate->reset(new ValidationState_t(&hijack_context));
return ValidateBinaryUsingContextAndValidationState(
hijack_context, words, num_words, pDiagnostic, vstate->get());
}

View File

@ -190,4 +190,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions,
spv_position position,
const spvtools::MessageConsumer& consumer);
namespace spvtools {
// Performs validation for the SPIRV-V module binary.
// The main difference between this API and spvValidateBinary is that the
// "Validation State" is not destroyed upon function return; it lives on and is
// pointed to by the vstate unique_ptr.
spv_result_t ValidateBinaryAndKeepValidationState(
const spv_const_context context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<libspirv::ValidationState_t>* vstate);
} // namespace spvtools
#endif // LIBSPIRV_VALIDATE_H_

View File

@ -75,3 +75,9 @@ add_spvtools_unittest(TARGET val_limits
LIBS ${SPIRV_TOOLS}
)
add_spvtools_unittest(TARGET val_validation_state
SRCS val_validation_state_test.cpp
${VAL_TEST_COMMON_SRCS}
LIBS ${SPIRV_TOOLS}
)

View File

@ -67,6 +67,14 @@ spv_result_t ValidateBase<T>::ValidateInstructions(spv_target_env env) {
&diagnostic_);
}
template <typename T>
spv_result_t ValidateBase<T>::ValidateAndRetrieveValidationState(
spv_target_env env) {
return spvtools::ValidateBinaryAndKeepValidationState(
ScopedContext(env).context, get_const_binary()->code,
get_const_binary()->wordCount, &diagnostic_, &vstate_);
}
template <typename T>
std::string ValidateBase<T>::getDiagnosticString() {
return std::string(diagnostic_->error);

View File

@ -18,6 +18,7 @@
#define LIBSPIRV_TEST_VALIDATE_FIXTURES_H_
#include "unit_spirv.h"
#include "source/val/validation_state.h"
namespace spvtest {
@ -45,11 +46,17 @@ class ValidateBase : public ::testing::Test,
// spvValidate function
spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
// Performs validation. Returns the status and stores validation state into
// the vstate_ member.
spv_result_t ValidateAndRetrieveValidationState(
spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
std::string getDiagnosticString();
spv_position_t getErrorPosition();
spv_binary binary_;
spv_diagnostic diagnostic_;
std::unique_ptr<libspirv::ValidationState_t> vstate_;
};
}
#endif

View File

@ -0,0 +1,106 @@
// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Basic tests for the ValidationState_t datastructure.
#include <string>
#include "gmock/gmock.h"
#include "unit_spirv.h"
#include "val_fixtures.h"
namespace {
using std::string;
using ::testing::HasSubstr;
using ValidationStateTest = spvtest::ValidateBase<bool>;
const char header[] =
" OpCapability Shader"
" OpCapability Linkage"
" OpMemoryModel Logical GLSL450 ";
const char kVoidFVoid[] =
" %void = OpTypeVoid"
" %void_f = OpTypeFunction %void"
" %func = OpFunction %void None %void_f"
" %label = OpLabel"
" OpReturn"
" OpFunctionEnd ";
// Tests that the instruction count in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumInstructions) {
string spirv = string(header) + "%int = OpTypeInt 32 0";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
}
// Tests that the number of global variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumGlobalVars) {
string spirv = string(header) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Input %int
%var_1 = OpVariable %_ptr_int Input
%var_2 = OpVariable %_ptr_int Input
)";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(unsigned(2), vstate_->num_global_vars());
}
// Tests that the number of local variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumLocalVars) {
string spirv = string(header) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Function %int
%voidt = OpTypeVoid
%funct = OpTypeFunction %voidt
%main = OpFunction %voidt None %funct
%entry = OpLabel
%var_1 = OpVariable %_ptr_int Function
%var_2 = OpVariable %_ptr_int Function
%var_3 = OpVariable %_ptr_int Function
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(unsigned(3), vstate_->num_local_vars());
}
// Tests that the "id bound" in ValidationState is correct.
TEST_F(ValidationStateTest, CheckIdBound) {
string spirv = string(header) + R"(
%int = OpTypeInt 32 0
%voidt = OpTypeVoid
)";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(unsigned(3), vstate_->getIdBound());
}
// Tests that the entry_points in ValidationState is correct.
TEST_F(ValidationStateTest, CheckEntryPoints) {
string spirv = string(header) + " OpEntryPoint Vertex %func \"shader\"" +
string(kVoidFVoid);
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(size_t(1), vstate_->entry_points().size());
EXPECT_EQ(SpvOpFunction,
vstate_->FindDef(vstate_->entry_points()[0])->opcode());
}
} // anonymous namespace