mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-23 12:10:06 +00:00
Introducing a new flow for running the Validator.
We are adding a new API which can be called to run the SPIR-V validator, and retrieve the ValidationState_t object. This is very useful for unit testing. I have also added basic unit tests that demonstrate usage of this flow and ease of use to verify correctness.
This commit is contained in:
parent
d5e4f06eec
commit
68e36ec7e9
@ -182,64 +182,58 @@ spv_result_t spvValidate(const spv_const_context context,
|
||||
return spvValidateBinary(context, binary->code, binary->wordCount,
|
||||
pDiagnostic);
|
||||
}
|
||||
spv_result_t spvValidateBinary(const spv_const_context context,
|
||||
const uint32_t* words, const size_t num_words,
|
||||
spv_diagnostic* pDiagnostic) {
|
||||
spv_context_t hijack_context = *context;
|
||||
|
||||
spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
const spv_context_t& context, const uint32_t* words, const size_t num_words,
|
||||
spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
|
||||
spv_const_binary binary = new spv_const_binary_t{words, num_words};
|
||||
if (pDiagnostic) {
|
||||
*pDiagnostic = nullptr;
|
||||
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
|
||||
}
|
||||
|
||||
spv_endianness_t endian;
|
||||
spv_position_t position = {};
|
||||
if (spvBinaryEndianness(binary, &endian)) {
|
||||
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
||||
return libspirv::DiagnosticStream(position, context.consumer,
|
||||
SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid SPIR-V magic number.";
|
||||
}
|
||||
|
||||
spv_header_t header;
|
||||
if (spvBinaryHeaderGet(binary, endian, &header)) {
|
||||
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
||||
return libspirv::DiagnosticStream(position, context.consumer,
|
||||
SPV_ERROR_INVALID_BINARY)
|
||||
<< "Invalid SPIR-V header.";
|
||||
}
|
||||
|
||||
// NOTE: Parse the module and perform inline validation checks. These
|
||||
// checks do not require the the knowledge of the whole module.
|
||||
ValidationState_t vstate(&hijack_context);
|
||||
if (auto error = spvBinaryParse(&hijack_context, &vstate, words, num_words,
|
||||
if (auto error = spvBinaryParse(&context, vstate, words, num_words,
|
||||
setHeader, ProcessInstruction, pDiagnostic))
|
||||
return error;
|
||||
|
||||
if (vstate.in_function_body())
|
||||
return vstate.diag(SPV_ERROR_INVALID_LAYOUT)
|
||||
if (vstate->in_function_body())
|
||||
return vstate->diag(SPV_ERROR_INVALID_LAYOUT)
|
||||
<< "Missing OpFunctionEnd at end of module.";
|
||||
|
||||
// TODO(umar): Add validation checks which require the parsing of the entire
|
||||
// module. Use the information from the ProcessInstruction pass to make the
|
||||
// checks.
|
||||
if (vstate.unresolved_forward_id_count() > 0) {
|
||||
if (vstate->unresolved_forward_id_count() > 0) {
|
||||
stringstream ss;
|
||||
vector<uint32_t> ids = vstate.UnresolvedForwardIds();
|
||||
vector<uint32_t> ids = vstate->UnresolvedForwardIds();
|
||||
|
||||
transform(begin(ids), end(ids), ostream_iterator<string>(ss, " "),
|
||||
bind(&ValidationState_t::getIdName, std::ref(vstate), _1));
|
||||
bind(&ValidationState_t::getIdName, std::ref(*vstate), _1));
|
||||
|
||||
auto id_str = ss.str();
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID)
|
||||
return vstate->diag(SPV_ERROR_INVALID_ID)
|
||||
<< "The following forward referenced IDs have not been defined:\n"
|
||||
<< id_str.substr(0, id_str.size() - 1);
|
||||
}
|
||||
|
||||
// CFG checks are performed after the binary has been parsed
|
||||
// and the CFGPass has collected information about the control flow
|
||||
if (auto error = PerformCfgChecks(vstate)) return error;
|
||||
if (auto error = UpdateIdUse(vstate)) return error;
|
||||
if (auto error = CheckIdDefinitionDominateUse(vstate)) return error;
|
||||
if (auto error = PerformCfgChecks(*vstate)) return error;
|
||||
if (auto error = UpdateIdUse(*vstate)) return error;
|
||||
if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
|
||||
|
||||
// NOTE: Copy each instruction for easier processing
|
||||
std::vector<spv_instruction_t> instructions;
|
||||
@ -258,7 +252,39 @@ spv_result_t spvValidateBinary(const spv_const_context context,
|
||||
|
||||
position.index = SPV_INDEX_INSTRUCTION;
|
||||
return spvValidateIDs(instructions.data(), instructions.size(),
|
||||
hijack_context.opcode_table,
|
||||
hijack_context.operand_table,
|
||||
hijack_context.ext_inst_table, vstate, &position);
|
||||
context.opcode_table,
|
||||
context.operand_table,
|
||||
context.ext_inst_table, *vstate, &position);
|
||||
}
|
||||
|
||||
spv_result_t spvValidateBinary(const spv_const_context context,
|
||||
const uint32_t* words, const size_t num_words,
|
||||
spv_diagnostic* pDiagnostic) {
|
||||
spv_context_t hijack_context = *context;
|
||||
if (pDiagnostic) {
|
||||
*pDiagnostic = nullptr;
|
||||
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
|
||||
}
|
||||
|
||||
// Create the ValidationState using the context.
|
||||
ValidationState_t vstate(&hijack_context);
|
||||
|
||||
return ValidateBinaryUsingContextAndValidationState(
|
||||
hijack_context, words, num_words, pDiagnostic, &vstate);
|
||||
}
|
||||
|
||||
spv_result_t spvtools::ValidateBinaryAndKeepValidationState(
|
||||
const spv_const_context context, const uint32_t* words,
|
||||
const size_t num_words, spv_diagnostic* pDiagnostic,
|
||||
std::unique_ptr<ValidationState_t>* vstate) {
|
||||
spv_context_t hijack_context = *context;
|
||||
if (pDiagnostic) {
|
||||
*pDiagnostic = nullptr;
|
||||
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
|
||||
}
|
||||
|
||||
vstate->reset(new ValidationState_t(&hijack_context));
|
||||
return ValidateBinaryUsingContextAndValidationState(
|
||||
hijack_context, words, num_words, pDiagnostic, vstate->get());
|
||||
}
|
||||
|
||||
|
@ -190,4 +190,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions,
|
||||
spv_position position,
|
||||
const spvtools::MessageConsumer& consumer);
|
||||
|
||||
namespace spvtools {
|
||||
// Performs validation for the SPIRV-V module binary.
|
||||
// The main difference between this API and spvValidateBinary is that the
|
||||
// "Validation State" is not destroyed upon function return; it lives on and is
|
||||
// pointed to by the vstate unique_ptr.
|
||||
spv_result_t ValidateBinaryAndKeepValidationState(
|
||||
const spv_const_context context, const uint32_t* words,
|
||||
const size_t num_words, spv_diagnostic* pDiagnostic,
|
||||
std::unique_ptr<libspirv::ValidationState_t>* vstate);
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // LIBSPIRV_VALIDATE_H_
|
||||
|
@ -75,3 +75,9 @@ add_spvtools_unittest(TARGET val_limits
|
||||
LIBS ${SPIRV_TOOLS}
|
||||
)
|
||||
|
||||
add_spvtools_unittest(TARGET val_validation_state
|
||||
SRCS val_validation_state_test.cpp
|
||||
${VAL_TEST_COMMON_SRCS}
|
||||
LIBS ${SPIRV_TOOLS}
|
||||
)
|
||||
|
||||
|
@ -67,6 +67,14 @@ spv_result_t ValidateBase<T>::ValidateInstructions(spv_target_env env) {
|
||||
&diagnostic_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
spv_result_t ValidateBase<T>::ValidateAndRetrieveValidationState(
|
||||
spv_target_env env) {
|
||||
return spvtools::ValidateBinaryAndKeepValidationState(
|
||||
ScopedContext(env).context, get_const_binary()->code,
|
||||
get_const_binary()->wordCount, &diagnostic_, &vstate_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string ValidateBase<T>::getDiagnosticString() {
|
||||
return std::string(diagnostic_->error);
|
||||
|
@ -18,6 +18,7 @@
|
||||
#define LIBSPIRV_TEST_VALIDATE_FIXTURES_H_
|
||||
|
||||
#include "unit_spirv.h"
|
||||
#include "source/val/validation_state.h"
|
||||
|
||||
namespace spvtest {
|
||||
|
||||
@ -45,11 +46,17 @@ class ValidateBase : public ::testing::Test,
|
||||
// spvValidate function
|
||||
spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
|
||||
|
||||
// Performs validation. Returns the status and stores validation state into
|
||||
// the vstate_ member.
|
||||
spv_result_t ValidateAndRetrieveValidationState(
|
||||
spv_target_env env = SPV_ENV_UNIVERSAL_1_0);
|
||||
|
||||
std::string getDiagnosticString();
|
||||
spv_position_t getErrorPosition();
|
||||
|
||||
spv_binary binary_;
|
||||
spv_diagnostic diagnostic_;
|
||||
std::unique_ptr<libspirv::ValidationState_t> vstate_;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
106
test/val/val_validation_state_test.cpp
Normal file
106
test/val/val_validation_state_test.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (c) 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Basic tests for the ValidationState_t datastructure.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "unit_spirv.h"
|
||||
#include "val_fixtures.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using std::string;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
using ValidationStateTest = spvtest::ValidateBase<bool>;
|
||||
|
||||
const char header[] =
|
||||
" OpCapability Shader"
|
||||
" OpCapability Linkage"
|
||||
" OpMemoryModel Logical GLSL450 ";
|
||||
|
||||
const char kVoidFVoid[] =
|
||||
" %void = OpTypeVoid"
|
||||
" %void_f = OpTypeFunction %void"
|
||||
" %func = OpFunction %void None %void_f"
|
||||
" %label = OpLabel"
|
||||
" OpReturn"
|
||||
" OpFunctionEnd ";
|
||||
|
||||
// Tests that the instruction count in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumInstructions) {
|
||||
string spirv = string(header) + "%int = OpTypeInt 32 0";
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
|
||||
}
|
||||
|
||||
// Tests that the number of global variables in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumGlobalVars) {
|
||||
string spirv = string(header) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%_ptr_int = OpTypePointer Input %int
|
||||
%var_1 = OpVariable %_ptr_int Input
|
||||
%var_2 = OpVariable %_ptr_int Input
|
||||
)";
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(unsigned(2), vstate_->num_global_vars());
|
||||
}
|
||||
|
||||
// Tests that the number of local variables in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumLocalVars) {
|
||||
string spirv = string(header) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%_ptr_int = OpTypePointer Function %int
|
||||
%voidt = OpTypeVoid
|
||||
%funct = OpTypeFunction %voidt
|
||||
%main = OpFunction %voidt None %funct
|
||||
%entry = OpLabel
|
||||
%var_1 = OpVariable %_ptr_int Function
|
||||
%var_2 = OpVariable %_ptr_int Function
|
||||
%var_3 = OpVariable %_ptr_int Function
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(unsigned(3), vstate_->num_local_vars());
|
||||
}
|
||||
|
||||
// Tests that the "id bound" in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckIdBound) {
|
||||
string spirv = string(header) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%voidt = OpTypeVoid
|
||||
)";
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(unsigned(3), vstate_->getIdBound());
|
||||
}
|
||||
|
||||
// Tests that the entry_points in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckEntryPoints) {
|
||||
string spirv = string(header) + " OpEntryPoint Vertex %func \"shader\"" +
|
||||
string(kVoidFVoid);
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(size_t(1), vstate_->entry_points().size());
|
||||
EXPECT_EQ(SpvOpFunction,
|
||||
vstate_->FindDef(vstate_->entry_points()[0])->opcode());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
Loading…
Reference in New Issue
Block a user