Validate the input to Optimizer::Run (#1799)

* Run the validator in the optimization fuzzers.

The optimizers assumes that the input to the optimizer is valid.  Since
the fuzzers do not check that the input is valid before passing the
spir-v to the optimizer, we are getting a few errors.

The solution is to run the validator in the optimizer to validate the
input.

For the legalization passes, we need to add an extra option to the
validator to accept certain types of variable pointers, even if the
capability is not given.  At the same time, we changed the option
"--legalize-hlsl" to relax the validator in the same way instead of
turning it off.
This commit is contained in:
Steven Perron 2018-08-08 11:16:19 -04:00 committed by GitHub
parent 3a20879f4d
commit 5c8b4f5a1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 180 additions and 105 deletions

View File

@ -158,6 +158,9 @@ class Optimizer {
// It's allowed to alias |original_binary| to the start of |optimized_binary|.
bool Run(const uint32_t* original_binary, size_t original_binary_size,
std::vector<uint32_t>* optimized_binary) const;
bool Run(const uint32_t* original_binary, size_t original_binary_size,
std::vector<uint32_t>* optimized_binary,
const ValidatorOptions& options) const;
// Returns a vector of strings with all the pass names added to this
// optimizer's pass manager. These strings are valid until the associated

View File

@ -14,6 +14,8 @@
#include "spirv-tools/libspirv.hpp"
#include <iostream>
#include <string>
#include <utility>
#include <vector>
@ -115,8 +117,15 @@ bool SpirvTools::Validate(const uint32_t* binary,
bool SpirvTools::Validate(const uint32_t* binary, const size_t binary_size,
const ValidatorOptions& options) const {
spv_const_binary_t the_binary{binary, binary_size};
return spvValidateWithOptions(impl_->context, options, &the_binary,
nullptr) == SPV_SUCCESS;
spv_diagnostic diagnostic = nullptr;
bool valid = spvValidateWithOptions(impl_->context, options, &the_binary,
&diagnostic) == SPV_SUCCESS;
if (!valid && impl_->context->consumer) {
impl_->context->consumer.operator()(
SPV_MSG_ERROR, nullptr, diagnostic->position, diagnostic->error);
}
spvDiagnosticDestroy(diagnostic);
return valid;
}
} // namespace spvtools

View File

@ -448,6 +448,20 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
bool Optimizer::Run(const uint32_t* original_binary,
const size_t original_binary_size,
std::vector<uint32_t>* optimized_binary) const {
return Run(original_binary, original_binary_size, optimized_binary,
ValidatorOptions());
}
bool Optimizer::Run(const uint32_t* original_binary,
const size_t original_binary_size,
std::vector<uint32_t>* optimized_binary,
const ValidatorOptions& options) const {
spvtools::SpirvTools tools(impl_->target_env);
tools.SetMessageConsumer(impl_->pass_manager.consumer());
if (!tools.Validate(original_binary, original_binary_size, options)) {
return false;
}
std::unique_ptr<opt::IRContext> context = BuildModule(
impl_->target_env, consumer(), original_binary, original_binary_size);
if (context == nullptr) return false;

View File

@ -27,6 +27,15 @@ namespace {
using ::testing::ContainerEq;
using ::testing::HasSubstr;
// Return a string that contains the minimum instructions needed to form
// a valid module. Other instructions can be appended to this string.
std::string Header() {
return R"(OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
}
TEST(CppInterface, SuccessfulRoundTrip) {
const std::string input_text = "%2 = OpSizeOf %1 %3\n";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
@ -130,10 +139,6 @@ TEST(CppInterface, DisassembleOverloads) {
}
TEST(CppInterface, SuccessfulValidation) {
const std::string input_text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450)";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
int invocation_count = 0;
t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
@ -142,19 +147,15 @@ TEST(CppInterface, SuccessfulValidation) {
});
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text, &binary));
EXPECT_TRUE(t.Assemble(Header(), &binary));
EXPECT_TRUE(t.Validate(binary));
EXPECT_EQ(0, invocation_count);
}
TEST(CppInterface, ValidateOverloads) {
const std::string input_text = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450)";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text, &binary));
EXPECT_TRUE(t.Assemble(Header(), &binary));
{ EXPECT_TRUE(t.Validate(binary)); }
{ EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
@ -182,11 +183,9 @@ TEST(CppInterface, ValidateEmptyModule) {
// with the given number of members.
std::string MakeModuleHavingStruct(int num_members) {
std::stringstream os;
os << R"(OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeInt 32 0
%2 = OpTypeStruct)";
os << Header();
os << R"(%1 = OpTypeInt 32 0
%2 = OpTypeStruct)";
for (int i = 0; i < num_members; i++) os << " %1";
return os.str();
}
@ -220,8 +219,8 @@ TEST(CppInterface, ValidateWithOptionsFail) {
// Checks that after running the given optimizer |opt| on the given |original|
// source code, we can get the given |optimized| source code.
void CheckOptimization(const char* original, const char* optimized,
const Optimizer& opt) {
void CheckOptimization(const std::string& original,
const std::string& optimized, const Optimizer& opt) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> original_binary;
ASSERT_TRUE(t.Assemble(original, &original_binary));
@ -242,29 +241,31 @@ TEST(CppInterface, OptimizeEmptyModule) {
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass());
EXPECT_TRUE(o.Run(binary.data(), binary.size(), &binary));
// Fails to validate.
EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary));
}
TEST(CppInterface, OptimizeModifiedModule) {
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass());
CheckOptimization("OpSource GLSL 450", "", o);
CheckOptimization(Header() + "OpSource GLSL 450", Header(), o);
}
TEST(CppInterface, OptimizeMulitplePasses) {
const char* original_text =
"OpSource GLSL 450 "
"OpDecorate %true SpecId 1 "
"%bool = OpTypeBool "
"%true = OpSpecConstantTrue %bool";
std::string original_text = Header() +
"OpSource GLSL 450 "
"OpDecorate %true SpecId 1 "
"%bool = OpTypeBool "
"%true = OpSpecConstantTrue %bool";
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass())
.RegisterPass(CreateFreezeSpecConstantValuePass());
const char* expected_text =
"%bool = OpTypeBool\n"
"%true = OpConstantTrue %bool\n";
std::string expected_text = Header() +
"%bool = OpTypeBool\n"
"%true = OpConstantTrue %bool\n";
CheckOptimization(original_text, expected_text, o);
}
@ -279,7 +280,7 @@ TEST(CppInterface, OptimizeReassignPassToken) {
token = CreateStripDebugInfoPass();
CheckOptimization(
"OpSource GLSL 450", "",
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
}
@ -288,7 +289,7 @@ TEST(CppInterface, OptimizeMoveConstructPassToken) {
Optimizer::PassToken token2(std::move(token1));
CheckOptimization(
"OpSource GLSL 450", "",
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
}
@ -298,14 +299,14 @@ TEST(CppInterface, OptimizeMoveAssignPassToken) {
token2 = std::move(token1);
CheckOptimization(
"OpSource GLSL 450", "",
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
}
TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
ASSERT_TRUE(t.Assemble("OpSource GLSL 450", &binary));
ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary));
EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
.RegisterPass(CreateStripDebugInfoPass())
@ -313,7 +314,7 @@ TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
std::string optimized_text;
EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
EXPECT_EQ("", optimized_text);
EXPECT_EQ(Header(), optimized_text);
}
// TODO(antiagainst): tests for SetMessageConsumer().

View File

@ -107,7 +107,7 @@ OpMemoryModel Logical Simple
OpEntryPoint GLCompute %100 "main"
%200 = OpTypeVoid
%300 = OpTypeFunction %200
%100 = OpFunction %300 None %200
%100 = OpFunction %200 None %300
%400 = OpLabel
OpReturn
OpFunctionEnd
@ -136,7 +136,7 @@ OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%1 = OpFunction %3 None %2
%1 = OpFunction %2 None %3
%4 = OpLabel
OpReturn
OpFunctionEnd
@ -151,7 +151,7 @@ OpMemoryModel Logical Simple
OpEntryPoint GLCompute %100 "main"
%200 = OpTypeVoid
%300 = OpTypeFunction %200
%100 = OpFunction %300 None %200
%100 = OpFunction %200 None %300
%400 = OpLabel
OpReturn
OpFunctionEnd
@ -185,7 +185,7 @@ OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%1 = OpFunction %3 None %2
%1 = OpFunction %2 None %3
%4 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -26,10 +26,20 @@ namespace {
using ::testing::Eq;
// Return a string that contains the minimum instructions needed to form
// a valid module. Other instructions can be appended to this string.
std::string Header() {
return R"(OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
}
TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) {
SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
std::vector<uint32_t> binary_in;
tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in);
tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid",
&binary_in);
Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
opt.RegisterPass(CreateNullPass());
@ -38,13 +48,15 @@ TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) {
std::string disassembly;
tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly);
EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
EXPECT_THAT(disassembly,
Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
}
TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) {
SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
std::vector<uint32_t> binary_in;
tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in);
tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid",
&binary_in);
Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
opt.RegisterPass(CreateStripDebugInfoPass());
@ -53,7 +65,7 @@ TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) {
std::string disassembly;
tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly);
EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n"));
EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n"));
}
TEST(Optimizer, CanRunNullPassWithAliasedVectors) {
@ -73,7 +85,7 @@ TEST(Optimizer, CanRunNullPassWithAliasedVectors) {
TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) {
SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
std::vector<uint32_t> binary;
tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
opt.RegisterPass(CreateNullPass());
@ -88,13 +100,14 @@ TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) {
std::string disassembly;
tools.Disassemble(binary.data(), binary.size(), &disassembly);
EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
EXPECT_THAT(disassembly,
Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n"));
}
TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) {
SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
std::vector<uint32_t> binary;
tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary);
Optimizer opt(SPV_ENV_UNIVERSAL_1_0);
opt.RegisterPass(CreateStripDebugInfoPass());
@ -102,7 +115,7 @@ TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) {
std::string disassembly;
tools.Disassemble(binary.data(), binary.size(), &disassembly);
EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n"));
EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n"));
}
TEST(Optimizer, CanValidateFlags) {

View File

@ -40,8 +40,8 @@ endfunction()
if (NOT ${SPIRV_SKIP_EXECUTABLES})
add_spvtools_tool(TARGET spirv-as SRCS as/as.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp util/cli_consumer.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-stats
SRCS stats/stats.cpp

View File

@ -29,6 +29,7 @@
#include "source/spirv_validator_options.h"
#include "spirv-tools/optimizer.hpp"
#include "tools/io.h"
#include "tools/util/cli_consumer.h"
namespace {
@ -180,12 +181,12 @@ Options (in lexicographical order):
early return in a loop.
--legalize-hlsl
Runs a series of optimizations that attempts to take SPIR-V
generated by and HLSL front-end and generate legal Vulkan SPIR-V.
generated by an HLSL front-end and generates legal Vulkan SPIR-V.
The optimizations are:
%s
Note this does not guarantee legal code. This option implies
--skip-validation.
Note this does not guarantee legal code. This option passes the
option --relax-logical-pointer to the validator.
--local-redundancy-elimination
Looks for instructions in the same basic block that compute the
same value, and deletes the redundant ones.
@ -396,7 +397,7 @@ bool ReadFlagsFromFile(const char* oconfig_flag,
OptStatus ParseFlags(int argc, const char** argv,
spvtools::Optimizer* optimizer, const char** in_file,
const char** out_file, spv_validator_options options,
const char** out_file, spvtools::ValidatorOptions* options,
bool* skip_validator);
// Parses and handles the -Oconfig flag. |prog_name| contains the name of
@ -485,7 +486,7 @@ std::string CanonicalizeFlag(const char** argv, int argc, int* argi) {
// success.
OptStatus ParseFlags(int argc, const char** argv,
spvtools::Optimizer* optimizer, const char** in_file,
const char** out_file, spv_validator_options options,
const char** out_file, spvtools::ValidatorOptions* options,
bool* skip_validator) {
std::vector<std::string> pass_flags;
for (int argi = 1; argi < argc; ++argi) {
@ -527,7 +528,7 @@ OptStatus ParseFlags(int argc, const char** argv,
} else if (0 == strcmp(cur_arg, "--time-report")) {
optimizer->SetTimeReport(&std::cerr);
} else if (0 == strcmp(cur_arg, "--relax-struct-store")) {
options->relax_struct_store = true;
options->SetRelaxStructStore(true);
} else {
// Some passes used to accept the form '--pass arg', canonicalize them
// to '--pass=arg'.
@ -535,7 +536,9 @@ OptStatus ParseFlags(int argc, const char** argv,
// If we were requested to legalize SPIR-V generated from the HLSL
// front-end, skip validation.
if (0 == strcmp(cur_arg, "--legalize-hlsl")) *skip_validator = true;
if (0 == strcmp(cur_arg, "--legalize-hlsl")) {
options->SetRelaxLogicalPointer(true);
}
}
} else {
if (!*in_file) {
@ -563,18 +566,13 @@ int main(int argc, const char** argv) {
bool skip_validator = false;
spv_target_env target_env = kDefaultEnvironment;
spv_validator_options options = spvValidatorOptionsCreate();
spvtools::ValidatorOptions options;
spvtools::Optimizer optimizer(target_env);
optimizer.SetMessageConsumer([](spv_message_level_t level, const char* source,
const spv_position_t& position,
const char* message) {
std::cerr << spvtools::StringifyMessage(level, source, position, message)
<< std::endl;
});
optimizer.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
OptStatus status = ParseFlags(argc, argv, &optimizer, &in_file, &out_file,
options, &skip_validator);
&options, &skip_validator);
if (status.action == OPT_STOP) {
return status.code;
@ -590,28 +588,9 @@ int main(int argc, const char** argv) {
return 1;
}
if (!skip_validator) {
// Let's do validation first.
spv_context context = spvContextCreate(target_env);
spv_diagnostic diagnostic = nullptr;
spv_const_binary_t binary_struct = {binary.data(), binary.size()};
spv_result_t error =
spvValidateWithOptions(context, options, &binary_struct, &diagnostic);
if (error) {
spvDiagnosticPrint(diagnostic);
spvDiagnosticDestroy(diagnostic);
spvValidatorOptionsDestroy(options);
spvContextDestroy(context);
return error;
}
spvDiagnosticDestroy(diagnostic);
spvValidatorOptionsDestroy(options);
spvContextDestroy(context);
}
// By using the same vector as input and output, we save time in the case
// that there was no change.
bool ok = optimizer.Run(binary.data(), binary.size(), &binary);
bool ok = optimizer.Run(binary.data(), binary.size(), &binary, options);
if (!WriteFile<uint32_t>(out_file, "wb", binary.data(), binary.size())) {
return 1;

View File

@ -0,0 +1,45 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tools/util/cli_consumer.h"
#include <iostream>
namespace spvtools {
namespace utils {
void CLIMessageConsumer(spv_message_level_t level, const char*,
const spv_position_t& position, const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: line " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_WARNING:
std::cout << "warning: line " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cout << "info: line " << position.index << ": " << message
<< std::endl;
break;
default:
break;
}
}
} // namespace utils
} // namespace spvtools

31
tools/util/cli_consumer.h Normal file
View File

@ -0,0 +1,31 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_UTIL_CLI_CONSUMMER_H_
#define SOURCE_UTIL_CLI_CONSUMMER_H_
#include <include/spirv-tools/libspirv.h>
namespace spvtools {
namespace utils {
// A message consumer that can be used by command line tools like spirv-opt and
// spirv-val to display messages.
void CLIMessageConsumer(spv_message_level_t level, const char*,
const spv_position_t& position, const char* message);
} // namespace utils
} // namespace spvtools
#endif // SOURCE_UTIL_CLI_CONSUMMER_H_

View File

@ -22,6 +22,7 @@
#include "source/spirv_validator_options.h"
#include "spirv-tools/libspirv.hpp"
#include "tools/io.h"
#include "tools/util/cli_consumer.h"
void print_usage(char* argv0) {
printf(
@ -164,28 +165,7 @@ int main(int argc, char** argv) {
if (!ReadFile<uint32_t>(inFile, "rb", &contents)) return 1;
spvtools::SpirvTools tools(target_env);
tools.SetMessageConsumer([](spv_message_level_t level, const char*,
const spv_position_t& position,
const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: line " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_WARNING:
std::cout << "warning: line " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cout << "info: line " << position.index << ": " << message
<< std::endl;
break;
default:
break;
}
});
tools.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
bool succeed = tools.Validate(contents.data(), contents.size(), options);