SPIRV-Tools/test/cpp_interface_test.cpp
Steven Perron 5c8b4f5a1c
Validate the input to Optimizer::Run (#1799)
* Run the validator in the optimization fuzzers.

The optimizers assumes that the input to the optimizer is valid.  Since
the fuzzers do not check that the input is valid before passing the
spir-v to the optimizer, we are getting a few errors.

The solution is to run the validator in the optimizer to validate the
input.

For the legalization passes, we need to add an extra option to the
validator to accept certain types of variable pointers, even if the
capability is not given.  At the same time, we changed the option
"--legalize-hlsl" to relax the validator in the same way instead of
turning it off.
2018-08-08 11:16:19 -04:00

324 lines
10 KiB
C++

// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "spirv-tools/optimizer.hpp"
#include "spirv/1.1/spirv.h"
namespace spvtools {
namespace {
using ::testing::ContainerEq;
using ::testing::HasSubstr;
// Return a string that contains the minimum instructions needed to form
// a valid module. Other instructions can be appended to this string.
std::string Header() {
return R"(OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
}
TEST(CppInterface, SuccessfulRoundTrip) {
const std::string input_text = "%2 = OpSizeOf %1 %3\n";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text, &binary));
EXPECT_TRUE(binary.size() > 5u);
EXPECT_EQ(SpvMagicNumber, binary[0]);
EXPECT_EQ(SpvVersion, binary[1]);
// This cannot pass validation since %1 is not defined.
t.SetMessageConsumer([](spv_message_level_t level, const char* source,
const spv_position_t& position, const char* message) {
EXPECT_EQ(SPV_MSG_ERROR, level);
EXPECT_STREQ("input", source);
EXPECT_EQ(0u, position.line);
EXPECT_EQ(0u, position.column);
EXPECT_EQ(1u, position.index);
EXPECT_STREQ("ID 1 has not been defined\n %2 = OpSizeOf %1 %3\n", message);
});
EXPECT_FALSE(t.Validate(binary));
std::string output_text;
EXPECT_TRUE(t.Disassemble(binary, &output_text));
EXPECT_EQ(input_text, output_text);
}
TEST(CppInterface, AssembleEmptyModule) {
std::vector<uint32_t> binary(10, 42);
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
EXPECT_TRUE(t.Assemble("", &binary));
// We only have the header.
EXPECT_EQ(5u, binary.size());
EXPECT_EQ(SpvMagicNumber, binary[0]);
EXPECT_EQ(SpvVersion, binary[1]);
}
TEST(CppInterface, AssembleOverloads) {
const std::string input_text = "%2 = OpSizeOf %1 %3\n";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
{
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text, &binary));
EXPECT_TRUE(binary.size() > 5u);
EXPECT_EQ(SpvMagicNumber, binary[0]);
EXPECT_EQ(SpvVersion, binary[1]);
}
{
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary));
EXPECT_TRUE(binary.size() > 5u);
EXPECT_EQ(SpvMagicNumber, binary[0]);
EXPECT_EQ(SpvVersion, binary[1]);
}
{ // Ignore the last newline.
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary));
EXPECT_TRUE(binary.size() > 5u);
EXPECT_EQ(SpvMagicNumber, binary[0]);
EXPECT_EQ(SpvVersion, binary[1]);
}
}
TEST(CppInterface, DisassembleEmptyModule) {
std::string text(10, 'x');
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
int invocation_count = 0;
t.SetMessageConsumer(
[&invocation_count](spv_message_level_t level, const char* source,
const spv_position_t& position, const char* message) {
++invocation_count;
EXPECT_EQ(SPV_MSG_ERROR, level);
EXPECT_STREQ("input", source);
EXPECT_EQ(0u, position.line);
EXPECT_EQ(0u, position.column);
EXPECT_EQ(0u, position.index);
EXPECT_STREQ("Missing module.", message);
});
EXPECT_FALSE(t.Disassemble({}, &text));
EXPECT_EQ("xxxxxxxxxx", text); // The original string is unmodified.
EXPECT_EQ(1, invocation_count);
}
TEST(CppInterface, DisassembleOverloads) {
const std::string input_text = "%2 = OpSizeOf %1 %3\n";
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(input_text, &binary));
{
std::string output_text;
EXPECT_TRUE(t.Disassemble(binary, &output_text));
EXPECT_EQ(input_text, output_text);
}
{
std::string output_text;
EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text));
EXPECT_EQ(input_text, output_text);
}
}
TEST(CppInterface, SuccessfulValidation) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
int invocation_count = 0;
t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
const spv_position_t&, const char*) {
++invocation_count;
});
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(Header(), &binary));
EXPECT_TRUE(t.Validate(binary));
EXPECT_EQ(0, invocation_count);
}
TEST(CppInterface, ValidateOverloads) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(Header(), &binary));
{ EXPECT_TRUE(t.Validate(binary)); }
{ EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
}
TEST(CppInterface, ValidateEmptyModule) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
int invocation_count = 0;
t.SetMessageConsumer(
[&invocation_count](spv_message_level_t level, const char* source,
const spv_position_t& position, const char* message) {
++invocation_count;
EXPECT_EQ(SPV_MSG_ERROR, level);
EXPECT_STREQ("input", source);
EXPECT_EQ(0u, position.line);
EXPECT_EQ(0u, position.column);
EXPECT_EQ(0u, position.index);
EXPECT_STREQ("Invalid SPIR-V magic number.", message);
});
EXPECT_FALSE(t.Validate({}));
EXPECT_EQ(1, invocation_count);
}
// Returns the assembly for a SPIR-V module with a struct declaration
// with the given number of members.
std::string MakeModuleHavingStruct(int num_members) {
std::stringstream os;
os << Header();
os << R"(%1 = OpTypeInt 32 0
%2 = OpTypeStruct)";
for (int i = 0; i < num_members; i++) os << " %1";
return os.str();
}
TEST(CppInterface, ValidateWithOptionsPass) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
const ValidatorOptions opts;
EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts));
}
TEST(CppInterface, ValidateWithOptionsFail) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
ValidatorOptions opts;
opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9);
std::stringstream os;
t.SetMessageConsumer([&os](spv_message_level_t, const char*,
const spv_position_t&,
const char* message) { os << message; });
EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts));
EXPECT_THAT(
os.str(),
HasSubstr(
"Number of OpTypeStruct members (10) has exceeded the limit (9)"));
}
// Checks that after running the given optimizer |opt| on the given |original|
// source code, we can get the given |optimized| source code.
void CheckOptimization(const std::string& original,
const std::string& optimized, const Optimizer& opt) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> original_binary;
ASSERT_TRUE(t.Assemble(original, &original_binary));
std::vector<uint32_t> optimized_binary;
EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(),
&optimized_binary));
std::string optimized_text;
EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text));
EXPECT_EQ(optimized, optimized_text);
}
TEST(CppInterface, OptimizeEmptyModule) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
EXPECT_TRUE(t.Assemble("", &binary));
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass());
// Fails to validate.
EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary));
}
TEST(CppInterface, OptimizeModifiedModule) {
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass());
CheckOptimization(Header() + "OpSource GLSL 450", Header(), o);
}
TEST(CppInterface, OptimizeMulitplePasses) {
std::string original_text = Header() +
"OpSource GLSL 450 "
"OpDecorate %true SpecId 1 "
"%bool = OpTypeBool "
"%true = OpSpecConstantTrue %bool";
Optimizer o(SPV_ENV_UNIVERSAL_1_1);
o.RegisterPass(CreateStripDebugInfoPass())
.RegisterPass(CreateFreezeSpecConstantValuePass());
std::string expected_text = Header() +
"%bool = OpTypeBool\n"
"%true = OpConstantTrue %bool\n";
CheckOptimization(original_text, expected_text, o);
}
TEST(CppInterface, OptimizeDoNothingWithPassToken) {
CreateFreezeSpecConstantValuePass();
auto token = CreateUnifyConstantPass();
}
TEST(CppInterface, OptimizeReassignPassToken) {
auto token = CreateNullPass();
token = CreateStripDebugInfoPass();
CheckOptimization(
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
}
TEST(CppInterface, OptimizeMoveConstructPassToken) {
auto token1 = CreateStripDebugInfoPass();
Optimizer::PassToken token2(std::move(token1));
CheckOptimization(
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
}
TEST(CppInterface, OptimizeMoveAssignPassToken) {
auto token1 = CreateStripDebugInfoPass();
auto token2 = CreateNullPass();
token2 = std::move(token1);
CheckOptimization(
Header() + "OpSource GLSL 450", Header(),
Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
}
TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
std::vector<uint32_t> binary;
ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary));
EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
.RegisterPass(CreateStripDebugInfoPass())
.Run(binary.data(), binary.size(), &binary));
std::string optimized_text;
EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
EXPECT_EQ(Header(), optimized_text);
}
// TODO(antiagainst): tests for SetMessageConsumer().
} // namespace
} // namespace spvtools