Remove MarkV and Stats code. (#2576)

* Remove MarkV and Stats code.

This Cl removes the MarkV and Stats code from SPIRV-Tools. This code was
unused and currently un-maintained.
This commit is contained in:
dan sinclair 2019-05-24 15:43:59 -04:00 committed by Steven Perron
parent 3b5ab540ca
commit 42abaa099a
46 changed files with 8 additions and 25015 deletions

View File

@ -59,7 +59,7 @@ build:
build_script:
- mkdir build && cd build
- cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
- cmake -GNinja -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
- ninja install
test_script:

View File

@ -13,7 +13,6 @@ SPVTOOLS_SRC_FILES := \
source/ext_inst.cpp \
source/enum_string_mapping.cpp \
source/extensions.cpp \
source/id_descriptor.cpp \
source/libspirv.cpp \
source/name_mapper.cpp \
source/opcode.cpp \

View File

@ -69,6 +69,10 @@ if(NOT ${SKIP_SPIRV_TOOLS_INSTALL})
endif()
option(SPIRV_BUILD_COMPRESSION "Build SPIR-V compressing codec" OFF)
if(SPIRV_BUILD_COMPRESSION)
message(FATAL_ERROR "SPIR-V compression codec has been removed from SPIR-V tools. "
"Please remove SPIRV_BUILD_COMPRESSION from your build options.")
endif(SPIRV_BUILD_COMPRESSION)
option(SPIRV_WERROR "Enable error on warning" ON)
if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR (("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") AND (NOT CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC")))
@ -257,9 +261,6 @@ endif()
set(SPIRV_LIBRARIES "-lSPIRV-Tools -lSPIRV-Tools-link -lSPIRV-Tools-opt")
set(SPIRV_SHARED_LIBRARIES "-lSPIRV-Tools-shared")
if(SPIRV_BUILD_COMPRESSION)
set(SPIRV_LIBRARIES "${SPIRV_LIBRARIES} -lSPIRV-Tools-comp")
endif(SPIRV_BUILD_COMPRESSION)
# Build pkg-config file
# Use a first-class target so it's regenerated when relevant files are updated.

View File

@ -307,8 +307,6 @@ The following CMake options are supported:
the command line tools. This will prevent the tests from being built.
* `SPIRV_SKIP_EXECUTABLES={ON|OFF}`, default `OFF`- Build only the library, not
the command line tools and tests.
* `SPIRV_BUILD_COMPRESSION={ON|OFF}`, default `OFF`- Build SPIR-V compressing
codec.
* `SPIRV_USE_SANITIZER=<sanitizer>`, default is no sanitizing - On UNIX
platforms with an appropriate version of `clang` this option enables the use
of the sanitizers documented [here][clang-sanitizers].

View File

@ -44,7 +44,7 @@ mkdir build && cd $SRC/build
# Invoke the build.
BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT}
echo $(date): Starting build...
cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_BUILD_COMPRESSION=ON -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK ..
cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK ..
echo $(date): Build everything...
ninja

View File

@ -63,7 +63,7 @@ if "%KOKORO_GITHUB_COMMIT%." == "." (
set BUILD_SHA=%KOKORO_GITHUB_COMMIT%
)
set CMAKE_FLAGS=-DCMAKE_INSTALL_PREFIX=%KOKORO_ARTIFACTS_DIR%\install -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe
set CMAKE_FLAGS=-DCMAKE_INSTALL_PREFIX=%KOKORO_ARTIFACTS_DIR%\install -GNinja -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe
:: Skip building tests for VS2013
if %VS_VERSION% == 2013 (

View File

@ -196,7 +196,6 @@ set_source_files_properties(
${CMAKE_CURRENT_SOURCE_DIR}/pch_source.cpp
PROPERTIES OBJECT_DEPENDS "${PCH_DEPENDS}")
add_subdirectory(comp)
add_subdirectory(opt)
add_subdirectory(reduce)
add_subdirectory(link)
@ -221,7 +220,6 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.h
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.h
${CMAKE_CURRENT_SOURCE_DIR}/extensions.h
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.h
${CMAKE_CURRENT_SOURCE_DIR}/instruction.h
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_glsl_std_450_header.h
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_opencl_std_header.h
@ -254,7 +252,6 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libspirv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/opcode.cpp

View File

@ -1,52 +0,0 @@
# Copyright (c) 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(SPIRV_BUILD_COMPRESSION)
add_library(SPIRV-Tools-comp
bit_stream.cpp
bit_stream.h
huffman_codec.h
markv_codec.cpp
markv_codec.h
markv.cpp
markv.h
markv_decoder.cpp
markv_decoder.h
markv_encoder.cpp
markv_encoder.h
markv_logger.h
move_to_front.h
move_to_front.cpp)
spvtools_default_compile_options(SPIRV-Tools-comp)
target_include_directories(SPIRV-Tools-comp
PUBLIC ${spirv-tools_SOURCE_DIR}/include
PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
PRIVATE ${spirv-tools_BINARY_DIR}
)
target_link_libraries(SPIRV-Tools-comp
PUBLIC ${SPIRV_TOOLS})
set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries")
spvtools_check_symbol_exports(SPIRV-Tools-comp)
if(ENABLE_SPIRV_TOOLS_INSTALL)
install(TARGETS SPIRV-Tools-comp
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif(ENABLE_SPIRV_TOOLS_INSTALL)
endif(SPIRV_BUILD_COMPRESSION)

View File

@ -1,348 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstring>
#include <sstream>
#include <type_traits>
#include "source/comp/bit_stream.h"
namespace spvtools {
namespace comp {
namespace {
// Returns if the system is little-endian. Unfortunately only works during
// runtime.
bool IsLittleEndian() {
// This constant value allows the detection of the host machine's endianness.
// Accessing it as an array of bytes is valid due to C++11 section 3.10
// paragraph 10.
static const uint16_t kFF00 = 0xff00;
return reinterpret_cast<const unsigned char*>(&kFF00)[0] == 0;
}
// Copies bytes from the given buffer to a uint64_t buffer.
// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other
// direction is only advisable if uint8_t* is aligned to 64-bit word boundary.
std::vector<uint64_t> ToBuffer64(const void* buffer, size_t num_bytes) {
std::vector<uint64_t> out;
out.resize((num_bytes + 7) / 8, 0);
memcpy(out.data(), buffer, num_bytes);
return out;
}
// Copies uint8_t buffer to a uint64_t buffer.
std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
return ToBuffer64(in.data(), in.size());
}
// Returns uint64_t containing the same bits as |val|.
// Type size must be less than 8 bytes.
template <typename T>
uint64_t ToU64(T val) {
static_assert(sizeof(T) <= 8, "Type size too big");
uint64_t val64 = 0;
std::memcpy(&val64, &val, sizeof(T));
return val64;
}
// Returns value of type T containing the same bits as |val64|.
// Type size must be less than 8 bytes. Upper (unused) bits of |val64| must be
// zero (irrelevant, but is checked with assertion).
template <typename T>
T FromU64(uint64_t val64) {
assert(sizeof(T) == 8 || (val64 >> (sizeof(T) * 8)) == 0);
static_assert(sizeof(T) <= 8, "Type size too big");
T val = 0;
std::memcpy(&val, &val64, sizeof(T));
return val;
}
// Writes bits from |val| to |writer| in chunks of size |chunk_length|.
// Signal bit is used to signal if the reader should expect another chunk:
// 0 - no more chunks to follow
// 1 - more chunks to follow
// If number of written bits reaches |max_payload| last chunk is truncated.
void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val,
size_t chunk_length, size_t max_payload) {
assert(chunk_length > 0);
assert(chunk_length < max_payload);
assert(max_payload == 64 || (val >> max_payload) == 0);
if (val == 0) {
// Split in two writes for more readable logging.
writer->WriteBits(0, chunk_length);
writer->WriteBits(0, 1);
return;
}
size_t payload_written = 0;
while (val) {
if (payload_written + chunk_length >= max_payload) {
// This has to be the last chunk.
// There is no need for the signal bit and the chunk can be truncated.
const size_t left_to_write = max_payload - payload_written;
assert((val >> left_to_write) == 0);
writer->WriteBits(val, left_to_write);
break;
}
writer->WriteBits(val, chunk_length);
payload_written += chunk_length;
val = val >> chunk_length;
// Write a single bit to signal if there is more to come.
writer->WriteBits(val ? 1 : 0, 1);
}
}
// Reads data written with WriteVariableWidthInternal. |chunk_length| and
// |max_payload| should be identical to those used to write the data.
// Returns false if the stream ends prematurely.
bool ReadVariableWidthInternal(BitReaderInterface* reader, uint64_t* val,
size_t chunk_length, size_t max_payload) {
assert(chunk_length > 0);
assert(chunk_length <= max_payload);
size_t payload_read = 0;
while (payload_read + chunk_length < max_payload) {
uint64_t bits = 0;
if (reader->ReadBits(&bits, chunk_length) != chunk_length) return false;
*val |= bits << payload_read;
payload_read += chunk_length;
uint64_t more_to_come = 0;
if (reader->ReadBits(&more_to_come, 1) != 1) return false;
if (!more_to_come) {
return true;
}
}
// Need to read the last chunk which may be truncated. No signal bit follows.
uint64_t bits = 0;
const size_t left_to_read = max_payload - payload_read;
if (reader->ReadBits(&bits, left_to_read) != left_to_read) return false;
*val |= bits << payload_read;
return true;
}
// Calls WriteVariableWidthInternal with the right max_payload argument.
template <typename T>
void WriteVariableWidthUnsigned(BitWriterInterface* writer, T val,
size_t chunk_length) {
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
static_assert(std::is_integral<T>::value, "Type must be integral");
WriteVariableWidthInternal(writer, val, chunk_length, sizeof(T) * 8);
}
// Calls ReadVariableWidthInternal with the right max_payload argument.
template <typename T>
bool ReadVariableWidthUnsigned(BitReaderInterface* reader, T* val,
size_t chunk_length) {
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
static_assert(std::is_integral<T>::value, "Type must be integral");
uint64_t val64 = 0;
if (!ReadVariableWidthInternal(reader, &val64, chunk_length, sizeof(T) * 8))
return false;
*val = static_cast<T>(val64);
assert(*val == val64);
return true;
}
// Encodes signed |val| to an unsigned value and calls
// WriteVariableWidthInternal with the right max_payload argument.
template <typename T>
void WriteVariableWidthSigned(BitWriterInterface* writer, T val,
size_t chunk_length, size_t zigzag_exponent) {
static_assert(std::is_signed<T>::value, "Type must be signed");
static_assert(std::is_integral<T>::value, "Type must be integral");
WriteVariableWidthInternal(writer, EncodeZigZag(val, zigzag_exponent),
chunk_length, sizeof(T) * 8);
}
// Calls ReadVariableWidthInternal with the right max_payload argument
// and decodes the value.
template <typename T>
bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val,
size_t chunk_length, size_t zigzag_exponent) {
static_assert(std::is_signed<T>::value, "Type must be signed");
static_assert(std::is_integral<T>::value, "Type must be integral");
uint64_t encoded = 0;
if (!ReadVariableWidthInternal(reader, &encoded, chunk_length, sizeof(T) * 8))
return false;
const int64_t decoded = DecodeZigZag(encoded, zigzag_exponent);
*val = static_cast<T>(decoded);
assert(*val == decoded);
return true;
}
} // namespace
void BitWriterInterface::WriteVariableWidthU64(uint64_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthU32(uint32_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthU16(uint16_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthS64(int64_t val, size_t chunk_length,
size_t zigzag_exponent) {
WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) {
buffer_.reserve(NumBitsToNumWords<64>(reserve_bits));
}
void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) {
// Check that |bits| and |num_bits| are valid and consistent.
assert(num_bits <= 64);
const bool is_little_endian = IsLittleEndian();
assert(is_little_endian && "Big-endian architecture support not implemented");
if (!is_little_endian) return;
if (num_bits == 0) return;
bits = GetLowerBits(bits, num_bits);
EmitSequence(bits, num_bits);
// Offset from the start of the current word.
const size_t offset = end_ % 64;
if (offset == 0) {
// If no offset, simply add |bits| as a new word to the buffer_.
buffer_.push_back(bits);
} else {
// Shift bits and add them to the current word after offset.
const uint64_t first_word = bits << offset;
buffer_.back() |= first_word;
// If we don't overflow to the next word, there is nothing more to do.
if (offset + num_bits > 64) {
// We overflow to the next word.
const uint64_t second_word = bits >> (64 - offset);
// Add remaining bits as a new word to buffer_.
buffer_.push_back(second_word);
}
}
// Move end_ into position for next write.
end_ += num_bits;
assert(buffer_.size() * 64 >= end_);
}
bool BitReaderInterface::ReadVariableWidthU64(uint64_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthU32(uint32_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthS64(int64_t* val, size_t chunk_length,
size_t zigzag_exponent) {
return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
BitReaderWord64::BitReaderWord64(std::vector<uint64_t>&& buffer)
: buffer_(std::move(buffer)), pos_(0) {}
BitReaderWord64::BitReaderWord64(const std::vector<uint8_t>& buffer)
: buffer_(ToBuffer64(buffer)), pos_(0) {}
BitReaderWord64::BitReaderWord64(const void* buffer, size_t num_bytes)
: buffer_(ToBuffer64(buffer, num_bytes)), pos_(0) {}
size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) {
assert(num_bits <= 64);
const bool is_little_endian = IsLittleEndian();
assert(is_little_endian && "Big-endian architecture support not implemented");
if (!is_little_endian) return 0;
if (ReachedEnd()) return 0;
// Index of the current word.
const size_t index = pos_ / 64;
// Bit position in the current word where we start reading.
const size_t offset = pos_ % 64;
// Read all bits from the current word (it might be too much, but
// excessive bits will be removed later).
*bits = buffer_[index] >> offset;
const size_t num_read_from_first_word = std::min(64 - offset, num_bits);
pos_ += num_read_from_first_word;
if (pos_ >= buffer_.size() * 64) {
// Reached end of buffer_.
EmitSequence(*bits, num_read_from_first_word);
return num_read_from_first_word;
}
if (offset + num_bits > 64) {
// Requested |num_bits| overflows to next word.
// Write all bits from the beginning of next word to *bits after offset.
*bits |= buffer_[index + 1] << (64 - offset);
pos_ += offset + num_bits - 64;
}
// We likely have written more bits than requested. Clear excessive bits.
*bits = GetLowerBits(*bits, num_bits);
EmitSequence(*bits, num_bits);
return num_bits;
}
bool BitReaderWord64::ReachedEnd() const { return pos_ >= buffer_.size() * 64; }
bool BitReaderWord64::OnlyZeroesLeft() const {
if (ReachedEnd()) return true;
const size_t index = pos_ / 64;
if (index < buffer_.size() - 1) return false;
assert(index == buffer_.size() - 1);
const size_t offset = pos_ % 64;
const uint64_t remaining_bits = buffer_[index] >> offset;
return !remaining_bits;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,280 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains utils for reading, writing and debug printing bit streams.
#ifndef SOURCE_COMP_BIT_STREAM_H_
#define SOURCE_COMP_BIT_STREAM_H_
#include <algorithm>
#include <bitset>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <functional>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace spvtools {
namespace comp {
// Terminology:
// Bits - usually used for a uint64 word, first bit is the lowest.
// Stream - std::string of '0' and '1', read left-to-right,
// i.e. first bit is at the front and not at the end as in
// std::bitset::to_string().
// Bitset - std::bitset corresponding to uint64 bits and to reverse(stream).
// Converts number of bits to a respective number of chunks of size N.
// For example NumBitsToNumWords<8> returns how many bytes are needed to store
// |num_bits|.
template <size_t N>
inline size_t NumBitsToNumWords(size_t num_bits) {
return (num_bits + (N - 1)) / N;
}
// Returns value of the same type as |in|, where all but the first |num_bits|
// are set to zero.
template <typename T>
inline T GetLowerBits(T in, size_t num_bits) {
return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1));
}
// Encodes signed integer as unsigned. This is a generalized version of
// EncodeZigZag, designed to favor small positive numbers.
// Values are transformed in blocks of 2^|block_exponent|.
// If |block_exponent| is zero, then this degenerates into normal EncodeZigZag.
// Example when |block_exponent| is 1 (return value is the index):
// 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8
// Example when |block_exponent| is 2:
// 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8
inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) {
assert(block_exponent < 64);
const uint64_t uval = static_cast<uint64_t>(val >= 0 ? val : -val - 1);
const uint64_t block_num =
((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1);
const uint64_t pos = GetLowerBits(uval, block_exponent);
return (block_num << block_exponent) + pos;
}
// Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be
// the same.
inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) {
assert(block_exponent < 64);
const uint64_t block_num = val >> block_exponent;
const uint64_t pos = GetLowerBits(val, block_exponent);
if (block_num & 1) {
// Negative.
return -1LL - ((block_num >> 1) << block_exponent) - pos;
} else {
// Positive.
return ((block_num >> 1) << block_exponent) + pos;
}
}
// Converts first |num_bits| stored in uint64 to a left-to-right stream of bits.
inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) {
std::bitset<64> bitset(bits);
std::string str = bitset.to_string().substr(64 - num_bits);
std::reverse(str.begin(), str.end());
return str;
}
// Base class for writing sequences of bits.
class BitWriterInterface {
public:
BitWriterInterface() = default;
virtual ~BitWriterInterface() = default;
// Writes lower |num_bits| in |bits| to the stream.
// |num_bits| must be no greater than 64.
virtual void WriteBits(uint64_t bits, size_t num_bits) = 0;
// Writes bits from value of type |T| to the stream. No encoding is done.
// Always writes 8 * sizeof(T) bits.
template <typename T>
void WriteUnencoded(T val) {
static_assert(sizeof(T) <= 64, "Type size too large");
uint64_t bits = 0;
memcpy(&bits, &val, sizeof(T));
WriteBits(bits, sizeof(T) * 8);
}
// Writes |val| in chunks of size |chunk_length| followed by a signal bit:
// 0 - no more chunks to follow
// 1 - more chunks to follow
// for example 255 is encoded into 1111 1 1111 0 for chunk length 4.
// The last chunk can be truncated and signal bit omitted, if the entire
// payload (for example 16 bit for uint16_t has already been written).
void WriteVariableWidthU64(uint64_t val, size_t chunk_length);
void WriteVariableWidthU32(uint32_t val, size_t chunk_length);
void WriteVariableWidthU16(uint16_t val, size_t chunk_length);
void WriteVariableWidthS64(int64_t val, size_t chunk_length,
size_t zigzag_exponent);
// Returns number of bits written.
virtual size_t GetNumBits() const = 0;
// Provides direct access to the buffer data if implemented.
virtual const uint8_t* GetData() const { return nullptr; }
// Returns buffer size in bytes.
size_t GetDataSizeBytes() const { return NumBitsToNumWords<8>(GetNumBits()); }
// Generates and returns byte array containing written bits.
virtual std::vector<uint8_t> GetDataCopy() const = 0;
BitWriterInterface(const BitWriterInterface&) = delete;
BitWriterInterface& operator=(const BitWriterInterface&) = delete;
};
// This class is an implementation of BitWriterInterface, using
// std::vector<uint64_t> to store written bits.
class BitWriterWord64 : public BitWriterInterface {
public:
explicit BitWriterWord64(size_t reserve_bits = 64);
void WriteBits(uint64_t bits, size_t num_bits) override;
size_t GetNumBits() const override { return end_; }
const uint8_t* GetData() const override {
return reinterpret_cast<const uint8_t*>(buffer_.data());
}
std::vector<uint8_t> GetDataCopy() const override {
return std::vector<uint8_t>(GetData(), GetData() + GetDataSizeBytes());
}
// Sets callback to emit bit sequences after every write.
void SetCallback(std::function<void(const std::string&)> callback) {
callback_ = callback;
}
protected:
// Sends string generated from arguments to callback_ if defined.
void EmitSequence(uint64_t bits, size_t num_bits) const {
if (callback_) callback_(BitsToStream(bits, num_bits));
}
private:
std::vector<uint64_t> buffer_;
// Total number of bits written so far. Named 'end' as analogy to std::end().
size_t end_;
// If not null, the writer will use the callback to emit the written bit
// sequence as a string of '0' and '1'.
std::function<void(const std::string&)> callback_;
};
// Base class for reading sequences of bits.
class BitReaderInterface {
public:
BitReaderInterface() {}
virtual ~BitReaderInterface() {}
// Reads |num_bits| from the stream, stores them in |bits|.
// Returns number of read bits. |num_bits| must be no greater than 64.
virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0;
// Reads 8 * sizeof(T) bits and stores them in |val|.
template <typename T>
bool ReadUnencoded(T* val) {
static_assert(sizeof(T) <= 64, "Type size too large");
uint64_t bits = 0;
const size_t num_read = ReadBits(&bits, sizeof(T) * 8);
if (num_read != sizeof(T) * 8) return false;
memcpy(val, &bits, sizeof(T));
return true;
}
// Returns number of bits already read.
virtual size_t GetNumReadBits() const = 0;
// These two functions define 'hard' and 'soft' EOF.
//
// Returns true if the end of the buffer was reached.
virtual bool ReachedEnd() const = 0;
// Returns true if we reached the end of the buffer or are nearing it and only
// zero bits are left to read. Implementations of this function are allowed to
// commit a "false negative" error if the end of the buffer was not reached,
// i.e. it can return false even if indeed only zeroes are left.
// It is assumed that the consumer expects that
// the buffer stream ends with padding zeroes, and would accept this as a
// 'soft' EOF. Implementations of this class do not necessarily need to
// implement this, default behavior can simply delegate to ReachedEnd().
virtual bool OnlyZeroesLeft() const { return ReachedEnd(); }
// Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface).
// Reader and writer must use the same |chunk_length| and variable type.
// Returns true on success, false if the bit stream ends prematurely.
bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length);
bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length);
bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length);
bool ReadVariableWidthS64(int64_t* val, size_t chunk_length,
size_t zigzag_exponent);
BitReaderInterface(const BitReaderInterface&) = delete;
BitReaderInterface& operator=(const BitReaderInterface&) = delete;
};
// This class is an implementation of BitReaderInterface which accepts both
// uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and
// owned. uint8_t buffers are copied.
class BitReaderWord64 : public BitReaderInterface {
public:
// Consumes and owns the buffer.
explicit BitReaderWord64(std::vector<uint64_t>&& buffer);
// Copies the buffer and casts it to uint64.
// Consuming the original buffer and casting it to uint64 is difficult,
// as it would potentially cause data misalignment and poor performance.
explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
BitReaderWord64(const void* buffer, size_t num_bytes);
size_t ReadBits(uint64_t* bits, size_t num_bits) override;
size_t GetNumReadBits() const override { return pos_; }
bool ReachedEnd() const override;
bool OnlyZeroesLeft() const override;
BitReaderWord64() = delete;
// Sets callback to emit bit sequences after every read.
void SetCallback(std::function<void(const std::string&)> callback) {
callback_ = callback;
}
protected:
// Sends string generated from arguments to callback_ if defined.
void EmitSequence(uint64_t bits, size_t num_bits) const {
if (callback_) callback_(BitsToStream(bits, num_bits));
}
private:
const std::vector<uint64_t> buffer_;
size_t pos_;
// If not null, the reader will use the callback to emit the read bit
// sequence as a string of '0' and '1'.
std::function<void(const std::string&)> callback_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_BIT_STREAM_H_

View File

@ -1,389 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains utils for reading, writing and debug printing bit streams.
#ifndef SOURCE_COMP_HUFFMAN_CODEC_H_
#define SOURCE_COMP_HUFFMAN_CODEC_H_
#include <algorithm>
#include <cassert>
#include <functional>
#include <iomanip>
#include <map>
#include <memory>
#include <ostream>
#include <queue>
#include <sstream>
#include <stack>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
namespace spvtools {
namespace comp {
// Used to generate and apply a Huffman coding scheme.
// |Val| is the type of variable being encoded (for example a string or a
// literal).
template <class Val>
class HuffmanCodec {
public:
// Huffman tree node.
struct Node {
Node() {}
// Creates Node from serialization leaving weight and id undefined.
Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
: value(in_value), left(in_left), right(in_right) {}
Val value = Val();
uint32_t weight = 0;
// Ids are issued sequentially starting from 1. Ids are used as an ordering
// tie-breaker, to make sure that the ordering (and resulting coding scheme)
// are consistent accross multiple platforms.
uint32_t id = 0;
// Handles of children.
uint32_t left = 0;
uint32_t right = 0;
};
// Creates Huffman codec from a histogramm.
// Histogramm counts must not be zero.
explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
if (hist.empty()) return;
// Heuristic estimate.
nodes_.reserve(3 * hist.size());
// Create NIL.
CreateNode();
// The queue is sorted in ascending order by weight (or by node id if
// weights are equal).
std::vector<uint32_t> queue_vector;
queue_vector.reserve(hist.size());
std::priority_queue<uint32_t, std::vector<uint32_t>,
std::function<bool(uint32_t, uint32_t)>>
queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
std::placeholders::_1, std::placeholders::_2),
std::move(queue_vector));
// Put all leaves in the queue.
for (const auto& pair : hist) {
const uint32_t node = CreateNode();
MutableValueOf(node) = pair.first;
MutableWeightOf(node) = pair.second;
assert(WeightOf(node));
queue.push(node);
}
// Form the tree by combining two subtrees with the least weight,
// and pushing the root of the new tree in the queue.
while (true) {
// We push a node at the end of each iteration, so the queue is never
// supposed to be empty at this point, unless there are no leaves, but
// that case was already handled.
assert(!queue.empty());
const uint32_t right = queue.top();
queue.pop();
// If the queue is empty at this point, then the last node is
// the root of the complete Huffman tree.
if (queue.empty()) {
root_ = right;
break;
}
const uint32_t left = queue.top();
queue.pop();
// Combine left and right into a new tree and push it into the queue.
const uint32_t parent = CreateNode();
MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
MutableLeftOf(parent) = left;
MutableRightOf(parent) = right;
queue.push(parent);
}
// Traverse the tree and form encoding table.
CreateEncodingTable();
}
// Creates Huffman codec from saved tree structure.
// |nodes| is the list of nodes of the tree, nodes[0] being NIL.
// |root_handle| is the index of the root node.
HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
nodes_ = std::move(nodes);
assert(!nodes_.empty());
assert(root_handle > 0 && root_handle < nodes_.size());
assert(!LeftOf(0) && !RightOf(0));
root_ = root_handle;
// Traverse the tree and form encoding table.
CreateEncodingTable();
}
// Serializes the codec in the following text format:
// (<root_handle>, {
// {0, 0, 0},
// {val1, left1, right1},
// {val2, left2, right2},
// ...
// })
std::string SerializeToText(int indent_num_whitespaces) const {
const bool value_is_text = std::is_same<Val, std::string>::value;
const std::string indent1 = std::string(indent_num_whitespaces, ' ');
const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
std::stringstream code;
code << "(" << root_ << ", {\n";
for (const Node& node : nodes_) {
code << indent2 << "{";
if (value_is_text) code << "\"";
code << node.value;
if (value_is_text) code << "\"";
code << ", " << node.left << ", " << node.right << "},\n";
}
code << indent1 << "})";
return code.str();
}
// Prints the Huffman tree in the following format:
// w------w------'x'
// w------'y'
// Where w stands for the weight of the node.
// Right tree branches appear above left branches. Taking the right path
// adds 1 to the code, taking the left adds 0.
void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); }
// Traverses the tree and prints the Huffman table: value, code
// and optionally node weight for every leaf.
void PrintTable(std::ostream& out, bool print_weights = true) {
std::queue<std::pair<uint32_t, std::string>> queue;
queue.emplace(root_, "");
while (!queue.empty()) {
const uint32_t node = queue.front().first;
const std::string code = queue.front().second;
queue.pop();
if (!RightOf(node) && !LeftOf(node)) {
out << ValueOf(node);
if (print_weights) out << " " << WeightOf(node);
out << " " << code << std::endl;
} else {
if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0");
if (RightOf(node)) queue.emplace(RightOf(node), code + "1");
}
}
}
// Returns the Huffman table. The table was built at at construction time,
// this function just returns a const reference.
const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()
const {
return encoding_table_;
}
// Encodes |val| and stores its Huffman code in the lower |num_bits| of
// |bits|. Returns false of |val| is not in the Huffman table.
bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const {
auto it = encoding_table_.find(val);
if (it == encoding_table_.end()) return false;
*bits = it->second.first;
*num_bits = it->second.second;
return true;
}
// Reads bits one-by-one using callback |read_bit| until a match is found.
// Matching value is stored in |val|. Returns false if |read_bit| terminates
// before a code was mathced.
// |read_bit| has type bool func(bool* bit). When called, the next bit is
// stored in |bit|. |read_bit| returns false if the stream terminates
// prematurely.
bool DecodeFromStream(const std::function<bool(bool*)>& read_bit,
Val* val) const {
uint32_t node = root_;
while (true) {
assert(node);
if (!RightOf(node) && !LeftOf(node)) {
*val = ValueOf(node);
return true;
}
bool go_right;
if (!read_bit(&go_right)) return false;
if (go_right)
node = RightOf(node);
else
node = LeftOf(node);
}
assert(0);
return false;
}
private:
// Returns value of the node referenced by |handle|.
Val ValueOf(uint32_t node) const { return nodes_.at(node).value; }
// Returns left child of |node|.
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
// Returns right child of |node|.
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
// Returns weight of |node|.
uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; }
// Returns id of |node|.
uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; }
// Returns mutable reference to value of |node|.
Val& MutableValueOf(uint32_t node) {
assert(node);
return nodes_.at(node).value;
}
// Returns mutable reference to handle of left child of |node|.
uint32_t& MutableLeftOf(uint32_t node) {
assert(node);
return nodes_.at(node).left;
}
// Returns mutable reference to handle of right child of |node|.
uint32_t& MutableRightOf(uint32_t node) {
assert(node);
return nodes_.at(node).right;
}
// Returns mutable reference to weight of |node|.
uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; }
// Returns mutable reference to id of |node|.
uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; }
// Returns true if |left| has bigger weight than |right|. Node ids are
// used as tie-breaker.
bool LeftIsBigger(uint32_t left, uint32_t right) const {
if (WeightOf(left) == WeightOf(right)) {
assert(IdOf(left) != IdOf(right));
return IdOf(left) > IdOf(right);
}
return WeightOf(left) > WeightOf(right);
}
// Prints subtree (helper function used by PrintTree).
void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
if (!node) return;
const size_t kTextFieldWidth = 7;
if (!RightOf(node) && !LeftOf(node)) {
out << ValueOf(node) << std::endl;
} else {
if (RightOf(node)) {
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< WeightOf(RightOf(node));
out << label.str();
PrintTreeInternal(out, RightOf(node), depth + 1);
}
if (LeftOf(node)) {
out << std::string(depth * kTextFieldWidth, ' ');
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< WeightOf(LeftOf(node));
out << label.str();
PrintTreeInternal(out, LeftOf(node), depth + 1);
}
}
}
// Traverses the Huffman tree and saves paths to the leaves as bit
// sequences to encoding_table_.
void CreateEncodingTable() {
struct Context {
Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
: node(in_node), bits(in_bits), depth(in_depth) {}
uint32_t node;
// Huffman tree depth cannot exceed 64 as histogramm counts are expected
// to be positive and limited by numeric_limits<uint32_t>::max().
// For practical applications tree depth would be much smaller than 64.
uint64_t bits;
size_t depth;
};
std::queue<Context> queue;
queue.emplace(root_, 0, 0);
while (!queue.empty()) {
const Context& context = queue.front();
const uint32_t node = context.node;
const uint64_t bits = context.bits;
const size_t depth = context.depth;
queue.pop();
if (!RightOf(node) && !LeftOf(node)) {
auto insertion_result = encoding_table_.emplace(
ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
assert(insertion_result.second);
(void)insertion_result;
} else {
if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1);
if (RightOf(node))
queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
}
}
}
// Creates new Huffman tree node and stores it in the deleter array.
uint32_t CreateNode() {
const uint32_t handle = static_cast<uint32_t>(nodes_.size());
nodes_.emplace_back(Node());
nodes_.back().id = next_node_id_++;
return handle;
}
// Huffman tree root handle.
uint32_t root_ = 0;
// Huffman tree deleter.
std::vector<Node> nodes_;
// Encoding table value -> {bits, num_bits}.
// Huffman codes are expected to never exceed 64 bit length (this is in fact
// impossible if frequencies are stored as uint32_t).
std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
// Next node id issued by CreateNode();
uint32_t next_node_id_ = 1;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_HUFFMAN_CODEC_H_

View File

@ -1,112 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv.h"
#include "source/comp/markv_decoder.h"
#include "source/comp/markv_encoder.h"
namespace spvtools {
namespace comp {
namespace {
spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
uint32_t magic, uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t schema) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
schema);
}
spv_result_t EncodeInstruction(void* user_data,
const spv_parsed_instruction_t* inst) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeInstruction(*inst);
}
} // namespace
spv_result_t SpirvToMarkv(
spv_const_context context, const std::vector<uint32_t>& spirv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
spv_validator_options validator_options =
MarkvDecoder::GetValidatorOptions(options);
if (validator_options) {
spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
const spv_result_t result = spvValidateWithOptions(
&hijack_context, validator_options, &spirv_binary, nullptr);
if (result != SPV_SUCCESS) return result;
}
MarkvEncoder encoder(&hijack_context, options, &markv_model);
spv_position_t position = {};
if (log_consumer || debug_consumer) {
encoder.CreateLogger(log_consumer, debug_consumer);
spv_text text = nullptr;
if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Failed to disassemble SPIR-V binary.";
}
assert(text);
encoder.SetDisassembly(std::string(text->str, text->length));
spvTextDestroy(text);
}
if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to encode to MARK-V.";
}
*markv = encoder.GetMarkvBinary();
return SPV_SUCCESS;
}
spv_result_t MarkvToSpirv(
spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
spv_position_t position = {};
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
if (log_consumer || debug_consumer)
decoder.CreateLogger(log_consumer, debug_consumer);
if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to decode MARK-V.";
}
assert(!spirv->empty());
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,74 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result ids which can be regenerated) and
// uses various bit reduction techiniques to reduce the size of the binary and
// make it more similar to other compressed SPIR-V files to further improve
// compression of the dataset.
#ifndef SOURCE_COMP_MARKV_H_
#define SOURCE_COMP_MARKV_H_
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
class MarkvModel;
struct MarkvCodecOptions {
bool validate_spirv_binary = false;
};
// Debug callback. Called once per instruction.
// |words| is instruction SPIR-V words.
// |bits| is a textual representation of the MARK-V bit sequence used to encode
// the instruction (char '0' for 0, char '1' for 1).
// |comment| contains all logs generated while processing the instruction.
using MarkvDebugConsumer =
std::function<bool(const std::vector<uint32_t>& words,
const std::string& bits, const std::string& comment)>;
// Logging callback. Called often (if decoder reads a single bit, the log
// consumer will receive 1 character string with that bit).
// This callback is more suitable for continous output than MarkvDebugConsumer,
// for example if the codec crashes it would allow to pinpoint on which operand
// or bit the crash happened.
// |snippet| could be any atomic fragment of text logged by the codec. It can
// contain a paragraph of text with newlines, or can be just one character.
using MarkvLogConsumer = std::function<void(const std::string& snippet)>;
// Encodes the given SPIR-V binary to MARK-V binary.
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
spv_result_t SpirvToMarkv(
spv_const_context context, const std::vector<uint32_t>& spirv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv);
// Decodes a SPIR-V binary from the given MARK-V binary.
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
spv_result_t MarkvToSpirv(
spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv);
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_H_

View File

@ -1,793 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result IDs which can be regenerated) and
// uses various bit reduction techniques to reduce the size of the binary.
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/opcode.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace comp {
namespace {
// Custom hash function used to produce short descriptors.
uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
// The hash function is a sum of hashes of each word seeded by word index.
// Knuth's multiplicative hash is used to hash the words.
const uint32_t kKnuthMulHash = 2654435761;
uint32_t val = 0;
for (uint32_t i = 0; i < words.size(); ++i) {
val += (words[i] + i + 123) * kKnuthMulHash;
}
return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1);
}
// Returns a set of mtf rank codecs based on a plausible hand-coded
// distribution.
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
GetMtfHuffmanCodecs() {
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{0, 5},
{1, 40},
{2, 10},
{3, 5},
{4, 5},
{5, 5},
{6, 3},
{7, 3},
{8, 3},
{9, 3},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfAll, std::move(codec));
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{1, 50},
{2, 20},
{3, 5},
{4, 5},
{5, 2},
{6, 1},
{7, 1},
{8, 1},
{9, 1},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
return codecs;
}
} // namespace
const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303;
const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10;
const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal =
std::numeric_limits<uint32_t>::max();
const uint32_t MarkvCodec::kShortDescriptorNumBits = 8;
const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8;
MarkvCodec::MarkvCodec(spv_const_context context,
spv_validator_options validator_options,
const MarkvModel* model)
: validator_options_(validator_options),
grammar_(context),
model_(model),
short_id_descriptors_(ShortHashU32Array),
mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
context_(context) {}
MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); }
MarkvCodec::MarkvHeader::MarkvHeader()
: magic_number(MarkvCodec::kMarkvMagicNumber),
markv_version(MarkvCodec::GetMarkvVersion()) {}
// Defines and returns current MARK-V version.
// static
uint32_t MarkvCodec::GetMarkvVersion() {
const uint32_t kVersionMajor = 1;
const uint32_t kVersionMinor = 4;
return kVersionMinor | (kVersionMajor << 16);
}
size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const {
return (8 - (bit_pos % 8)) % 8;
}
// Returns true if the opcode has a fixed number of operands. May return a
// false negative.
bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const {
switch (opcode) {
// TODO(atgoo@github.com) This is not a complete list.
case SpvOpNop:
case SpvOpName:
case SpvOpUndef:
case SpvOpSizeOf:
case SpvOpLine:
case SpvOpNoLine:
case SpvOpDecorationGroup:
case SpvOpExtension:
case SpvOpExtInstImport:
case SpvOpMemoryModel:
case SpvOpCapability:
case SpvOpTypeVoid:
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypeVector:
case SpvOpTypeMatrix:
case SpvOpTypeSampler:
case SpvOpTypeSampledImage:
case SpvOpTypeArray:
case SpvOpTypePointer:
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpLabel:
case SpvOpBranch:
case SpvOpFunction:
case SpvOpFunctionParameter:
case SpvOpFunctionEnd:
case SpvOpBitcast:
case SpvOpCopyObject:
case SpvOpTranspose:
case SpvOpSNegate:
case SpvOpFNegate:
case SpvOpIAdd:
case SpvOpFAdd:
case SpvOpISub:
case SpvOpFSub:
case SpvOpIMul:
case SpvOpFMul:
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSMod:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
return true;
default:
break;
}
return false;
}
void MarkvCodec::ProcessCurInstruction() {
instructions_.emplace_back(new val::Instruction(&inst_));
const SpvOp opcode = SpvOp(inst_.opcode);
if (inst_.result_id) {
id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
// Collect ids local to the current function.
if (cur_function_id_) {
ids_local_to_cur_function_.push_back(inst_.result_id);
}
// Starting new function.
if (opcode == SpvOpFunction) {
cur_function_id_ = inst_.result_id;
cur_function_return_type_ = inst_.type_id;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
inst_.result_id);
}
// Store function parameter types in a queue, so that we know which types
// to expect in the following OpFunctionParameter instructions.
const val::Instruction* def_inst = FindDef(inst_.words[4]);
assert(def_inst);
assert(def_inst->opcode() == SpvOpTypeFunction);
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
remaining_function_parameter_types_.push_back(def_inst->word(i));
}
}
}
// Remove local ids from MTFs if function end.
if (opcode == SpvOpFunctionEnd) {
cur_function_id_ = 0;
for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
ids_local_to_cur_function_.clear();
assert(remaining_function_parameter_types_.empty());
}
if (!inst_.result_id) return;
{
// Save the result ID to type ID mapping.
// In the grammar, type ID always appears before result ID.
// A regular value maps to its type. Some instructions (e.g. OpLabel)
// have no type Id, and will map to 0. The result Id for a
// type-generating instruction (e.g. OpTypeInt) maps to itself.
auto insertion_result = id_to_type_id_.emplace(
inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
? inst_.result_id
: inst_.type_id);
(void)insertion_result;
assert(insertion_result.second);
}
// Add result_id to MTFs.
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
switch (opcode) {
case SpvOpTypeFloat:
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypeVector:
case SpvOpTypePointer:
case SpvOpExtInstImport:
case SpvOpTypeSampledImage:
case SpvOpTypeImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
break;
default:
break;
}
if (spvOpcodeIsComposite(opcode)) {
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
}
if (opcode == SpvOpLabel) {
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
}
if (opcode == SpvOpTypeInt) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeFloat) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeBool) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeVector) {
const uint32_t component_type_id = inst_.words[2];
const uint32_t size = inst_.words[3];
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
}
if (inst_.opcode == SpvOpTypeFunction) {
const uint32_t return_type = inst_.words[2];
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
inst_.result_id);
}
if (inst_.type_id) {
const val::Instruction* type_inst = FindDef(inst_.type_id);
assert(type_inst);
multi_mtf_.Insert(kMtfObject, inst_.result_id);
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
}
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
switch (type_inst->opcode()) {
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypePointer:
case SpvOpTypeVector:
case SpvOpTypeImage:
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(
GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
inst_.result_id);
break;
default:
break;
}
if (type_inst->opcode() == SpvOpTypeVector) {
const uint32_t component_type = type_inst->word(2);
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
inst_.result_id);
}
if (type_inst->opcode() == SpvOpTypePointer) {
assert(type_inst->operands().size() > 2);
assert(type_inst->words().size() > type_inst->operands()[2].offset);
const uint32_t data_type =
type_inst->word(type_inst->operands()[2].offset);
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
}
}
if (spvOpcodeGeneratesType(opcode)) {
if (opcode != SpvOpTypeFunction) {
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
}
}
}
if (model_->AnyDescriptorHasCodingScheme()) {
const uint32_t long_descriptor =
long_id_descriptors_.ProcessInstruction(inst_);
if (model_->DescriptorHasCodingScheme(long_descriptor))
multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
inst_.result_id);
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
const uint32_t short_descriptor =
short_id_descriptors_.ProcessInstruction(inst_);
multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
inst_.result_id);
}
}
uint64_t MarkvCodec::GetRuleBasedMtf() {
// This function is only called for id operands (but not result ids).
assert(spvIsIdType(operand_.type) ||
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
// All operand slots which expect label id.
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranchConditional &&
(operand_index_ == 1 || operand_index_ == 2)) ||
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
operand_index_ % 2 == 1) ||
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
return kMtfLabel;
}
switch (opcode) {
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpFNegate: {
if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
return GetMtfIdOfType(inst_.type_id);
}
case SpvOpISub:
case SpvOpIAdd:
case SpvOpIMul:
case SpvOpSDiv:
case SpvOpUDiv:
case SpvOpSMod:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSNegate: {
if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
return kMtfIntScalarOrVector;
}
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
case SpvOpFOrdEqual:
case SpvOpFUnordEqual:
case SpvOpFOrdNotEqual:
case SpvOpFUnordNotEqual:
case SpvOpFOrdLessThan:
case SpvOpFUnordLessThan:
case SpvOpFOrdGreaterThan:
case SpvOpFUnordGreaterThan:
case SpvOpFOrdLessThanEqual:
case SpvOpFUnordLessThanEqual:
case SpvOpFOrdGreaterThanEqual:
case SpvOpFUnordGreaterThanEqual: {
if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
if (operand_index_ == 2) return kMtfFloatScalarOrVector;
if (operand_index_ == 3) {
const uint32_t first_operand_id = GetInstWords()[3];
const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
return GetMtfIdOfType(first_operand_type);
}
break;
}
case SpvOpVectorShuffle: {
if (operand_index_ == 0) {
assert(inst_.num_operands > 4);
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
}
assert(inst_.type_id);
if (operand_index_ == 2 || operand_index_ == 3)
return GetMtfVectorOfComponentType(
GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpVectorTimesScalar: {
if (operand_index_ == 0) {
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
assert(inst_.type_id);
if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
if (operand_index_ == 3)
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpDot: {
if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
assert(inst_.type_id);
if (operand_index_ == 2)
return GetMtfVectorOfComponentType(inst_.type_id);
if (operand_index_ == 3) {
const uint32_t vector_id = GetInstWords()[3];
const uint32_t vector_type = id_to_type_id_.at(vector_id);
return GetMtfIdOfType(vector_type);
}
break;
}
case SpvOpTypeVector: {
if (operand_index_ == 1) {
return kMtfTypeScalar;
}
break;
}
case SpvOpTypeMatrix: {
if (operand_index_ == 1) {
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
break;
}
case SpvOpTypePointer: {
if (operand_index_ == 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeStruct: {
if (operand_index_ >= 1) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeFunction: {
if (operand_index_ == 1) {
return kMtfTypeNonFunction;
}
if (operand_index_ >= 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpLoad: {
if (operand_index_ == 0) return kMtfTypeNonFunction;
if (operand_index_ == 2) {
assert(inst_.type_id);
return GetMtfPointerToType(inst_.type_id);
}
break;
}
case SpvOpStore: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 1) {
const uint32_t pointer_id = GetInstWords()[1];
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
const val::Instruction* pointer_inst = FindDef(pointer_type);
assert(pointer_inst);
assert(pointer_inst->opcode() == SpvOpTypePointer);
const uint32_t data_type =
pointer_inst->word(pointer_inst->operands()[2].offset);
return GetMtfIdOfType(data_type);
}
break;
}
case SpvOpVariable: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
break;
}
case SpvOpAccessChain: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 2) return kMtfTypePointerToComposite;
if (operand_index_ >= 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
break;
}
case SpvOpCompositeConstruct: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const uint32_t composite_type = GetInstWords()[1];
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
return kMtfFloatScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
return kMtfIntScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
return kMtfBoolScalarOrVector;
}
break;
}
case SpvOpCompositeExtract: {
if (operand_index_ == 2) return kMtfComposite;
break;
}
case SpvOpConstantComposite: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const val::Instruction* composite_type_inst = FindDef(inst_.type_id);
assert(composite_type_inst);
if (composite_type_inst->opcode() == SpvOpTypeVector) {
return GetMtfIdOfType(composite_type_inst->word(2));
}
}
break;
}
case SpvOpExtInst: {
if (operand_index_ == 2)
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
if (operand_index_ >= 4) {
const uint32_t return_type = GetInstWords()[1];
const uint32_t ext_inst_type = inst_.ext_inst_type;
const uint32_t ext_inst_index = GetInstWords()[4];
// TODO(atgoo@github.com) The list of extended instructions is
// incomplete. Only common instructions and low-hanging fruits listed.
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
switch (ext_inst_index) {
case GLSLstd450FAbs:
case GLSLstd450FClamp:
case GLSLstd450FMax:
case GLSLstd450FMin:
case GLSLstd450FMix:
case GLSLstd450Step:
case GLSLstd450SmoothStep:
case GLSLstd450Fma:
case GLSLstd450Pow:
case GLSLstd450Exp:
case GLSLstd450Exp2:
case GLSLstd450Log:
case GLSLstd450Log2:
case GLSLstd450Sqrt:
case GLSLstd450InverseSqrt:
case GLSLstd450Fract:
case GLSLstd450Floor:
case GLSLstd450Ceil:
case GLSLstd450Radians:
case GLSLstd450Degrees:
case GLSLstd450Sin:
case GLSLstd450Cos:
case GLSLstd450Tan:
case GLSLstd450Sinh:
case GLSLstd450Cosh:
case GLSLstd450Tanh:
case GLSLstd450Asin:
case GLSLstd450Acos:
case GLSLstd450Atan:
case GLSLstd450Atan2:
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450MatrixInverse:
case GLSLstd450Cross:
case GLSLstd450Normalize:
case GLSLstd450Reflect:
case GLSLstd450FaceForward:
return GetMtfIdOfType(return_type);
case GLSLstd450Length:
case GLSLstd450Distance:
case GLSLstd450Refract:
return kMtfFloatScalarOrVector;
default:
break;
}
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
switch (ext_inst_index) {
case OpenCLLIB::Fabs:
case OpenCLLIB::FClamp:
case OpenCLLIB::Fmax:
case OpenCLLIB::Fmin:
case OpenCLLIB::Step:
case OpenCLLIB::Smoothstep:
case OpenCLLIB::Fma:
case OpenCLLIB::Pow:
case OpenCLLIB::Exp:
case OpenCLLIB::Exp2:
case OpenCLLIB::Log:
case OpenCLLIB::Log2:
case OpenCLLIB::Sqrt:
case OpenCLLIB::Rsqrt:
case OpenCLLIB::Fract:
case OpenCLLIB::Floor:
case OpenCLLIB::Ceil:
case OpenCLLIB::Radians:
case OpenCLLIB::Degrees:
case OpenCLLIB::Sin:
case OpenCLLIB::Cos:
case OpenCLLIB::Tan:
case OpenCLLIB::Sinh:
case OpenCLLIB::Cosh:
case OpenCLLIB::Tanh:
case OpenCLLIB::Asin:
case OpenCLLIB::Acos:
case OpenCLLIB::Atan:
case OpenCLLIB::Atan2:
case OpenCLLIB::Asinh:
case OpenCLLIB::Acosh:
case OpenCLLIB::Atanh:
case OpenCLLIB::Cross:
case OpenCLLIB::Normalize:
return GetMtfIdOfType(return_type);
case OpenCLLIB::Length:
case OpenCLLIB::Distance:
return kMtfFloatScalarOrVector;
default:
break;
}
}
}
break;
}
case SpvOpFunction: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 3) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionTypeWithReturnType(return_type);
}
break;
}
case SpvOpFunctionCall: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 2) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionWithReturnType(return_type);
}
if (operand_index_ >= 3) {
const uint32_t function_id = GetInstWords()[3];
const val::Instruction* function_inst = FindDef(function_id);
if (!function_inst) return kMtfObject;
assert(function_inst->opcode() == SpvOpFunction);
const uint32_t function_type_id = function_inst->word(4);
const val::Instruction* function_type_inst = FindDef(function_type_id);
assert(function_type_inst);
assert(function_type_inst->opcode() == SpvOpTypeFunction);
const uint32_t argument_type = function_type_inst->word(operand_index_);
return GetMtfIdOfType(argument_type);
}
break;
}
case SpvOpReturnValue: {
if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
break;
}
case SpvOpBranchConditional: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
break;
}
case SpvOpSampledImage: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
break;
}
case SpvOpImageSampleImplicitLod: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
break;
}
default:
break;
}
return kMtfNone;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,337 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_CODEC_H_
#define SOURCE_COMP_MARKV_CODEC_H_
#include <list>
#include <map>
#include <memory>
#include <vector>
#include "source/assembly_grammar.h"
#include "source/comp/huffman_codec.h"
#include "source/comp/markv_model.h"
#include "source/comp/move_to_front.h"
#include "source/diagnostic.h"
#include "source/id_descriptor.h"
#include "source/val/instruction.h"
// Base class for MARK-V encoder and decoder. Contains common functionality
// such as:
// - Validator connection and validation state.
// - SPIR-V grammar and helper functions.
namespace spvtools {
namespace comp {
class MarkvLogger;
// Handles for move-to-front sequences. Enums which end with "Begin" define
// handle spaces which start at that value and span 16 or 32 bit wide.
enum : uint64_t {
kMtfNone = 0,
// All ids.
kMtfAll,
// All forward declared ids.
kMtfForwardDeclared,
// All type ids except for generated by OpTypeFunction.
kMtfTypeNonFunction,
// All labels.
kMtfLabel,
// All ids created by instructions which had type_id.
kMtfObject,
// All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
kMtfTypeScalar,
// All composite types.
kMtfTypeComposite,
// Boolean type or any vector type of it.
kMtfTypeBoolScalarOrVector,
// All float types or any vector floats type.
kMtfTypeFloatScalarOrVector,
// All int types or any vector int type.
kMtfTypeIntScalarOrVector,
// All types declared as return types in OpTypeFunction.
kMtfTypeReturnedByFunction,
// All composite objects.
kMtfComposite,
// All bool objects or vectors of bools.
kMtfBoolScalarOrVector,
// All float objects or vectors of float.
kMtfFloatScalarOrVector,
// All int objects or vectors of int.
kMtfIntScalarOrVector,
// All pointer types which point to composited.
kMtfTypePointerToComposite,
// Used by EncodeMtfRankHuffman.
kMtfGenericNonZeroRank,
// Handle space for ids of specific type.
kMtfIdOfTypeBegin = 0x10000,
// Handle space for ids generated by specific opcode.
kMtfIdGeneratedByOpcode = 0x20000,
// Handle space for ids of objects with type generated by specific opcode.
kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
// All vectors of specific component type.
kMtfVectorOfComponentTypeBegin = 0x40000,
// All vector types of specific size.
kMtfTypeVectorOfSizeBegin = 0x50000,
// All pointer types to specific type.
kMtfPointerToTypeBegin = 0x60000,
// All function types which return specific type.
kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
// All function objects which return specific type.
kMtfFunctionWithReturnTypeBegin = 0x80000,
// Short id descriptor space (max 16-bit).
kMtfShortIdDescriptorSpaceBegin = 0x90000,
// Long id descriptor space (32-bit).
kMtfLongIdDescriptorSpaceBegin = 0x100000000,
};
class MarkvCodec {
public:
static const uint32_t kMarkvMagicNumber;
// Mtf ranks smaller than this are encoded with Huffman coding.
static const uint32_t kMtfSmallestRankEncodedByValue;
// Signals that the mtf rank is too large to be encoded with Huffman.
static const uint32_t kMtfRankEncodedByValueSignal;
static const uint32_t kShortDescriptorNumBits;
static const size_t kByteBreakAfterInstIfLessThanUntilNextByte;
static uint32_t GetMarkvVersion();
virtual ~MarkvCodec();
protected:
struct MarkvHeader {
MarkvHeader();
uint32_t magic_number;
uint32_t markv_version;
// Magic number to identify or verify MarkvModel used for encoding.
uint32_t markv_model = 0;
uint32_t markv_length_in_bits = 0;
uint32_t spirv_version = 0;
uint32_t spirv_generator = 0;
};
// |model| is owned by the caller, must be not null and valid during the
// lifetime of the codec.
MarkvCodec(spv_const_context context, spv_validator_options validator_options,
const MarkvModel* model);
// Returns instruction which created |id| or nullptr if such instruction was
// not registered.
const val::Instruction* FindDef(uint32_t id) const {
const auto it = id_to_def_instruction_.find(id);
if (it == id_to_def_instruction_.end()) return nullptr;
return it->second;
}
size_t GetNumBitsToNextByte(size_t bit_pos) const;
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const;
// Returns type id of vector type component.
uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
const val::Instruction* type_inst = FindDef(vector_type_id);
assert(type_inst);
assert(type_inst->opcode() == SpvOpTypeVector);
const uint32_t component_type =
type_inst->word(type_inst->operands()[1].offset);
return component_type;
}
// Returns mtf handle for ids of given type.
uint64_t GetMtfIdOfType(uint32_t type_id) const {
return kMtfIdOfTypeBegin + type_id;
}
// Returns mtf handle for ids generated by given opcode.
uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
return kMtfIdGeneratedByOpcode + opcode;
}
// Returns mtf handle for ids of type generated by given opcode.
uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
}
// Returns mtf handle for vectors of specific component type.
uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
return kMtfVectorOfComponentTypeBegin + type_id;
}
// Returns mtf handle for vector type of specific size.
uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
return kMtfTypeVectorOfSizeBegin + size;
}
// Returns mtf handle for pointers to specific size.
uint64_t GetMtfPointerToType(uint32_t type_id) const {
return kMtfPointerToTypeBegin + type_id;
}
// Returns mtf handle for function types with given return type.
uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
return kMtfFunctionTypeWithReturnTypeBegin + type_id;
}
// Returns mtf handle for functions with given return type.
uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
return kMtfFunctionWithReturnTypeBegin + type_id;
}
// Returns mtf handle for the given long id descriptor.
uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
return kMtfLongIdDescriptorSpaceBegin + descriptor;
}
// Returns mtf handle for the given short id descriptor.
uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
return kMtfShortIdDescriptorSpaceBegin + descriptor;
}
// Process data from the current instruction. This would update MTFs and
// other data containers.
void ProcessCurInstruction();
// Returns move-to-front handle to be used for the current operand slot.
// Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
uint64_t GetRuleBasedMtf();
// Returns words of the current instruction. Decoder has a different
// implementation and the array is valid only until the previously decoded
// word.
virtual const uint32_t* GetInstWords() const { return inst_.words; }
// Returns the opcode of the previous instruction.
SpvOp GetPrevOpcode() const {
if (instructions_.empty()) return SpvOpNop;
return instructions_.back()->opcode();
}
// Returns diagnostic stream, position index is set to instruction number.
DiagnosticStream Diag(spv_result_t error_code) const {
return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
"", error_code);
}
// Returns current id bound.
uint32_t GetIdBound() const { return id_bound_; }
// Sets current id bound, expected to be no lower than the previous one.
void SetIdBound(uint32_t id_bound) {
assert(id_bound >= id_bound_);
id_bound_ = id_bound;
}
// Returns Huffman codec for ranks of the mtf with given |handle|.
// Different mtfs can use different rank distributions.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
const auto it = mtf_huffman_codecs_.find(handle);
if (it == mtf_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Promotes id in all move-to-front sequences if ids can be shared by multiple
// sequences.
void PromoteIfNeeded(uint32_t id) {
if (!model_->AnyDescriptorHasCodingScheme() &&
model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
// Move-to-front sequences do not share ids. Nothing to do.
return;
}
multi_mtf_.Promote(id);
}
spv_validator_options validator_options_ = nullptr;
const AssemblyGrammar grammar_;
MarkvHeader header_;
// MARK-V model, not owned.
const MarkvModel* model_ = nullptr;
// Current instruction, current operand and current operand index.
spv_parsed_instruction_t inst_;
spv_parsed_operand_t operand_;
uint32_t operand_index_;
// Maps a result ID to its type ID. By convention:
// - a result ID that is a type definition maps to itself.
// - a result ID without a type maps to 0. (E.g. for OpLabel)
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
// Container for all move-to-front sequences.
MultiMoveToFront multi_mtf_;
// Id of the current function or zero if outside of function.
uint32_t cur_function_id_ = 0;
// Return type of the current function.
uint32_t cur_function_return_type_ = 0;
// Remaining function parameter types. This container is filled on OpFunction,
// and drained on OpFunctionParameter.
std::list<uint32_t> remaining_function_parameter_types_;
// List of ids local to the current function.
std::vector<uint32_t> ids_local_to_cur_function_;
// List of instructions in the order they are given in the module.
std::vector<std::unique_ptr<const val::Instruction>> instructions_;
// Container/computer for long (32-bit) id descriptors.
IdDescriptorCollection long_id_descriptors_;
// Container/computer for short id descriptors.
// Short descriptors are stored in uint32_t, but their actual bit width is
// defined with kShortDescriptorNumBits.
// It doesn't seem logical to have a different computer for short id
// descriptors, since one could actually map/truncate long descriptors.
// But as short descriptors have collisions, the efficiency of
// compression depends on the collision pattern, and short descriptors
// produced by function ShortHashU32Array have been empirically proven to
// produce better results.
IdDescriptorCollection short_id_descriptors_;
// Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
// need to contain a different codec for every handle as most use one and the
// same.
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
mtf_huffman_codecs_;
// If not nullptr, codec will log comments on the compression process.
std::unique_ptr<MarkvLogger> logger_;
spv_const_context context_ = nullptr;
private:
// Maps result id to the instruction which defined it.
std::unordered_map<uint32_t, const val::Instruction*> id_to_def_instruction_;
uint32_t id_bound_ = 1;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_CODEC_H_

View File

@ -1,925 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv_decoder.h"
#include <cstring>
#include <iterator>
#include <numeric>
#include "source/ext_inst.h"
#include "source/opcode.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to decode non-id word with Huffman";
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
// The word decoded successfully.
*word = uint32_t(decoded_value);
assert(*word == decoded_value);
return SPV_SUCCESS;
}
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
}
const size_t chunk_length =
model_->GetOperandVariableWidthChunkLength(operand_.type);
if (chunk_length) {
if (!reader_.ReadVariableWidthU32(word, chunk_length))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to decode non-id word with varint";
} else {
if (!reader_.ReadUnencoded(word))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read unencoded non-id word";
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
uint32_t* opcode, uint32_t* num_operands) {
// First try to use the Markov chain codec.
auto* codec =
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode opcode_and_num_operands, previous opcode is "
<< spvOpcodeString(GetPrevOpcode());
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
// The word was successfully decoded.
*opcode = uint32_t(decoded_value & 0xFFFF);
*num_operands = uint32_t(decoded_value >> 16);
return SPV_SUCCESS;
}
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
}
// Fallback to base-rate codec.
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
assert(codec);
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode opcode_and_num_operands with global codec";
if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
// Received kMarkvNoneOfTheAbove signal, fallback further.
return SPV_UNSUPPORTED;
}
*opcode = uint32_t(decoded_value & 0xFFFF);
*num_operands = uint32_t(decoded_value >> 16);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
uint32_t fallback_method,
uint32_t* rank) {
const auto* codec = GetMtfHuffmanCodec(mtf);
if (!codec) {
assert(fallback_method != kMtfNone);
codec = GetMtfHuffmanCodec(fallback_method);
}
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
uint32_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
if (decoded_value == kMtfRankEncodedByValueSignal) {
// Decode by value.
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode MTF rank with varint";
*rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
} else {
// Decode using Huffman coding.
assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
*rank = decoded_value;
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
auto* codec =
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
uint64_t mtf = kMtfNone;
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode descriptor with Huffman";
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
const uint32_t long_descriptor = uint32_t(decoded_value);
mtf = GetMtfLongIdDescriptor(long_descriptor);
}
}
if (mtf == kMtfNone) {
if (model_->id_fallback_strategy() !=
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
return SPV_UNSUPPORTED;
}
uint64_t decoded_value = 0;
if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
const uint32_t short_descriptor = uint32_t(decoded_value);
if (short_descriptor == 0) {
// Forward declared id.
return SPV_UNSUPPORTED;
}
mtf = GetMtfShortIdDescriptor(short_descriptor);
}
return DecodeExistingId(mtf, id);
}
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
assert(multi_mtf_.GetSize(mtf) > 0);
*id = 0;
uint32_t rank = 0;
if (multi_mtf_.GetSize(mtf) == 1) {
rank = 1;
} else {
const spv_result_t result =
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
if (result != SPV_SUCCESS) return result;
}
assert(rank);
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
{
const spv_result_t result = DecodeIdWithDescriptor(id);
if (result != SPV_UNSUPPORTED) return result;
}
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
SpvOp(inst_.opcode))(operand_index_);
uint32_t rank = 0;
*id = 0;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
uint64_t mtf = GetRuleBasedMtf();
if (mtf != kMtfNone && !can_forward_declare) {
return DecodeExistingId(mtf, id);
}
if (mtf == kMtfNone) mtf = kMtfAll;
{
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
if (result != SPV_SUCCESS) return result;
}
if (rank == 0) {
// This is the first occurrence of a forward declared id.
*id = GetIdBound();
SetIdBound(*id + 1);
multi_mtf_.Insert(kMtfAll, *id);
multi_mtf_.Insert(kMtfForwardDeclared, *id);
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
} else {
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
}
} else {
assert(can_forward_declare);
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode MTF rank with varint";
if (rank == 0) {
// This is the first occurrence of a forward declared id.
*id = GetIdBound();
SetIdBound(*id + 1);
multi_mtf_.Insert(kMtfForwardDeclared, *id);
} else {
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
}
}
assert(*id);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeTypeId() {
if (inst_.opcode == SpvOpFunctionParameter) {
assert(!remaining_function_parameter_types_.empty());
inst_.type_id = remaining_function_parameter_types_.front();
remaining_function_parameter_types_.pop_front();
return SPV_SUCCESS;
}
{
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
if (result != SPV_UNSUPPORTED) return result;
}
assert(model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased);
uint64_t mtf = GetRuleBasedMtf();
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
operand_index_));
if (mtf == kMtfNone) {
mtf = kMtfTypeNonFunction;
// Function types should have been handled by GetRuleBasedMtf.
assert(inst_.opcode != SpvOpFunction);
}
return DecodeExistingId(mtf, &inst_.type_id);
}
spv_result_t MarkvDecoder::DecodeResultId() {
uint32_t rank = 0;
const uint64_t num_still_forward_declared =
multi_mtf_.GetSize(kMtfForwardDeclared);
if (num_still_forward_declared) {
// Some ids were forward declared. Check if this id is one of them.
uint64_t id_was_forward_declared;
if (!reader_.ReadBits(&id_was_forward_declared, 1))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read id_was_forward_declared flag";
if (id_was_forward_declared) {
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read MTF rank of forward declared id";
if (rank) {
// The id was forward declared, recover it from kMtfForwardDeclared.
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
&inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Forward declared MTF rank is out of bounds";
// We can now remove the id from kMtfForwardDeclared.
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to remove id from kMtfForwardDeclared";
}
}
}
if (inst_.result_id == 0) {
// The id was not forward declared, issue a new id.
inst_.result_id = GetIdBound();
SetIdBound(inst_.result_id + 1);
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
if (!rank) {
multi_mtf_.Insert(kMtfAll, inst_.result_id);
}
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width <= 32) {
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
} else {
assert(operand.number_bit_width <= 64);
uint64_t word = 0;
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
model_->s64_block_exponent()))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
std::memcpy(&word, &val, 8);
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
if (!reader_.ReadUnencoded(&word))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
} else {
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
}
inst_words_.push_back(static_cast<uint32_t>(word));
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
}
return SPV_SUCCESS;
}
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
const size_t num_bits_to_next_byte =
GetNumBitsToNextByte(reader_.GetNumReadBits());
if (num_bits_to_next_byte == 0 ||
num_bits_to_next_byte > byte_break_if_less_than)
return true;
uint64_t bits = 0;
if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
assert(bits == 0);
if (bits != 0) return false;
return true;
}
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
const bool header_read_success =
reader_.ReadUnencoded(&header_.magic_number) &&
reader_.ReadUnencoded(&header_.markv_version) &&
reader_.ReadUnencoded(&header_.markv_model) &&
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
reader_.ReadUnencoded(&header_.spirv_version) &&
reader_.ReadUnencoded(&header_.spirv_generator);
if (!header_read_success)
return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
if (header_.markv_length_in_bits == 0)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Header markv_length_in_bits field is zero";
if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has incorrect magic number";
// TODO(atgoo@github.com): Print version strings.
if (header_.markv_version != MarkvCodec::GetMarkvVersion())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec have different versions";
const uint32_t model_type = header_.markv_model >> 16;
const uint32_t model_version = header_.markv_model & 0xFFFF;
if (model_type != model_->model_type())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec use different MARK-V models";
if (model_version != model_->model_version())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec use different versions if the same "
<< "MARK-V model";
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
spirv_.resize(5, 0);
spirv_[0] = SpvMagicNumber;
spirv_[1] = header_.spirv_version;
spirv_[2] = header_.spirv_generator;
if (logger_) {
reader_.SetCallback(
[this](const std::string& str) { logger_->AppendBitSequence(str); });
}
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
inst_ = {};
const spv_result_t decode_result = DecodeInstruction();
if (decode_result != SPV_SUCCESS) return decode_result;
}
if (validator_options_) {
spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
const spv_result_t result = spvValidateWithOptions(
context_, validator_options_, &validation_binary, nullptr);
if (result != SPV_SUCCESS) return result;
}
// Validate the decode binary
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
!reader_.OnlyZeroesLeft()) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has wrong stated bit length "
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
}
// Decoding of the module is finished, validation state should have correct
// id bound.
spirv_[3] = GetIdBound();
*spirv_binary = std::move(spirv_);
return SPV_SUCCESS;
}
// TODO(atgoo@github.com): The implementation borrows heavily from
// Parser::parseOperand.
// Consider coupling them together in some way once MARK-V codec is more mature.
// For now it's better to keep the code independent for experimentation
// purposes.
spv_result_t MarkvDecoder::DecodeOperand(
size_t operand_offset, const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands) {
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
memset(&operand_, 0, sizeof(operand_));
assert((operand_offset >> 16) == 0);
operand_.offset = static_cast<uint16_t>(operand_offset);
operand_.type = type;
// Set default values, may be updated later.
operand_.number_kind = SPV_NUMBER_NONE;
operand_.number_bit_width = 0;
const size_t first_word_index = inst_words_.size();
switch (type) {
case SPV_OPERAND_TYPE_RESULT_ID: {
const spv_result_t result = DecodeResultId();
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(inst_.result_id);
SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
PromoteIfNeeded(inst_.result_id);
break;
}
case SPV_OPERAND_TYPE_TYPE_ID: {
const spv_result_t result = DecodeTypeId();
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(inst_.type_id);
SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
PromoteIfNeeded(inst_.type_id);
break;
}
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
uint32_t id = 0;
const spv_result_t result = DecodeRefId(&id);
if (result != SPV_SUCCESS) return result;
if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
operand_.type = SPV_OPERAND_TYPE_ID;
if (opcode == SpvOpExtInst && operand_.offset == 3) {
// The current word is the extended instruction set id.
// Set the extended instruction set type for the current
// instruction.
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
return Diag(SPV_ERROR_INVALID_ID)
<< "OpExtInst set id " << id
<< " does not reference an OpExtInstImport result Id";
}
inst_.ext_inst_type = ext_inst_type_iter->second;
}
}
inst_words_.push_back(id);
SetIdBound(std::max(GetIdBound(), id + 1));
PromoteIfNeeded(id);
break;
}
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
assert(SpvOpExtInst == opcode);
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
spv_ext_inst_desc ext_inst;
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction number: " << word;
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
// These are regular single-word literal integer operands.
// Post-parsing validation should check the range of the parsed value.
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
// It turns out they are always unsigned integers!
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
operand_.number_bit_width = 32;
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
if (opcode == SpvOpSwitch) {
// The literal operands have the same type as the value
// referenced by the selector Id.
const uint32_t selector_id = inst_words_.at(1);
const auto type_id_iter = id_to_type_id_.find(selector_id);
if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " has no type";
}
uint32_t type_id = type_id_iter->second;
if (selector_id == type_id) {
// Recall that by convention, a result ID that is a type definition
// maps to itself.
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is a type, not a value";
}
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
return error;
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is not a scalar integer";
}
} else {
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
// The literal number type is determined by the type Id for the
// constant.
assert(inst_.type_id);
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
return error;
}
if (auto error = DecodeLiteralNumber(operand_)) return error;
break;
}
case SPV_OPERAND_TYPE_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
std::vector<char> str;
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
if (codec) {
std::string decoded_string;
const bool huffman_result =
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
assert(huffman_result);
if (!huffman_result)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal string";
if (decoded_string != "kMarkvNoneOfTheAbove") {
std::copy(decoded_string.begin(), decoded_string.end(),
std::back_inserter(str));
str.push_back('\0');
}
}
// The loop is expected to terminate once we encounter '\0' or exhaust
// the bit stream.
if (str.empty()) {
while (true) {
char ch = 0;
if (!reader_.ReadUnencoded(&ch))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal string";
str.push_back(ch);
if (ch == '\0') break;
}
}
while (str.size() % 4 != 0) str.push_back('\0');
inst_words_.resize(inst_words_.size() + str.size() / 4);
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
if (SpvOpExtInstImport == opcode) {
// Record the extended instruction type for the ID for this import.
// There is only one string literal argument to OpExtInstImport,
// so it's sufficient to guard this just on the opcode.
const spv_ext_inst_type_t ext_inst_type =
spvExtInstImportTypeGet(str.data());
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction import '" << str.data()
<< "'";
}
// We must have parsed a valid result ID. It's a condition
// of the grammar, and we only accept non-zero result Ids.
assert(inst_.result_id);
const bool inserted =
import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
.second;
(void)inserted;
assert(inserted);
}
break;
}
case SPV_OPERAND_TYPE_CAPABILITY:
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
case SPV_OPERAND_TYPE_MEMORY_MODEL:
case SPV_OPERAND_TYPE_EXECUTION_MODE:
case SPV_OPERAND_TYPE_STORAGE_CLASS:
case SPV_OPERAND_TYPE_DIMENSIONALITY:
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
case SPV_OPERAND_TYPE_DECORATION:
case SPV_OPERAND_TYPE_BUILT_IN:
case SPV_OPERAND_TYPE_GROUP_OPERATION:
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
// A single word that is a plain enum value.
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
spv_operand_desc entry;
if (grammar_.lookupOperand(type, word, &entry)) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid " << spvOperandTypeStr(operand_.type)
<< " operand: " << word;
}
// Prepare to accept operands to this operand, if needed.
spvPushOperandTypes(entry->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
// This operand is a mask.
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
operand_.type = SPV_OPERAND_TYPE_IMAGE;
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
// MSB to LSB since we can only prepend operands to a pattern.
// The only case in the grammar where you have more than one mask bit
// having an operand is for image operands. See SPIR-V 3.14 Image
// Operands.
uint32_t remaining_word = word;
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
if (remaining_word & mask) {
spv_operand_desc entry;
if (grammar_.lookupOperand(type, mask, &entry)) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid " << spvOperandTypeStr(operand_.type)
<< " operand: " << word << " has invalid mask component "
<< mask;
}
remaining_word ^= mask;
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
if (word == 0) {
// An all-zeroes mask *might* also be valid.
spv_operand_desc entry;
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
// Prepare for its operands, if any.
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
break;
}
default:
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Internal error: Unhandled operand type: " << type;
}
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
assert(spvOperandIsConcrete(operand_.type));
parsed_operands_.push_back(operand_);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeInstruction() {
parsed_operands_.clear();
inst_words_.clear();
// Opcode/num_words placeholder, the word will be filled in later.
inst_words_.push_back(0);
bool num_operands_still_unknown = true;
{
uint32_t opcode = 0;
uint32_t num_operands = 0;
const spv_result_t opcode_decoding_result =
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
if (opcode_decoding_result < 0) return opcode_decoding_result;
if (opcode_decoding_result == SPV_SUCCESS) {
inst_.num_operands = static_cast<uint16_t>(num_operands);
num_operands_still_unknown = false;
} else {
if (!reader_.ReadVariableWidthU32(&opcode,
model_->opcode_chunk_length())) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read opcode of instruction";
}
}
inst_.opcode = static_cast<uint16_t>(opcode);
}
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
spv_opcode_desc opcode_desc;
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
}
spv_operand_pattern_t expected_operands;
expected_operands.reserve(opcode_desc->numTypes);
for (auto i = 0; i < opcode_desc->numTypes; i++) {
expected_operands.push_back(
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
}
if (num_operands_still_unknown) {
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
model_->num_operands_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read num_operands of instruction";
} else {
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
}
}
for (operand_index_ = 0;
operand_index_ < static_cast<size_t>(inst_.num_operands);
++operand_index_) {
assert(!expected_operands.empty());
const spv_operand_type_t type =
spvTakeFirstMatchableOperand(&expected_operands);
const size_t operand_offset = inst_words_.size();
const spv_result_t decode_result =
DecodeOperand(operand_offset, type, &expected_operands);
if (decode_result != SPV_SUCCESS) return decode_result;
}
assert(inst_.num_operands == parsed_operands_.size());
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
// next DecodeInstruction call).
inst_.words = inst_words_.data();
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
assert(inst_.num_words ==
std::accumulate(
parsed_operands_.begin(), parsed_operands_.end(), 1,
[](int num_words, const spv_parsed_operand_t& operand) {
return num_words += operand.num_words;
}) &&
"num_words in instruction doesn't correspond to the sum of num_words"
"in the operands");
RecordNumberType();
ProcessCurInstruction();
if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
if (logger_) {
logger_->NewLine();
std::stringstream ss;
ss << spvOpcodeString(opcode) << " ";
for (size_t index = 1; index < inst_words_.size(); ++index)
ss << inst_words_[index] << " ";
logger_->AppendText(ss.str());
logger_->NewLine();
logger_->NewLine();
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
assert(type_id != 0);
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
if (type_info_iter == type_id_to_number_type_info_.end()) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a type";
}
const NumberType& info = type_info_iter->second;
if (info.type == SPV_NUMBER_NONE) {
// This is a valid type, but for something other than a scalar number.
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a scalar numeric type";
}
parsed_operand->number_kind = info.type;
parsed_operand->number_bit_width = info.bit_width;
// Round up the word count.
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
return SPV_SUCCESS;
}
void MarkvDecoder::RecordNumberType() {
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
if (spvOpcodeGeneratesType(opcode)) {
NumberType info = {SPV_NUMBER_NONE, 0};
if (SpvOpTypeInt == opcode) {
info.bit_width = inst_.words[inst_.operands[1].offset];
info.type = inst_.words[inst_.operands[2].offset]
? SPV_NUMBER_SIGNED_INT
: SPV_NUMBER_UNSIGNED_INT;
} else if (SpvOpTypeFloat == opcode) {
info.bit_width = inst_.words[inst_.operands[1].offset];
info.type = SPV_NUMBER_FLOATING;
}
// The *result* Id of a type generating instruction is the type Id.
type_id_to_number_type_info_[inst_.result_id] = info;
}
}
} // namespace comp
} // namespace spvtools

View File

@ -1,175 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/bit_stream.h"
#include "source/comp/markv.h"
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/util/make_unique.h"
#ifndef SOURCE_COMP_MARKV_DECODER_H_
#define SOURCE_COMP_MARKV_DECODER_H_
namespace spvtools {
namespace comp {
class MarkvLogger;
// Decodes MARK-V buffers written by MarkvEncoder.
class MarkvDecoder : public MarkvCodec {
public:
// |model| is owned by the caller, must be not null and valid during the
// lifetime of MarkvEncoder.
MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel* model)
: MarkvCodec(context, GetValidatorOptions(options), model),
options_(options),
reader_(markv) {
SetIdBound(1);
parsed_operands_.reserve(25);
inst_words_.reserve(25);
}
~MarkvDecoder() = default;
// Creates an internal logger which writes comments on the decoding process.
void CreateLogger(MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer) {
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
}
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
// Can be called only once. Fails if data of wrong format or ends prematurely,
// of if validation fails.
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
// Creates and returns validator options. Returned value owned by the caller.
static spv_validator_options GetValidatorOptions(
const MarkvCodecOptions& options) {
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
: nullptr;
}
private:
// Describes the format of a typed literal number.
struct NumberType {
spv_number_kind_t type;
uint32_t bit_width;
};
// Reads a single bit from reader_. The read bit is stored in |bit|.
// Returns false iff reader_ fails.
bool ReadBit(bool* bit) {
uint64_t bits = 0;
const bool result = reader_.ReadBits(&bits, 1);
if (result) *bit = bits ? true : false;
return result;
};
// Returns ReadBit bound to the class object.
std::function<bool(bool*)> GetReadBitCallback() {
return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
}
// Reads a single non-id word from bit stream. operand_.type determines if
// the word needs to be decoded and how.
spv_result_t DecodeNonIdWord(uint32_t* word);
// Reads and decodes both opcode and num_operands as a single code.
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
uint32_t* num_operands);
// Reads mtf rank from bit stream. |mtf| is used to determine the codec
// scheme. |fallback_method| is used if no codec defined for |mtf|.
spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
uint32_t* rank);
// Reads id using coding based on mtf associated with the id descriptor.
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
spv_result_t DecodeIdWithDescriptor(uint32_t* id);
// Reads id using coding based on the given |mtf|, which is expected to
// contain the needed |id|.
spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
// Reads type id of the current instruction if can't be inferred.
spv_result_t DecodeTypeId();
// Reads result id of the current instruction if can't be inferred.
spv_result_t DecodeResultId();
// Reads id which is neither type nor result id.
spv_result_t DecodeRefId(uint32_t* id);
// Reads and discards bits until the beginning of the next byte if the
// number of bits until the next byte is less than |byte_break_if_less_than|.
bool ReadToByteBreak(size_t byte_break_if_less_than);
// Returns instruction words decoded up to this point.
const uint32_t* GetInstWords() const override { return inst_words_.data(); }
// Reads a literal number as it is described in |operand| from the bit stream,
// decodes and writes it to spirv_.
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
// Reads instruction from bit stream, decodes and validates it.
// Decoded instruction is valid until the next call of DecodeInstruction().
spv_result_t DecodeInstruction();
// Read operand from the stream decodes and validates it.
spv_result_t DecodeOperand(size_t operand_offset,
const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands);
// Records the numeric type for an operand according to the type information
// associated with the given non-zero type Id. This can fail if the type Id
// is not a type Id, or if the type Id does not reference a scalar numeric
// type. On success, return SPV_SUCCESS and populates the num_words,
// number_kind, and number_bit_width fields of parsed_operand.
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
uint32_t type_id);
// Records the number type for the current instruction, if it generates a
// type. For types that aren't scalar numbers, record something with number
// kind SPV_NUMBER_NONE.
void RecordNumberType();
MarkvCodecOptions options_;
// Temporary sink where decoded SPIR-V words are written. Once it contains the
// entire module, the container is moved and returned.
std::vector<uint32_t> spirv_;
// Bit stream containing encoded data.
BitReaderWord64 reader_;
// Temporary storage for operands of the currently parsed instruction.
// Valid until next DecodeInstruction call.
std::vector<spv_parsed_operand_t> parsed_operands_;
// Temporary storage for current instruction words.
// Valid until next DecodeInstruction call.
std::vector<uint32_t> inst_words_;
// Maps a type ID to its number type description.
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
// Maps an ExtInstImport id to the extended instruction type.
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_DECODER_H_

View File

@ -1,486 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv_encoder.h"
#include "source/binary.h"
#include "source/opcode.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
namespace {
const size_t kCommentNumWhitespaces = 2;
} // namespace
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
if (codec->Encode(word, &bits, &num_bits)) {
// Encoding successful.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// Encoding failed, write kMarkvNoneOfTheAbove flag.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Non-id word Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback encoding.
const size_t chunk_length =
model_->GetOperandVariableWidthChunkLength(operand_.type);
if (chunk_length) {
writer_.WriteVariableWidthU32(word, chunk_length);
} else {
writer_.WriteUnencoded(word);
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands) {
uint64_t bits = 0;
size_t num_bits = 0;
const uint32_t word = opcode | (num_operands << 16);
// First try to use the Markov chain codec.
auto* codec =
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
if (codec) {
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and use fallback encoding.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "opcode_and_num_operands Huffman table for "
<< spvOpcodeString(GetPrevOpcode())
<< "is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback to base-rate codec.
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
assert(codec);
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and return false.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Global opcode_and_num_operands Huffman table is missing "
<< "kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
return SPV_UNSUPPORTED;
}
}
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
uint64_t fallback_method) {
const auto* codec = GetMtfHuffmanCodec(mtf);
if (!codec) {
assert(fallback_method != kMtfNone);
codec = GetMtfHuffmanCodec(fallback_method);
}
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
uint64_t bits = 0;
size_t num_bits = 0;
if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) {
// Encode using Huffman coding.
if (!codec->Encode(rank, &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode MTF rank with Huffman";
writer_.WriteBits(bits, num_bits);
} else {
// Encode by value.
if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode kMtfRankEncodedByValueSignal";
writer_.WriteBits(bits, num_bits);
writer_.WriteVariableWidthU32(
rank - MarkvCodec::kMtfSmallestRankEncodedByValue,
model_->mtf_rank_chunk_length());
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
// Get the descriptor for id.
const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
auto* codec =
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
uint64_t bits = 0;
size_t num_bits = 0;
uint64_t mtf = kMtfNone;
if (long_descriptor && codec &&
codec->Encode(long_descriptor, &bits, &num_bits)) {
// If the descriptor exists and is in the table, write the descriptor and
// proceed to encoding the rank.
writer_.WriteBits(bits, num_bits);
mtf = GetMtfLongIdDescriptor(long_descriptor);
} else {
if (codec) {
// The descriptor doesn't exist or we have no coding for it. Write
// kMarkvNoneOfTheAbove and go to fallback method.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Descriptor Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
if (model_->id_fallback_strategy() !=
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
return SPV_UNSUPPORTED;
}
const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits);
if (short_descriptor == 0) {
// Forward declared id.
return SPV_UNSUPPORTED;
}
mtf = GetMtfShortIdDescriptor(short_descriptor);
}
// Descriptor has been encoded. Now encode the rank of the id in the
// associated mtf sequence.
return EncodeExistingId(mtf, id);
}
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
assert(multi_mtf_.GetSize(mtf) > 0);
if (multi_mtf_.GetSize(mtf) == 1) {
// If the sequence has only one element no need to write rank, the decoder
// would make the same decision.
return SPV_SUCCESS;
}
uint32_t rank = 0;
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
}
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
SpvOp(inst_.opcode))(operand_index_);
uint32_t rank = 0;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
// Encode using rule-based mtf.
uint64_t mtf = GetRuleBasedMtf();
if (mtf != kMtfNone && !can_forward_declare) {
assert(multi_mtf_.HasValue(kMtfAll, id));
return EncodeExistingId(mtf, id);
}
if (mtf == kMtfNone) mtf = kMtfAll;
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfAll, id);
multi_mtf_.Insert(kMtfForwardDeclared, id);
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
rank = 0;
}
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
} else {
assert(can_forward_declare);
if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfForwardDeclared, id);
rank = 0;
}
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
return SPV_SUCCESS;
}
}
spv_result_t MarkvEncoder::EncodeTypeId() {
if (inst_.opcode == SpvOpFunctionParameter) {
assert(!remaining_function_parameter_types_.empty());
assert(inst_.type_id == remaining_function_parameter_types_.front());
remaining_function_parameter_types_.pop_front();
return SPV_SUCCESS;
}
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
assert(model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased);
uint64_t mtf = GetRuleBasedMtf();
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
operand_index_));
if (mtf == kMtfNone) {
mtf = kMtfTypeNonFunction;
// Function types should have been handled by GetRuleBasedMtf.
assert(inst_.opcode != SpvOpFunction);
}
return EncodeExistingId(mtf, inst_.type_id);
}
spv_result_t MarkvEncoder::EncodeResultId() {
uint32_t rank = 0;
const uint64_t num_still_forward_declared =
multi_mtf_.GetSize(kMtfForwardDeclared);
if (num_still_forward_declared) {
// We write the rank only if kMtfForwardDeclared is not empty. If it is
// empty the decoder knows that there are no forward declared ids to expect.
if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
// This is a definition of a forward declared id. We can remove the id
// from kMtfForwardDeclared.
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to remove id from kMtfForwardDeclared";
writer_.WriteBits(1, 1);
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
} else {
rank = 0;
writer_.WriteBits(0, 1);
}
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
if (!rank) {
multi_mtf_.Insert(kMtfAll, inst_.result_id);
}
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width <= 32) {
const uint32_t word = inst_.words[operand.offset];
return EncodeNonIdWord(word);
} else {
assert(operand.number_bit_width <= 64);
const uint64_t word = uint64_t(inst_.words[operand.offset]) |
(uint64_t(inst_.words[operand.offset + 1]) << 32);
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
std::memcpy(&val, &word, 8);
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
model_->s64_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
writer_.WriteUnencoded(word);
} else {
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
}
}
return SPV_SUCCESS;
}
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
const size_t num_bits_to_next_byte =
GetNumBitsToNextByte(writer_.GetNumBits());
if (num_bits_to_next_byte == 0 ||
num_bits_to_next_byte > byte_break_if_less_than)
return;
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<byte break>");
}
writer_.WriteBits(0, num_bits_to_next_byte);
}
spv_result_t MarkvEncoder::EncodeInstruction(
const spv_parsed_instruction_t& inst) {
SpvOp opcode = SpvOp(inst.opcode);
inst_ = inst;
LogDisassemblyInstruction();
const spv_result_t opcode_encodig_result =
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
if (opcode_encodig_result < 0) return opcode_encodig_result;
if (opcode_encodig_result != SPV_SUCCESS) {
// Fallback encoding for opcode and num_operands.
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
// If the opcode has a variable number of operands, encode the number of
// operands with the instruction.
if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
writer_.WriteVariableWidthU16(inst.num_operands,
model_->num_operands_chunk_length());
}
}
// Write operands.
const uint32_t num_operands = inst_.num_operands;
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
operand_ = inst_.operands[operand_index_];
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<");
logger_->AppendText(spvOperandTypeStr(operand_.type));
logger_->AppendText(">");
}
switch (operand_.type) {
case SPV_OPERAND_TYPE_RESULT_ID:
case SPV_OPERAND_TYPE_TYPE_ID:
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
const uint32_t id = inst_.words[operand_.offset];
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
const spv_result_t result = EncodeTypeId();
if (result != SPV_SUCCESS) return result;
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
const spv_result_t result = EncodeResultId();
if (result != SPV_SUCCESS) return result;
} else {
const spv_result_t result = EncodeRefId(id);
if (result != SPV_SUCCESS) return result;
}
PromoteIfNeeded(id);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
const spv_result_t result =
EncodeNonIdWord(inst_.words[operand_.offset]);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
const spv_result_t result = EncodeLiteralNumber(operand_);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_LITERAL_STRING: {
const char* src =
reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
const std::string str = src;
if (codec->Encode(str, &bits, &num_bits)) {
writer_.WriteBits(bits, num_bits);
break;
} else {
bool result =
codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
(void)result;
assert(result);
writer_.WriteBits(bits, num_bits);
}
}
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
if (length == operand_.num_words * 4)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to find terminal character of literal string";
for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
break;
}
default: {
for (int i = 0; i < operand_.num_words; ++i) {
const uint32_t word = inst_.words[operand_.offset + i];
const spv_result_t result = EncodeNonIdWord(word);
if (result != SPV_SUCCESS) return result;
}
break;
}
}
}
AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte);
if (logger_) {
logger_->NewLine();
logger_->NewLine();
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
}
ProcessCurInstruction();
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,167 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/bit_stream.h"
#include "source/comp/markv.h"
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/util/make_unique.h"
#ifndef SOURCE_COMP_MARKV_ENCODER_H_
#define SOURCE_COMP_MARKV_ENCODER_H_
#include <cstring>
namespace spvtools {
namespace comp {
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
// EncodeInstruction which can be used as callback by spvBinaryParse.
// Encoded binary is written to an internally maintained bitstream.
// After the last instruction is encoded, the resulting MARK-V binary can be
// acquired by calling GetMarkvBinary().
//
// The encoder uses SPIR-V validator to keep internal state, therefore
// SPIR-V binary needs to be able to pass validator checks.
// CreateCommentsLogger() can be used to enable the encoder to write comments
// on how encoding was done, which can later be accessed with GetComments().
class MarkvEncoder : public MarkvCodec {
public:
// |model| is owned by the caller, must be not null and valid during the
// lifetime of MarkvEncoder.
MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options,
const MarkvModel* model)
: MarkvCodec(context, GetValidatorOptions(options), model),
options_(options) {}
~MarkvEncoder() override = default;
// Writes data from SPIR-V header to MARK-V header.
spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */,
uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t /* schema */) {
SetIdBound(id_bound);
header_.spirv_version = version;
header_.spirv_generator = generator;
return SPV_SUCCESS;
}
// Creates an internal logger which writes comments on the encoding process.
void CreateLogger(MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer) {
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
writer_.SetCallback(
[this](const std::string& str) { logger_->AppendBitSequence(str); });
}
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
// Operation can fail if the instruction fails to pass the validator or if
// the encoder stubmles on something unexpected.
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
// Concatenates MARK-V header and the bit stream with encoded instructions
// into a single buffer and returns it as spv_markv_binary. The returned
// value is owned by the caller and needs to be destroyed with
// spvMarkvBinaryDestroy().
std::vector<uint8_t> GetMarkvBinary() {
header_.markv_length_in_bits =
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
header_.markv_model =
(model_->model_type() << 16) | model_->model_version();
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
std::vector<uint8_t> markv(num_bytes);
assert(writer_.GetData());
std::memcpy(markv.data(), &header_, sizeof(header_));
std::memcpy(markv.data() + sizeof(header_), writer_.GetData(),
writer_.GetDataSizeBytes());
return markv;
}
// Optionally adds disassembly to the comments.
// Disassembly should contain all instructions in the module separated by
// \n, and no header.
void SetDisassembly(std::string&& disassembly) {
disassembly_ = MakeUnique<std::stringstream>(std::move(disassembly));
}
// Extracts the next instruction line from the disassembly and logs it.
void LogDisassemblyInstruction() {
if (logger_ && disassembly_) {
std::string line;
std::getline(*disassembly_, line, '\n');
logger_->AppendTextNewLine(line);
}
}
private:
// Creates and returns validator options. Returned value owned by the caller.
static spv_validator_options GetValidatorOptions(
const MarkvCodecOptions& options) {
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
: nullptr;
}
// Writes a single word to bit stream. operand_.type determines if the word is
// encoded and how.
spv_result_t EncodeNonIdWord(uint32_t word);
// Writes both opcode and num_operands as a single code.
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands);
// Writes mtf rank to bit stream. |mtf| is used to determine the codec
// scheme. |fallback_method| is used if no codec defined for |mtf|.
spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
uint64_t fallback_method);
// Writes id using coding based on mtf associated with the id descriptor.
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
spv_result_t EncodeIdWithDescriptor(uint32_t id);
// Writes id using coding based on the given |mtf|, which is expected to
// contain the given |id|.
spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
// Writes type id of the current instruction if can't be inferred.
spv_result_t EncodeTypeId();
// Writes result id of the current instruction if can't be inferred.
spv_result_t EncodeResultId();
// Writes ids which are neither type nor result ids.
spv_result_t EncodeRefId(uint32_t id);
// Writes bits to the stream until the beginning of the next byte if the
// number of bits until the next byte is less than |byte_break_if_less_than|.
void AddByteBreak(size_t byte_break_if_less_than);
// Encodes a literal number operand and writes it to the bit stream.
spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
MarkvCodecOptions options_;
// Bit stream where encoded instructions are written.
BitWriterWord64 writer_;
// If not nullptr, disassembled instruction lines will be written to comments.
// Format: \n separated instruction lines, no header.
std::unique_ptr<std::stringstream> disassembly_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_ENCODER_H_

View File

@ -1,93 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_LOGGER_H_
#define SOURCE_COMP_MARKV_LOGGER_H_
#include "source/comp/markv.h"
namespace spvtools {
namespace comp {
class MarkvLogger {
public:
MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer)
: log_consumer_(log_consumer), debug_consumer_(debug_consumer) {}
void AppendText(const std::string& str) {
Append(str);
use_delimiter_ = false;
}
void AppendTextNewLine(const std::string& str) {
Append(str);
Append("\n");
use_delimiter_ = false;
}
void AppendBitSequence(const std::string& str) {
if (debug_consumer_) instruction_bits_ << str;
if (use_delimiter_) Append("-");
Append(str);
use_delimiter_ = true;
}
void AppendWhitespaces(size_t num) {
Append(std::string(num, ' '));
use_delimiter_ = false;
}
void NewLine() {
Append("\n");
use_delimiter_ = false;
}
bool DebugInstruction(const spv_parsed_instruction_t& inst) {
bool result = true;
if (debug_consumer_) {
result = debug_consumer_(
std::vector<uint32_t>(inst.words, inst.words + inst.num_words),
instruction_bits_.str(), instruction_comment_.str());
instruction_bits_.str(std::string());
instruction_comment_.str(std::string());
}
return result;
}
private:
MarkvLogger(const MarkvLogger&) = delete;
MarkvLogger(MarkvLogger&&) = delete;
MarkvLogger& operator=(const MarkvLogger&) = delete;
MarkvLogger& operator=(MarkvLogger&&) = delete;
void Append(const std::string& str) {
if (log_consumer_) log_consumer_(str);
if (debug_consumer_) instruction_comment_ << str;
}
MarkvLogConsumer log_consumer_;
MarkvDebugConsumer debug_consumer_;
std::stringstream instruction_bits_;
std::stringstream instruction_comment_;
// If true a delimiter will be appended before the next bit sequence.
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
bool use_delimiter_ = false;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_LOGGER_H_

View File

@ -1,232 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_MODEL_H_
#define SOURCE_COMP_MARKV_MODEL_H_
#include <unordered_set>
#include "source/comp/huffman_codec.h"
#include "source/latest_version_spirv_header.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
// Base class for MARK-V models.
// The class contains encoding/decoding model with various constants and
// codecs used by the compression algorithm.
class MarkvModel {
public:
MarkvModel()
: operand_chunk_lengths_(
static_cast<size_t>(SPV_OPERAND_TYPE_NUM_OPERAND_TYPES), 0) {
// Set default values.
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPE_ID] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_RESULT_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SCOPE_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LITERAL_INTEGER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_CAPABILITY] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SOURCE_LANGUAGE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODEL] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ADDRESSING_MODEL] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_MODEL] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODE] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_STORAGE_CLASS] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_DIMENSIONALITY] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_ROUNDING_MODE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LINKAGE_TYPE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ACCESS_QUALIFIER] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_DECORATION] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_BUILT_IN] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_GROUP_OPERATION] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_FAST_MATH_MODE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LOOP_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_IMAGE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_IMAGE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SELECTION_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER] = 6;
}
uint32_t model_type() const { return model_type_; }
uint32_t model_version() const { return model_version_; }
uint32_t opcode_chunk_length() const { return opcode_chunk_length_; }
uint32_t num_operands_chunk_length() const {
return num_operands_chunk_length_;
}
uint32_t mtf_rank_chunk_length() const { return mtf_rank_chunk_length_; }
uint32_t u64_chunk_length() const { return u64_chunk_length_; }
uint32_t s64_chunk_length() const { return s64_chunk_length_; }
uint32_t s64_block_exponent() const { return s64_block_exponent_; }
enum class IdFallbackStrategy {
kRuleBased = 0,
kShortDescriptor,
};
IdFallbackStrategy id_fallback_strategy() const {
return id_fallback_strategy_;
}
// Returns a codec for common opcode_and_num_operands words for the given
// previous opcode. May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetOpcodeAndNumOperandsMarkovHuffmanCodec(
uint32_t prev_opcode) const {
if (prev_opcode == SpvOpNop)
return opcode_and_num_operands_huffman_codec_.get();
const auto it =
opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode);
if (it == opcode_and_num_operands_markov_huffman_codecs_.end())
return nullptr;
return it->second.get();
}
// Returns a codec for common non-id words used for given operand slot.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetNonIdWordHuffmanCodec(
uint32_t opcode, uint32_t operand_index) const {
const auto it = non_id_word_huffman_codecs_.find(
std::pair<uint32_t, uint32_t>(opcode, operand_index));
if (it == non_id_word_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Returns a codec for common id descriptos used for given operand slot.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetIdDescriptorHuffmanCodec(
uint32_t opcode, uint32_t operand_index) const {
const auto it = id_descriptor_huffman_codecs_.find(
std::pair<uint32_t, uint32_t>(opcode, operand_index));
if (it == id_descriptor_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Returns a codec for common strings used by the given opcode.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<std::string>* GetLiteralStringHuffmanCodec(
uint32_t opcode) const {
const auto it = literal_string_huffman_codecs_.find(opcode);
if (it == literal_string_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Checks if |descriptor| has a coding scheme in any of
// id_descriptor_huffman_codecs_.
bool DescriptorHasCodingScheme(uint32_t descriptor) const {
return descriptors_with_coding_scheme_.count(descriptor);
}
// Checks if any descriptor has a coding scheme.
bool AnyDescriptorHasCodingScheme() const {
return !descriptors_with_coding_scheme_.empty();
}
// Returns chunk length used for variable length encoding of spirv operand
// words.
uint32_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) const {
return operand_chunk_lengths_.at(static_cast<size_t>(type));
}
// Sets model type.
void SetModelType(uint32_t in_model_type) { model_type_ = in_model_type; }
// Sets model version.
void SetModelVersion(uint32_t in_model_version) {
model_version_ = in_model_version;
}
// Returns value used by Huffman codecs as a signal that a value is not in the
// coding table.
static uint64_t GetMarkvNoneOfTheAbove() {
// Magic number.
return 1111111111111111111;
}
MarkvModel(const MarkvModel&) = delete;
const MarkvModel& operator=(const MarkvModel&) = delete;
protected:
// Huffman codec for base-rate of opcode_and_num_operands.
std::unique_ptr<HuffmanCodec<uint64_t>>
opcode_and_num_operands_huffman_codec_;
// Huffman codecs for opcode_and_num_operands. The map key is previous opcode.
std::map<uint32_t, std::unique_ptr<HuffmanCodec<uint64_t>>>
opcode_and_num_operands_markov_huffman_codecs_;
// Huffman codecs for non-id single-word operand values.
// The map key is pair <opcode, operand_index>.
std::map<std::pair<uint32_t, uint32_t>,
std::unique_ptr<HuffmanCodec<uint64_t>>>
non_id_word_huffman_codecs_;
// Huffman codecs for id descriptors. The map key is pair
// <opcode, operand_index>.
std::map<std::pair<uint32_t, uint32_t>,
std::unique_ptr<HuffmanCodec<uint64_t>>>
id_descriptor_huffman_codecs_;
// Set of all descriptors which have a coding scheme in any of
// id_descriptor_huffman_codecs_.
std::unordered_set<uint32_t> descriptors_with_coding_scheme_;
// Huffman codecs for literal strings. The map key is the opcode of the
// current instruction. This assumes, that there is no more than one literal
// string operand per instruction, but would still work even if this is not
// the case. Names and debug information strings are not collected.
std::map<uint32_t, std::unique_ptr<HuffmanCodec<std::string>>>
literal_string_huffman_codecs_;
// Chunk lengths used for variable width encoding of operands (index is
// spv_operand_type of the operand).
std::vector<uint32_t> operand_chunk_lengths_;
uint32_t opcode_chunk_length_ = 7;
uint32_t num_operands_chunk_length_ = 3;
uint32_t mtf_rank_chunk_length_ = 5;
uint32_t u64_chunk_length_ = 8;
uint32_t s64_chunk_length_ = 8;
uint32_t s64_block_exponent_ = 10;
IdFallbackStrategy id_fallback_strategy_ =
IdFallbackStrategy::kShortDescriptor;
uint32_t model_type_ = 0;
uint32_t model_version_ = 0;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_MODEL_H_

View File

@ -1,456 +0,0 @@
// Copyright (c) 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/move_to_front.h"
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sstream>
#include <unordered_set>
#include <utility>
namespace spvtools {
namespace comp {
bool MoveToFront::Insert(uint32_t value) {
auto it = value_to_node_.find(value);
if (it != value_to_node_.end() && IsInTree(it->second)) return false;
const uint32_t old_size = GetSize();
(void)old_size;
InsertNode(CreateNode(next_timestamp_++, value));
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
assert(value_to_node_.count(value));
assert(old_size + 1 == GetSize());
return true;
}
bool MoveToFront::Remove(uint32_t value) {
auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) return false;
if (!IsInTree(it->second)) return false;
if (last_accessed_value_ == value) last_accessed_value_valid_ = false;
const uint32_t orphan = RemoveNode(it->second);
(void)orphan;
// The node of |value| is still alive but it's orphaned now. Can still be
// reused later.
assert(!IsInTree(orphan));
assert(ValueOf(orphan) == value);
return true;
}
bool MoveToFront::RankFromValue(uint32_t value, uint32_t* rank) {
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
*rank = 1;
return true;
}
const uint32_t old_size = GetSize();
if (old_size == 1) {
if (ValueOf(root_) == value) {
*rank = 1;
return true;
} else {
return false;
}
}
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
uint32_t target = it->second;
if (!IsInTree(target)) {
return false;
}
uint32_t node = target;
*rank = 1 + SizeOf(LeftOf(node));
while (node) {
if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node)));
node = ParentOf(node);
}
// Don't update timestamp if the node has rank 1.
if (*rank != 1) {
// Update timestamp and reposition the node.
target = RemoveNode(target);
assert(ValueOf(target) == value);
assert(old_size == GetSize() + 1);
MutableTimestampOf(target) = next_timestamp_++;
InsertNode(target);
assert(old_size == GetSize());
}
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
return true;
}
bool MoveToFront::HasValue(uint32_t value) const {
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
return IsInTree(it->second);
}
bool MoveToFront::Promote(uint32_t value) {
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
return true;
}
const uint32_t old_size = GetSize();
if (old_size == 1) return ValueOf(root_) == value;
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
uint32_t target = it->second;
if (!IsInTree(target)) {
return false;
}
// Update timestamp and reposition the node.
target = RemoveNode(target);
assert(ValueOf(target) == value);
assert(old_size == GetSize() + 1);
MutableTimestampOf(target) = next_timestamp_++;
InsertNode(target);
assert(old_size == GetSize());
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
return true;
}
bool MoveToFront::ValueFromRank(uint32_t rank, uint32_t* value) {
if (last_accessed_value_valid_ && rank == 1) {
*value = last_accessed_value_;
return true;
}
const uint32_t old_size = GetSize();
if (rank <= 0 || rank > old_size) {
return false;
}
if (old_size == 1) {
*value = ValueOf(root_);
return true;
}
const bool update_timestamp = (rank != 1);
uint32_t node = root_;
while (node) {
const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node));
if (rank == left_subtree_num_nodes + 1) {
// This is the node we are looking for.
// Don't update timestamp if the node has rank 1.
if (update_timestamp) {
node = RemoveNode(node);
assert(old_size == GetSize() + 1);
MutableTimestampOf(node) = next_timestamp_++;
InsertNode(node);
assert(old_size == GetSize());
}
*value = ValueOf(node);
last_accessed_value_ = *value;
last_accessed_value_valid_ = true;
return true;
}
if (rank < left_subtree_num_nodes + 1) {
// Descend into the left subtree. The rank is still valid.
node = LeftOf(node);
} else {
// Descend into the right subtree. We leave behind the left subtree and
// the current node, adjust the |rank| accordingly.
rank -= left_subtree_num_nodes + 1;
node = RightOf(node);
}
}
assert(0);
return false;
}
uint32_t MoveToFront::CreateNode(uint32_t timestamp, uint32_t value) {
uint32_t handle = static_cast<uint32_t>(nodes_.size());
const auto result = value_to_node_.emplace(value, handle);
if (result.second) {
// Create new node.
nodes_.emplace_back(Node());
Node& node = nodes_.back();
node.timestamp = timestamp;
node.value = value;
node.size = 1;
// Non-NIL nodes start with height 1 because their NIL children are
// leaves.
node.height = 1;
} else {
// Reuse old node.
handle = result.first->second;
assert(!IsInTree(handle));
assert(ValueOf(handle) == value);
assert(SizeOf(handle) == 1);
assert(HeightOf(handle) == 1);
MutableTimestampOf(handle) = timestamp;
}
return handle;
}
void MoveToFront::InsertNode(uint32_t node) {
assert(!IsInTree(node));
assert(SizeOf(node) == 1);
assert(HeightOf(node) == 1);
assert(TimestampOf(node));
if (!root_) {
root_ = node;
return;
}
uint32_t iter = root_;
uint32_t parent = 0;
// Will determine if |node| will become the right or left child after
// insertion (but before balancing).
bool right_child = true;
// Find the node which will become |node|'s parent after insertion
// (but before balancing).
while (iter) {
parent = iter;
assert(TimestampOf(iter) != TimestampOf(node));
right_child = TimestampOf(iter) > TimestampOf(node);
iter = right_child ? RightOf(iter) : LeftOf(iter);
}
assert(parent);
// Connect node and parent.
MutableParentOf(node) = parent;
if (right_child)
MutableRightOf(parent) = node;
else
MutableLeftOf(parent) = node;
// Insertion is finished. Start the balancing process.
bool needs_rebalancing = true;
parent = ParentOf(node);
while (parent) {
UpdateNode(parent);
if (needs_rebalancing) {
const int parent_balance = BalanceOf(parent);
if (RightOf(parent) == node) {
// Added node to the right subtree.
if (parent_balance > 1) {
// Parent is right heavy, rotate left.
if (BalanceOf(node) < 0) RotateRight(node);
parent = RotateLeft(parent);
} else if (parent_balance == 0 || parent_balance == -1) {
// Parent is balanced or left heavy, no need to balance further.
needs_rebalancing = false;
}
} else {
// Added node to the left subtree.
if (parent_balance < -1) {
// Parent is left heavy, rotate right.
if (BalanceOf(node) > 0) RotateLeft(node);
parent = RotateRight(parent);
} else if (parent_balance == 0 || parent_balance == 1) {
// Parent is balanced or right heavy, no need to balance further.
needs_rebalancing = false;
}
}
}
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
node = parent;
parent = ParentOf(parent);
}
}
uint32_t MoveToFront::RemoveNode(uint32_t node) {
if (LeftOf(node) && RightOf(node)) {
// If |node| has two children, then use another node as scapegoat and swap
// their contents. We pick the scapegoat on the side of the tree which has
// more nodes.
const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node))
? RightestDescendantOf(LeftOf(node))
: LeftestDescendantOf(RightOf(node));
assert(scapegoat);
std::swap(MutableValueOf(node), MutableValueOf(scapegoat));
std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat));
value_to_node_[ValueOf(node)] = node;
value_to_node_[ValueOf(scapegoat)] = scapegoat;
node = scapegoat;
}
// |node| may have only one child at this point.
assert(!RightOf(node) || !LeftOf(node));
uint32_t parent = ParentOf(node);
uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node);
// Orphan |node| and reconnect parent and child.
if (child) MutableParentOf(child) = parent;
if (parent) {
if (LeftOf(parent) == node)
MutableLeftOf(parent) = child;
else
MutableRightOf(parent) = child;
}
MutableParentOf(node) = 0;
MutableLeftOf(node) = 0;
MutableRightOf(node) = 0;
UpdateNode(node);
const uint32_t orphan = node;
if (root_ == node) root_ = child;
// Removal is finished. Start the balancing process.
bool needs_rebalancing = true;
node = child;
while (parent) {
UpdateNode(parent);
if (needs_rebalancing) {
const int parent_balance = BalanceOf(parent);
if (parent_balance == 1 || parent_balance == -1) {
// The height of the subtree was not changed.
needs_rebalancing = false;
} else {
if (RightOf(parent) == node) {
// Removed node from the right subtree.
if (parent_balance < -1) {
// Parent is left heavy, rotate right.
const uint32_t sibling = LeftOf(parent);
if (BalanceOf(sibling) > 0) RotateLeft(sibling);
parent = RotateRight(parent);
}
} else {
// Removed node from the left subtree.
if (parent_balance > 1) {
// Parent is right heavy, rotate left.
const uint32_t sibling = RightOf(parent);
if (BalanceOf(sibling) < 0) RotateRight(sibling);
parent = RotateLeft(parent);
}
}
}
}
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
node = parent;
parent = ParentOf(parent);
}
return orphan;
}
uint32_t MoveToFront::RotateLeft(const uint32_t node) {
const uint32_t pivot = RightOf(node);
assert(pivot);
// LeftOf(pivot) gets attached to node in place of pivot.
MutableRightOf(node) = LeftOf(pivot);
if (RightOf(node)) MutableParentOf(RightOf(node)) = node;
// Pivot gets attached to ParentOf(node) in place of node.
MutableParentOf(pivot) = ParentOf(node);
if (!ParentOf(node))
root_ = pivot;
else if (IsLeftChild(node))
MutableLeftOf(ParentOf(node)) = pivot;
else
MutableRightOf(ParentOf(node)) = pivot;
// Node is child of pivot.
MutableLeftOf(pivot) = node;
MutableParentOf(node) = pivot;
// Update both node and pivot. Pivot is the new parent of node, so node should
// be updated first.
UpdateNode(node);
UpdateNode(pivot);
return pivot;
}
uint32_t MoveToFront::RotateRight(const uint32_t node) {
const uint32_t pivot = LeftOf(node);
assert(pivot);
// RightOf(pivot) gets attached to node in place of pivot.
MutableLeftOf(node) = RightOf(pivot);
if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node;
// Pivot gets attached to ParentOf(node) in place of node.
MutableParentOf(pivot) = ParentOf(node);
if (!ParentOf(node))
root_ = pivot;
else if (IsLeftChild(node))
MutableLeftOf(ParentOf(node)) = pivot;
else
MutableRightOf(ParentOf(node)) = pivot;
// Node is child of pivot.
MutableRightOf(pivot) = node;
MutableParentOf(node) = pivot;
// Update both node and pivot. Pivot is the new parent of node, so node should
// be updated first.
UpdateNode(node);
UpdateNode(pivot);
return pivot;
}
void MoveToFront::UpdateNode(uint32_t node) {
MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node));
MutableHeightOf(node) =
1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node)));
}
} // namespace comp
} // namespace spvtools

View File

@ -1,384 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MOVE_TO_FRONT_H_
#define SOURCE_COMP_MOVE_TO_FRONT_H_
#include <cassert>
#include <cstdint>
#include <map>
#include <set>
#include <unordered_map>
#include <vector>
namespace spvtools {
namespace comp {
// Log(n) move-to-front implementation. Implements the following functions:
// Insert - pushes value to the front of the mtf sequence
// (only unique values allowed).
// Remove - remove value from the sequence.
// ValueFromRank - access value by its 1-indexed rank in the sequence.
// RankFromValue - get the rank of the given value in the sequence.
// Accessing a value with ValueFromRank or RankFromValue moves the value to the
// front of the sequence (rank of 1).
//
// The implementation is based on an AVL-based order statistic tree. The tree
// is ordered by timestamps issued when values are inserted or accessed (recent
// values go to the left side of the tree, old values are gradually rotated to
// the right side).
//
// Terminology
// rank: 1-indexed rank showing how recently the value was inserted or accessed.
// node: handle used internally to access node data.
// size: size of the subtree of a node (including the node).
// height: distance from a node to the farthest leaf.
class MoveToFront {
public:
explicit MoveToFront(size_t reserve_capacity = 4) {
nodes_.reserve(reserve_capacity);
// Create NIL node.
nodes_.emplace_back(Node());
}
virtual ~MoveToFront() = default;
// Inserts value in the move-to-front sequence. Does nothing if the value is
// already in the sequence. Returns true if insertion was successful.
// The inserted value is placed at the front of the sequence (rank 1).
bool Insert(uint32_t value);
// Removes value from move-to-front sequence. Returns false iff the value
// was not found.
bool Remove(uint32_t value);
// Computes 1-indexed rank of value in the move-to-front sequence and moves
// the value to the front. Example:
// Before the call: 4 8 2 1 7
// RankFromValue(8) returns 2
// After the call: 8 4 2 1 7
// Returns true iff the value was found in the sequence.
bool RankFromValue(uint32_t value, uint32_t* rank);
// Returns value corresponding to a 1-indexed rank in the move-to-front
// sequence and moves the value to the front. Example:
// Before the call: 4 8 2 1 7
// ValueFromRank(2) returns 8
// After the call: 8 4 2 1 7
// Returns true iff the rank is within bounds [1, GetSize()].
bool ValueFromRank(uint32_t rank, uint32_t* value);
// Moves the value to the front of the sequence.
// Returns false iff value is not in the sequence.
bool Promote(uint32_t value);
// Returns true iff the move-to-front sequence contains the value.
bool HasValue(uint32_t value) const;
// Returns the number of elements in the move-to-front sequence.
uint32_t GetSize() const { return SizeOf(root_); }
protected:
// Internal tree data structure uses handles instead of pointers. Leaves and
// root parent reference a singleton under handle 0. Although dereferencing
// a null pointer is not possible, inappropriate access to handle 0 would
// cause an assertion. Handles are not garbage collected if value was
// deprecated
// with DeprecateValue(). But handles are recycled when a node is
// repositioned.
// Internal tree data structure node.
struct Node {
// Timestamp from a logical clock which updates every time the element is
// accessed through ValueFromRank or RankFromValue.
uint32_t timestamp = 0;
// The size of the node's subtree, including the node.
// SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1.
uint32_t size = 0;
// Handles to connected nodes.
uint32_t left = 0;
uint32_t right = 0;
uint32_t parent = 0;
// Distance to the farthest leaf.
// Leaves have height 0, real nodes at least 1.
uint32_t height = 0;
// Stored value.
uint32_t value = 0;
};
// Creates node and sets correct values. Non-NIL nodes should be created only
// through this function. If the node with this value has been created
// previously
// and since orphaned, reuses the old node instead of creating a new one.
uint32_t CreateNode(uint32_t timestamp, uint32_t value);
// Node accessor methods. Naming is designed to be similar to natural
// language as these functions tend to be used in sequences, for example:
// ParentOf(LeftestDescendentOf(RightOf(node)))
// Returns value of the node referenced by |handle|.
uint32_t ValueOf(uint32_t node) const { return nodes_.at(node).value; }
// Returns left child of |node|.
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
// Returns right child of |node|.
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
// Returns parent of |node|.
uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; }
// Returns timestamp of |node|.
uint32_t TimestampOf(uint32_t node) const {
assert(node);
return nodes_.at(node).timestamp;
}
// Returns size of |node|.
uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; }
// Returns height of |node|.
uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; }
// Returns mutable reference to value of |node|.
uint32_t& MutableValueOf(uint32_t node) {
assert(node);
return nodes_.at(node).value;
}
// Returns mutable reference to handle of left child of |node|.
uint32_t& MutableLeftOf(uint32_t node) {
assert(node);
return nodes_.at(node).left;
}
// Returns mutable reference to handle of right child of |node|.
uint32_t& MutableRightOf(uint32_t node) {
assert(node);
return nodes_.at(node).right;
}
// Returns mutable reference to handle of parent of |node|.
uint32_t& MutableParentOf(uint32_t node) {
assert(node);
return nodes_.at(node).parent;
}
// Returns mutable reference to timestamp of |node|.
uint32_t& MutableTimestampOf(uint32_t node) {
assert(node);
return nodes_.at(node).timestamp;
}
// Returns mutable reference to size of |node|.
uint32_t& MutableSizeOf(uint32_t node) {
assert(node);
return nodes_.at(node).size;
}
// Returns mutable reference to height of |node|.
uint32_t& MutableHeightOf(uint32_t node) {
assert(node);
return nodes_.at(node).height;
}
// Returns true iff |node| is left child of its parent.
bool IsLeftChild(uint32_t node) const {
assert(node);
return LeftOf(ParentOf(node)) == node;
}
// Returns true iff |node| is right child of its parent.
bool IsRightChild(uint32_t node) const {
assert(node);
return RightOf(ParentOf(node)) == node;
}
// Returns true iff |node| has no relatives.
bool IsOrphan(uint32_t node) const {
assert(node);
return !ParentOf(node) && !LeftOf(node) && !RightOf(node);
}
// Returns true iff |node| is in the tree.
bool IsInTree(uint32_t node) const {
assert(node);
return node == root_ || !IsOrphan(node);
}
// Returns the height difference between right and left subtrees.
int BalanceOf(uint32_t node) const {
return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node)));
}
// Updates size and height of the node, assuming that the children have
// correct values.
void UpdateNode(uint32_t node);
// Returns the most LeftOf(LeftOf(... descendent which is not leaf.
uint32_t LeftestDescendantOf(uint32_t node) const {
uint32_t parent = 0;
while (node) {
parent = node;
node = LeftOf(node);
}
return parent;
}
// Returns the most RightOf(RightOf(... descendent which is not leaf.
uint32_t RightestDescendantOf(uint32_t node) const {
uint32_t parent = 0;
while (node) {
parent = node;
node = RightOf(node);
}
return parent;
}
// Inserts node in the tree. The node must be an orphan.
void InsertNode(uint32_t node);
// Removes node from the tree. May change value_to_node_ if removal uses a
// scapegoat. Returns the removed (orphaned) handle for recycling. The
// returned handle may not be equal to |node| if scapegoat was used.
uint32_t RemoveNode(uint32_t node);
// Rotates |node| left, reassigns all connections and returns the node
// which takes place of the |node|.
uint32_t RotateLeft(const uint32_t node);
// Rotates |node| right, reassigns all connections and returns the node
// which takes place of the |node|.
uint32_t RotateRight(const uint32_t node);
// Root node handle. The tree is empty if root_ is 0.
uint32_t root_ = 0;
// Incremented counters for next timestamp and value.
uint32_t next_timestamp_ = 1;
// Holds all tree nodes. Indices of this vector are node handles.
std::vector<Node> nodes_;
// Maps ids to node handles.
std::unordered_map<uint32_t, uint32_t> value_to_node_;
// Cache for the last accessed value in the sequence.
uint32_t last_accessed_value_ = 0;
bool last_accessed_value_valid_ = false;
};
class MultiMoveToFront {
public:
// Inserts |value| to sequence with handle |mtf|.
// Returns false if |mtf| already has |value|.
bool Insert(uint64_t mtf, uint32_t value) {
if (GetMtf(mtf).Insert(value)) {
val_to_mtfs_[value].insert(mtf);
return true;
}
return false;
}
// Removes |value| from sequence with handle |mtf|.
// Returns false if |mtf| doesn't have |value|.
bool Remove(uint64_t mtf, uint32_t value) {
if (GetMtf(mtf).Remove(value)) {
val_to_mtfs_[value].erase(mtf);
return true;
}
assert(val_to_mtfs_[value].count(mtf) == 0);
return false;
}
// Removes |value| from all sequences which have it.
void RemoveFromAll(uint32_t value) {
auto it = val_to_mtfs_.find(value);
if (it == val_to_mtfs_.end()) return;
auto& mtfs_containing_value = it->second;
for (uint64_t mtf : mtfs_containing_value) {
GetMtf(mtf).Remove(value);
}
val_to_mtfs_.erase(value);
}
// Computes rank of |value| in sequence |mtf|.
// Returns false if |mtf| doesn't have |value|.
bool RankFromValue(uint64_t mtf, uint32_t value, uint32_t* rank) {
return GetMtf(mtf).RankFromValue(value, rank);
}
// Finds |value| with |rank| in sequence |mtf|.
// Returns false if |rank| is out of bounds.
bool ValueFromRank(uint64_t mtf, uint32_t rank, uint32_t* value) {
return GetMtf(mtf).ValueFromRank(rank, value);
}
// Returns size of |mtf| sequence.
uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); }
// Promotes |value| in all sequences which have it.
void Promote(uint32_t value) {
const auto it = val_to_mtfs_.find(value);
if (it == val_to_mtfs_.end()) return;
const auto& mtfs_containing_value = it->second;
for (uint64_t mtf : mtfs_containing_value) {
GetMtf(mtf).Promote(value);
}
}
// Inserts |value| in sequence |mtf| or promotes if it's already there.
void InsertOrPromote(uint64_t mtf, uint32_t value) {
if (!Insert(mtf, value)) {
GetMtf(mtf).Promote(value);
}
}
// Returns if |mtf| sequence has |value|.
bool HasValue(uint64_t mtf, uint32_t value) {
return GetMtf(mtf).HasValue(value);
}
private:
// Returns actual MoveToFront object corresponding to |handle|.
// As multiple operations are often performed consecutively for the same
// sequence, the last returned value is cached.
MoveToFront& GetMtf(uint64_t handle) {
if (!cached_mtf_ || cached_handle_ != handle) {
cached_handle_ = handle;
cached_mtf_ = &mtfs_[handle];
}
return *cached_mtf_;
}
// Container holding MoveToFront objects. Map key is sequence handle.
std::map<uint64_t, MoveToFront> mtfs_;
// Container mapping value to sequences which contain that value.
std::unordered_map<uint32_t, std::set<uint64_t>> val_to_mtfs_;
// Cache for the last accessed sequence.
uint64_t cached_handle_ = 0;
MoveToFront* cached_mtf_ = nullptr;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MOVE_TO_FRONT_H_

View File

@ -1,78 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/id_descriptor.h"
#include <cassert>
#include <iostream>
#include "source/opcode.h"
#include "source/operand.h"
namespace spvtools {
namespace {
// Hashes an array of words. Order of words is important.
uint32_t HashU32Array(const std::vector<uint32_t>& words) {
// The hash function is a sum of hashes of each word seeded by word index.
// Knuth's multiplicative hash is used to hash the words.
const uint32_t kKnuthMulHash = 2654435761;
uint32_t val = 0;
for (uint32_t i = 0; i < words.size(); ++i) {
val += (words[i] + i + 123) * kKnuthMulHash;
}
return val;
}
} // namespace
uint32_t IdDescriptorCollection::ProcessInstruction(
const spv_parsed_instruction_t& inst) {
if (!inst.result_id) return 0;
assert(words_.empty());
words_.push_back(inst.words[0]);
for (size_t operand_index = 0; operand_index < inst.num_operands;
++operand_index) {
const auto& operand = inst.operands[operand_index];
if (spvIsIdType(operand.type)) {
const uint32_t id = inst.words[operand.offset];
const auto it = id_to_descriptor_.find(id);
// Forward declared ids are not hashed.
if (it != id_to_descriptor_.end()) {
words_.push_back(it->second);
}
} else {
for (size_t operand_word_index = 0;
operand_word_index < operand.num_words; ++operand_word_index) {
words_.push_back(inst.words[operand.offset + operand_word_index]);
}
}
}
uint32_t descriptor =
custom_hash_func_ ? custom_hash_func_(words_) : HashU32Array(words_);
if (descriptor == 0) descriptor = 1;
assert(descriptor);
words_.clear();
const auto result = id_to_descriptor_.emplace(inst.result_id, descriptor);
assert(result.second);
(void)result;
return descriptor;
}
} // namespace spvtools

View File

@ -1,63 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_ID_DESCRIPTOR_H_
#define SOURCE_ID_DESCRIPTOR_H_
#include <unordered_map>
#include <vector>
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
using CustomHashFunc = std::function<uint32_t(const std::vector<uint32_t>&)>;
// Computes and stores id descriptors.
//
// Descriptors are computed as hash of all words in the instruction where ids
// were substituted with previously computed descriptors.
class IdDescriptorCollection {
public:
explicit IdDescriptorCollection(
CustomHashFunc custom_hash_func = CustomHashFunc())
: custom_hash_func_(custom_hash_func) {
words_.reserve(16);
}
// Computes descriptor for the result id of the given instruction and
// registers it in id_to_descriptor_. Returns the computed descriptor.
// This function needs to be sequentially called for every instruction in the
// module.
uint32_t ProcessInstruction(const spv_parsed_instruction_t& inst);
// Returns a previously computed descriptor id.
uint32_t GetDescriptor(uint32_t id) const {
const auto it = id_to_descriptor_.find(id);
if (it == id_to_descriptor_.end()) return 0;
return it->second;
}
private:
std::unordered_map<uint32_t, uint32_t> id_to_descriptor_;
std::function<uint32_t(const std::vector<uint32_t>&)> custom_hash_func_;
// Scratch buffer used for hashing. Class member to optimize on allocation.
std::vector<uint32_t> words_;
};
} // namespace spvtools
#endif // SOURCE_ID_DESCRIPTOR_H_

View File

@ -183,33 +183,9 @@ add_spvtools_unittest(
endif()
add_spvtools_unittest(
TARGET bit_stream
SRCS bit_stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h
LIBS ${SPIRV_TOOLS})
add_spvtools_unittest(
TARGET huffman_codec
SRCS huffman_codec.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/huffman_codec.h
LIBS ${SPIRV_TOOLS})
add_spvtools_unittest(
TARGET move_to_front
SRCS move_to_front_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.h
${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.cpp
LIBS ${SPIRV_TOOLS})
add_subdirectory(comp)
add_subdirectory(link)
add_subdirectory(opt)
add_subdirectory(reduce)
add_subdirectory(stats)
add_subdirectory(tools)
add_subdirectory(util)
add_subdirectory(val)

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +0,0 @@
# Copyright (c) 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set(VAL_TEST_COMMON_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
)
if(SPIRV_BUILD_COMPRESSION)
add_spvtools_unittest(TARGET markv_codec
SRCS
markv_codec_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_shader.cpp
${VAL_TEST_COMMON_SRCS}
LIBS SPIRV-Tools-comp ${SPIRV_TOOLS}
)
endif(SPIRV_BUILD_COMPRESSION)

View File

@ -1,829 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Tests for unique type declaration rules validator.
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "gmock/gmock.h"
#include "source/comp/markv.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
#include "tools/comp/markv_model_factory.h"
namespace spvtools {
namespace comp {
namespace {
using spvtest::ScopedContext;
using MarkvTest = ::testing::TestWithParam<MarkvModelType>;
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
const spv_position_t& position,
const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: " << position.index << ": " << message << std::endl;
break;
case SPV_MSG_WARNING:
std::cout << "warning: " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cout << "info: " << position.index << ": " << message << std::endl;
break;
default:
break;
}
}
// Compiles |code| to SPIR-V |words|.
void Compile(const std::string& code, std::vector<uint32_t>* words,
uint32_t options = SPV_TEXT_TO_BINARY_OPTION_NONE,
spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
spvtools::Context ctx(env);
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
spv_binary spirv_binary;
ASSERT_EQ(SPV_SUCCESS, spvTextToBinaryWithOptions(
ctx.CContext(), code.c_str(), code.size(), options,
&spirv_binary, nullptr));
*words = std::vector<uint32_t>(spirv_binary->code,
spirv_binary->code + spirv_binary->wordCount);
spvBinaryDestroy(spirv_binary);
}
// Disassembles SPIR-V |words| to |out_text|.
void Disassemble(const std::vector<uint32_t>& words, std::string* out_text,
spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
spvtools::Context ctx(env);
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
spv_text text = nullptr;
ASSERT_EQ(SPV_SUCCESS, spvBinaryToText(ctx.CContext(), words.data(),
words.size(), 0, &text, nullptr));
assert(text);
*out_text = std::string(text->str, text->length);
spvTextDestroy(text);
}
// Encodes/decodes |original|, assembles/dissasembles |original|, then compares
// the results of the two operations.
void TestEncodeDecode(MarkvModelType model_type,
const std::string& original_text) {
spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_2);
std::unique_ptr<MarkvModel> model = CreateMarkvModel(model_type);
MarkvCodecOptions options;
std::vector<uint32_t> expected_binary;
Compile(original_text, &expected_binary);
ASSERT_FALSE(expected_binary.empty());
std::string expected_text;
Disassemble(expected_binary, &expected_text);
ASSERT_FALSE(expected_text.empty());
std::vector<uint32_t> binary_to_encode;
Compile(original_text, &binary_to_encode,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_FALSE(binary_to_encode.empty());
std::stringstream encoder_comments;
const auto output_to_string_stream =
[&encoder_comments](const std::string& str) { encoder_comments << str; };
std::vector<uint8_t> markv;
ASSERT_EQ(SPV_SUCCESS,
SpirvToMarkv(ctx.CContext(), binary_to_encode, options, *model,
DiagnosticsMessageHandler, output_to_string_stream,
MarkvDebugConsumer(), &markv));
ASSERT_FALSE(markv.empty());
std::vector<uint32_t> decoded_binary;
ASSERT_EQ(SPV_SUCCESS,
MarkvToSpirv(ctx.CContext(), markv, options, *model,
DiagnosticsMessageHandler, MarkvLogConsumer(),
MarkvDebugConsumer(), &decoded_binary));
ASSERT_FALSE(decoded_binary.empty());
EXPECT_EQ(expected_binary, decoded_binary) << encoder_comments.str();
std::string decoded_text;
Disassemble(decoded_binary, &decoded_text);
ASSERT_FALSE(decoded_text.empty());
EXPECT_EQ(expected_text, decoded_text) << encoder_comments.str();
}
void TestEncodeDecodeShaderMainBody(MarkvModelType model_type,
const std::string& body) {
const std::string prefix =
R"(
OpCapability Shader
OpCapability Int64
OpCapability Float64
%ext_inst = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%f64 = OpTypeFloat 64
%u64 = OpTypeInt 64 0
%s64 = OpTypeInt 64 1
%boolvec2 = OpTypeVector %bool 2
%s32vec2 = OpTypeVector %s32 2
%u32vec2 = OpTypeVector %u32 2
%f32vec2 = OpTypeVector %f32 2
%f64vec2 = OpTypeVector %f64 2
%boolvec3 = OpTypeVector %bool 3
%u32vec3 = OpTypeVector %u32 3
%s32vec3 = OpTypeVector %s32 3
%f32vec3 = OpTypeVector %f32 3
%f64vec3 = OpTypeVector %f64 3
%boolvec4 = OpTypeVector %bool 4
%u32vec4 = OpTypeVector %u32 4
%s32vec4 = OpTypeVector %s32 4
%f32vec4 = OpTypeVector %f32 4
%f64vec4 = OpTypeVector %f64 4
%f32_0 = OpConstant %f32 0
%f32_1 = OpConstant %f32 1
%f32_2 = OpConstant %f32 2
%f32_3 = OpConstant %f32 3
%f32_4 = OpConstant %f32 4
%f32_pi = OpConstant %f32 3.14159
%s32_0 = OpConstant %s32 0
%s32_1 = OpConstant %s32 1
%s32_2 = OpConstant %s32 2
%s32_3 = OpConstant %s32 3
%s32_4 = OpConstant %s32 4
%s32_m1 = OpConstant %s32 -1
%u32_0 = OpConstant %u32 0
%u32_1 = OpConstant %u32 1
%u32_2 = OpConstant %u32 2
%u32_3 = OpConstant %u32 3
%u32_4 = OpConstant %u32 4
%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
%main = OpFunction %void None %func
%main_entry = OpLabel)";
const std::string suffix =
R"(
OpReturn
OpFunctionEnd)";
TestEncodeDecode(model_type, prefix + body + suffix);
}
TEST_P(MarkvTest, U32Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%u32 = OpTypeInt 32 0
%100 = OpConstant %u32 0
%200 = OpConstant %u32 1
%300 = OpConstant %u32 4294967295
)");
}
TEST_P(MarkvTest, S32Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%s32 = OpTypeInt 32 1
%100 = OpConstant %s32 0
%200 = OpConstant %s32 1
%300 = OpConstant %s32 -1
%400 = OpConstant %s32 2147483647
%500 = OpConstant %s32 -2147483648
)");
}
TEST_P(MarkvTest, U64Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Int64
OpMemoryModel Logical GLSL450
%u64 = OpTypeInt 64 0
%100 = OpConstant %u64 0
%200 = OpConstant %u64 1
%300 = OpConstant %u64 18446744073709551615
)");
}
TEST_P(MarkvTest, S64Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Int64
OpMemoryModel Logical GLSL450
%s64 = OpTypeInt 64 1
%100 = OpConstant %s64 0
%200 = OpConstant %s64 1
%300 = OpConstant %s64 -1
%400 = OpConstant %s64 9223372036854775807
%500 = OpConstant %s64 -9223372036854775808
)");
}
TEST_P(MarkvTest, U16Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Int16
OpMemoryModel Logical GLSL450
%u16 = OpTypeInt 16 0
%100 = OpConstant %u16 0
%200 = OpConstant %u16 1
%300 = OpConstant %u16 65535
)");
}
TEST_P(MarkvTest, S16Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Int16
OpMemoryModel Logical GLSL450
%s16 = OpTypeInt 16 1
%100 = OpConstant %s16 0
%200 = OpConstant %s16 1
%300 = OpConstant %s16 -1
%400 = OpConstant %s16 32767
%500 = OpConstant %s16 -32768
)");
}
TEST_P(MarkvTest, F32Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%f32 = OpTypeFloat 32
%100 = OpConstant %f32 0
%200 = OpConstant %f32 1
%300 = OpConstant %f32 0.1
%400 = OpConstant %f32 -0.1
)");
}
TEST_P(MarkvTest, F64Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Float64
OpMemoryModel Logical GLSL450
%f64 = OpTypeFloat 64
%100 = OpConstant %f64 0
%200 = OpConstant %f64 1
%300 = OpConstant %f64 0.1
%400 = OpConstant %f64 -0.1
)");
}
TEST_P(MarkvTest, F16Literal) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpCapability Float16
OpMemoryModel Logical GLSL450
%f16 = OpTypeFloat 16
%100 = OpConstant %f16 0
%200 = OpConstant %f16 1
%300 = OpConstant %f16 0.1
%400 = OpConstant %f16 -0.1
)");
}
TEST_P(MarkvTest, StringLiteral) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_KHR_16bit_storage"
OpExtension "xxx"
OpExtension "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
OpExtension ""
OpMemoryModel Logical GLSL450
)");
}
TEST_P(MarkvTest, WithFunction) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpExtension "SPV_KHR_16bit_storage"
OpMemoryModel Physical32 OpenCL
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%100 = OpConstant %u32 1
%200 = OpConstant %u32 2
%main = OpFunction %void None %void_func
%entry_main = OpLabel
%300 = OpIAdd %u32 %100 %200
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, WithMultipleFunctions) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
%f32 = OpTypeFloat 32
%one = OpConstant %f32 1
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%f32_func = OpTypeFunction %f32 %f32
%sqr_plus_one = OpFunction %f32 None %f32_func
%x = OpFunctionParameter %f32
%100 = OpLabel
%x2 = OpFMul %f32 %x %x
%x2p1 = OpFunctionCall %f32 %plus_one %x2
OpReturnValue %x2p1
OpFunctionEnd
%plus_one = OpFunction %f32 None %f32_func
%y = OpFunctionParameter %f32
%200 = OpLabel
%yp1 = OpFAdd %f32 %y %one
OpReturnValue %yp1
OpFunctionEnd
%main = OpFunction %void None %void_func
%entry_main = OpLabel
%1p1 = OpFunctionCall %f32 %sqr_plus_one %one
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, ForwardDeclaredId) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
OpEntryPoint Kernel %1 "simple_kernel"
%2 = OpTypeInt 32 0
%3 = OpTypeVector %2 2
%4 = OpConstant %2 2
%5 = OpTypeArray %2 %4
%6 = OpTypeVoid
%7 = OpTypeFunction %6
%1 = OpFunction %6 None %7
%8 = OpLabel
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, WithSwitch) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpCapability Int64
OpMemoryModel Physical32 OpenCL
%u64 = OpTypeInt 64 0
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%val = OpConstant %u64 1
%main = OpFunction %void None %void_func
%entry_main = OpLabel
OpSwitch %val %default 1 %case1 1000000000000 %case2
%case1 = OpLabel
OpNop
OpBranch %after_switch
%case2 = OpLabel
OpNop
OpBranch %after_switch
%default = OpLabel
OpNop
OpBranch %after_switch
%after_switch = OpLabel
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, WithLoop) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%main = OpFunction %void None %void_func
%entry_main = OpLabel
OpLoopMerge %merge %continue DontUnroll|DependencyLength 10
OpBranch %begin_loop
%begin_loop = OpLabel
OpNop
OpBranch %continue
%continue = OpLabel
OpNop
OpBranch %begin_loop
%merge = OpLabel
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, WithDecorate) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpDecorate %1 ArrayStride 4
OpDecorate %1 Uniform
%2 = OpTypeFloat 32
%1 = OpTypeRuntimeArray %2
)");
}
TEST_P(MarkvTest, WithExtInst) {
TestEncodeDecode(GetParam(), R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
%opencl = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical32 OpenCL
%f32 = OpTypeFloat 32
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%100 = OpConstant %f32 1.1
%main = OpFunction %void None %void_func
%entry_main = OpLabel
%200 = OpExtInst %f32 %opencl cos %100
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, F32Mul) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpFMul %f32 %f32_0 %f32_1
%val2 = OpFMul %f32 %f32_2 %f32_0
%val3 = OpFMul %f32 %f32_pi %f32_2
%val4 = OpFMul %f32 %f32_1 %f32_1
)");
}
TEST_P(MarkvTest, U32Mul) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpIMul %u32 %u32_0 %u32_1
%val2 = OpIMul %u32 %u32_2 %u32_0
%val3 = OpIMul %u32 %u32_3 %u32_2
%val4 = OpIMul %u32 %u32_1 %u32_1
)");
}
TEST_P(MarkvTest, S32Mul) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpIMul %s32 %s32_0 %s32_1
%val2 = OpIMul %s32 %s32_2 %s32_0
%val3 = OpIMul %s32 %s32_m1 %s32_2
%val4 = OpIMul %s32 %s32_1 %s32_1
)");
}
TEST_P(MarkvTest, F32Add) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpFAdd %f32 %f32_0 %f32_1
%val2 = OpFAdd %f32 %f32_2 %f32_0
%val3 = OpFAdd %f32 %f32_pi %f32_2
%val4 = OpFAdd %f32 %f32_1 %f32_1
)");
}
TEST_P(MarkvTest, U32Add) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpIAdd %u32 %u32_0 %u32_1
%val2 = OpIAdd %u32 %u32_2 %u32_0
%val3 = OpIAdd %u32 %u32_3 %u32_2
%val4 = OpIAdd %u32 %u32_1 %u32_1
)");
}
TEST_P(MarkvTest, S32Add) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%val1 = OpIAdd %s32 %s32_0 %s32_1
%val2 = OpIAdd %s32 %s32_2 %s32_0
%val3 = OpIAdd %s32 %s32_m1 %s32_2
%val4 = OpIAdd %s32 %s32_1 %s32_1
)");
}
TEST_P(MarkvTest, F32Dot) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%dot2_1 = OpDot %f32 %f32vec2_01 %f32vec2_12
%dot2_2 = OpDot %f32 %f32vec2_01 %f32vec2_01
%dot2_3 = OpDot %f32 %f32vec2_12 %f32vec2_12
%dot3_1 = OpDot %f32 %f32vec3_012 %f32vec3_123
%dot3_2 = OpDot %f32 %f32vec3_012 %f32vec3_012
%dot3_3 = OpDot %f32 %f32vec3_123 %f32vec3_123
%dot4_1 = OpDot %f32 %f32vec4_0123 %f32vec4_1234
%dot4_2 = OpDot %f32 %f32vec4_0123 %f32vec4_0123
%dot4_3 = OpDot %f32 %f32vec4_1234 %f32vec4_1234
)");
}
TEST_P(MarkvTest, F32VectorCompositeConstruct) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%cc1 = OpCompositeConstruct %f32vec4 %f32vec2_01 %f32vec2_12
%cc2 = OpCompositeConstruct %f32vec3 %f32vec2_01 %f32_2
%cc3 = OpCompositeConstruct %f32vec2 %f32_1 %f32_2
%cc4 = OpCompositeConstruct %f32vec4 %f32_1 %f32_2 %cc3
)");
}
TEST_P(MarkvTest, U32VectorCompositeConstruct) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12
%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2
%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2
%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3
)");
}
TEST_P(MarkvTest, S32VectorCompositeConstruct) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12
%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2
%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2
%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3
)");
}
TEST_P(MarkvTest, F32VectorCompositeExtract) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%f32vec3_013 = OpCompositeExtract %f32vec3 %f32vec4_0123 0 1 3
)");
}
TEST_P(MarkvTest, F32VectorComparison) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%c1 = OpFOrdEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c2 = OpFUnordEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c3 = OpFOrdNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c4 = OpFUnordNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c5 = OpFOrdLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210
%c6 = OpFUnordLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210
%c7 = OpFOrdGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210
%c8 = OpFUnordGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210
%c9 = OpFOrdLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c10 = OpFUnordLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c11 = OpFOrdGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
%c12 = OpFUnordGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210
)");
}
TEST_P(MarkvTest, VectorShuffle) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%sh1 = OpVectorShuffle %f32vec2 %f32vec4_0123 %f32vec4_3210 3 6
%sh2 = OpVectorShuffle %f32vec3 %f32vec2_01 %f32vec4_3210 0 3 4
)");
}
TEST_P(MarkvTest, VectorTimesScalar) {
TestEncodeDecodeShaderMainBody(GetParam(), R"(
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2
%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2
)");
}
TEST_P(MarkvTest, SpirvSpecSample) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %4 "main" %31 %33 %42 %57
OpExecutionMode %4 OriginLowerLeft
; Debug information
OpSource GLSL 450
OpName %4 "main"
OpName %9 "scale"
OpName %17 "S"
OpMemberName %17 0 "b"
OpMemberName %17 1 "v"
OpMemberName %17 2 "i"
OpName %18 "blockName"
OpMemberName %18 0 "s"
OpMemberName %18 1 "cond"
OpName %20 ""
OpName %31 "color"
OpName %33 "color1"
OpName %42 "color2"
OpName %48 "i"
OpName %57 "multiplier"
; Annotations (non-debug)
OpDecorate %15 ArrayStride 16
OpMemberDecorate %17 0 Offset 0
OpMemberDecorate %17 1 Offset 16
OpMemberDecorate %17 2 Offset 96
OpMemberDecorate %18 0 Offset 0
OpMemberDecorate %18 1 Offset 112
OpDecorate %18 Block
OpDecorate %20 DescriptorSet 0
OpDecorate %42 NoPerspective
; All types, variables, and constants
%2 = OpTypeVoid
%3 = OpTypeFunction %2 ; void ()
%6 = OpTypeFloat 32 ; 32-bit float
%7 = OpTypeVector %6 4 ; vec4
%8 = OpTypePointer Function %7 ; function-local vec4*
%10 = OpConstant %6 1
%11 = OpConstant %6 2
%12 = OpConstantComposite %7 %10 %10 %11 %10 ; vec4(1.0, 1.0, 2.0, 1.0)
%13 = OpTypeInt 32 0 ; 32-bit int, sign-less
%14 = OpConstant %13 5
%15 = OpTypeArray %7 %14
%16 = OpTypeInt 32 1
%17 = OpTypeStruct %13 %15 %16
%18 = OpTypeStruct %17 %13
%19 = OpTypePointer Uniform %18
%20 = OpVariable %19 Uniform
%21 = OpConstant %16 1
%22 = OpTypePointer Uniform %13
%25 = OpTypeBool
%26 = OpConstant %13 0
%30 = OpTypePointer Output %7
%31 = OpVariable %30 Output
%32 = OpTypePointer Input %7
%33 = OpVariable %32 Input
%35 = OpConstant %16 0
%36 = OpConstant %16 2
%37 = OpTypePointer Uniform %7
%42 = OpVariable %32 Input
%47 = OpTypePointer Function %16
%55 = OpConstant %16 4
%57 = OpVariable %32 Input
; All functions
%4 = OpFunction %2 None %3 ; main()
%5 = OpLabel
%9 = OpVariable %8 Function
%48 = OpVariable %47 Function
OpStore %9 %12
%23 = OpAccessChain %22 %20 %21 ; location of cond
%24 = OpLoad %13 %23 ; load 32-bit int from cond
%27 = OpINotEqual %25 %24 %26 ; convert to bool
OpSelectionMerge %29 None ; structured if
OpBranchConditional %27 %28 %41 ; if cond
%28 = OpLabel ; then
%34 = OpLoad %7 %33
%38 = OpAccessChain %37 %20 %35 %21 %36 ; s.v[2]
%39 = OpLoad %7 %38
%40 = OpFAdd %7 %34 %39
OpStore %31 %40
OpBranch %29
%41 = OpLabel ; else
%43 = OpLoad %7 %42
%44 = OpExtInst %7 %1 Sqrt %43 ; extended instruction sqrt
%45 = OpLoad %7 %9
%46 = OpFMul %7 %44 %45
OpStore %31 %46
OpBranch %29
%29 = OpLabel ; endif
OpStore %48 %35
OpBranch %49
%49 = OpLabel
OpLoopMerge %51 %52 None ; structured loop
OpBranch %53
%53 = OpLabel
%54 = OpLoad %16 %48
%56 = OpSLessThan %25 %54 %55 ; i < 4 ?
OpBranchConditional %56 %50 %51 ; body or break
%50 = OpLabel ; body
%58 = OpLoad %7 %57
%59 = OpLoad %7 %31
%60 = OpFMul %7 %59 %58
OpStore %31 %60
OpBranch %52
%52 = OpLabel ; continue target
%61 = OpLoad %16 %48
%62 = OpIAdd %16 %61 %21 ; ++i
OpStore %48 %62
OpBranch %49 ; loop back
%51 = OpLabel ; loop merge point
OpReturn
OpFunctionEnd
)");
}
TEST_P(MarkvTest, SampleFromDeadBranchEliminationTest) {
TestEncodeDecode(GetParam(), R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %gl_FragColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 140
OpName %main "main"
OpName %gl_FragColor "gl_FragColor"
%void = OpTypeVoid
%5 = OpTypeFunction %void
%bool = OpTypeBool
%true = OpConstantTrue %bool
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Function_v4float = OpTypePointer Function %v4float
%float_0 = OpConstant %float 0
%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%float_1 = OpConstant %float 1
%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
%_ptr_Output_v4float = OpTypePointer Output %v4float
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
%_ptr_Input_v4float = OpTypePointer Input %v4float
%main = OpFunction %void None %5
%17 = OpLabel
OpSelectionMerge %18 None
OpBranchConditional %true %19 %20
%19 = OpLabel
OpBranch %18
%20 = OpLabel
OpBranch %18
%18 = OpLabel
%21 = OpPhi %v4float %12 %19 %14 %20
OpStore %gl_FragColor %21
OpReturn
OpFunctionEnd
)");
}
INSTANTIATE_TEST_SUITE_P(AllMarkvModels, MarkvTest,
::testing::ValuesIn(std::vector<MarkvModelType>{
kMarkvModelShaderLite,
kMarkvModelShaderMid,
kMarkvModelShaderMax,
}));
} // namespace
} // namespace comp
} // namespace spvtools

View File

@ -1,317 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "gmock/gmock.h"
#include "source/comp/bit_stream.h"
#include "source/comp/huffman_codec.h"
namespace spvtools {
namespace comp {
namespace {
const std::map<std::string, uint32_t>& GetTestSet() {
static const std::map<std::string, uint32_t> hist = {
{"a", 4}, {"e", 7}, {"f", 3}, {"h", 2}, {"i", 3},
{"m", 2}, {"n", 2}, {"s", 2}, {"t", 2}, {"l", 1},
{"o", 2}, {"p", 1}, {"r", 1}, {"u", 1}, {"x", 1},
};
return hist;
}
class TestBitReader {
public:
TestBitReader(const std::string& bits) : bits_(bits) {}
bool ReadBit(bool* bit) {
if (pos_ < bits_.length()) {
*bit = bits_[pos_++] == '0' ? false : true;
return true;
}
return false;
}
private:
std::string bits_;
size_t pos_ = 0;
};
TEST(Huffman, PrintTree) {
HuffmanCodec<std::string> huffman(GetTestSet());
std::stringstream ss;
huffman.PrintTree(ss);
// clang-format off
const std::string expected = std::string(R"(
15-----7------e
8------4------a
4------2------m
2------n
19-----8------4------2------o
2------s
4------2------t
2------1------l
1------p
11-----5------2------1------r
1------u
3------f
6------3------i
3------1------x
2------h
)").substr(1);
// clang-format on
EXPECT_EQ(expected, ss.str());
}
TEST(Huffman, PrintTable) {
HuffmanCodec<std::string> huffman(GetTestSet());
std::stringstream ss;
huffman.PrintTable(ss);
const std::string expected = std::string(R"(
e 7 11
a 4 101
i 3 0001
f 3 0010
t 2 0101
s 2 0110
o 2 0111
n 2 1000
m 2 1001
h 2 00000
x 1 00001
u 1 00110
r 1 00111
p 1 01000
l 1 01001
)")
.substr(1);
EXPECT_EQ(expected, ss.str());
}
TEST(Huffman, TestValidity) {
HuffmanCodec<std::string> huffman(GetTestSet());
const auto& encoding_table = huffman.GetEncodingTable();
std::vector<std::string> codes;
for (const auto& entry : encoding_table) {
codes.push_back(BitsToStream(entry.second.first, entry.second.second));
}
std::sort(codes.begin(), codes.end());
ASSERT_LT(codes.size(), 20u) << "Inefficient test ahead";
for (size_t i = 0; i < codes.size(); ++i) {
for (size_t j = i + 1; j < codes.size(); ++j) {
ASSERT_FALSE(codes[i] == codes[j].substr(0, codes[i].length()))
<< codes[i] << " is prefix of " << codes[j];
}
}
}
TEST(Huffman, TestEncode) {
HuffmanCodec<std::string> huffman(GetTestSet());
uint64_t bits = 0;
size_t num_bits = 0;
EXPECT_TRUE(huffman.Encode("e", &bits, &num_bits));
EXPECT_EQ(2u, num_bits);
EXPECT_EQ("11", BitsToStream(bits, num_bits));
EXPECT_TRUE(huffman.Encode("a", &bits, &num_bits));
EXPECT_EQ(3u, num_bits);
EXPECT_EQ("101", BitsToStream(bits, num_bits));
EXPECT_TRUE(huffman.Encode("x", &bits, &num_bits));
EXPECT_EQ(5u, num_bits);
EXPECT_EQ("00001", BitsToStream(bits, num_bits));
EXPECT_FALSE(huffman.Encode("y", &bits, &num_bits));
}
TEST(Huffman, TestDecode) {
HuffmanCodec<std::string> huffman(GetTestSet());
TestBitReader bit_reader(
"01001"
"0001"
"1000"
"00110"
"00001"
"00");
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
std::string decoded;
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ("l", decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ("i", decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ("n", decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ("u", decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ("x", decoded);
ASSERT_FALSE(huffman.DecodeFromStream(read_bit, &decoded));
}
TEST(Huffman, TestDecodeNumbers) {
const std::map<uint32_t, uint32_t> hist = {{1, 10}, {2, 5}, {3, 15}};
HuffmanCodec<uint32_t> huffman(hist);
TestBitReader bit_reader(
"1"
"1"
"01"
"00"
"01"
"1");
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
uint32_t decoded;
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(3u, decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(3u, decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(2u, decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(1u, decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(2u, decoded);
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(3u, decoded);
}
TEST(Huffman, SerializeToTextU64) {
const std::map<uint64_t, uint32_t> hist = {{1001, 10}, {1002, 5}, {1003, 15}};
HuffmanCodec<uint64_t> huffman(hist);
const std::string code = huffman.SerializeToText(2);
const std::string expected = R"((5, {
{0, 0, 0},
{1001, 0, 0},
{1002, 0, 0},
{1003, 0, 0},
{0, 1, 2},
{0, 4, 3},
}))";
ASSERT_EQ(expected, code);
}
TEST(Huffman, SerializeToTextString) {
const std::map<std::string, uint32_t> hist = {
{"aaa", 10}, {"bbb", 20}, {"ccc", 15}};
HuffmanCodec<std::string> huffman(hist);
const std::string code = huffman.SerializeToText(4);
const std::string expected = R"((5, {
{"", 0, 0},
{"aaa", 0, 0},
{"bbb", 0, 0},
{"ccc", 0, 0},
{"", 3, 1},
{"", 4, 2},
}))";
ASSERT_EQ(expected, code);
}
TEST(Huffman, CreateFromTextString) {
std::vector<HuffmanCodec<std::string>::Node> nodes = {
{},
{"root", 2, 3},
{"left", 0, 0},
{"right", 0, 0},
};
HuffmanCodec<std::string> huffman(1, std::move(nodes));
std::stringstream ss;
huffman.PrintTree(ss);
const std::string expected = std::string(R"(
0------right
0------left
)")
.substr(1);
EXPECT_EQ(expected, ss.str());
}
TEST(Huffman, CreateFromTextU64) {
HuffmanCodec<uint64_t> huffman(5, {
{0, 0, 0},
{1001, 0, 0},
{1002, 0, 0},
{1003, 0, 0},
{0, 1, 2},
{0, 4, 3},
});
std::stringstream ss;
huffman.PrintTree(ss);
const std::string expected = std::string(R"(
0------1003
0------0------1002
0------1001
)")
.substr(1);
EXPECT_EQ(expected, ss.str());
TestBitReader bit_reader("01");
auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); };
uint64_t decoded = 0;
ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
EXPECT_EQ(1002u, decoded);
uint64_t bits = 0;
size_t num_bits = 0;
EXPECT_TRUE(huffman.Encode(1001, &bits, &num_bits));
EXPECT_EQ(2u, num_bits);
EXPECT_EQ("00", BitsToStream(bits, num_bits));
}
} // namespace
} // namespace comp
} // namespace spvtools

View File

@ -1,828 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include "gmock/gmock.h"
#include "source/comp/move_to_front.h"
namespace spvtools {
namespace comp {
namespace {
// Class used to test the inner workings of MoveToFront.
class MoveToFrontTester : public MoveToFront {
public:
// Inserts the value in the internal tree data structure. For testing only.
void TestInsert(uint32_t val) { InsertNode(CreateNode(val, val)); }
// Removes the value from the internal tree data structure. For testing only.
void TestRemove(uint32_t val) {
const auto it = value_to_node_.find(val);
assert(it != value_to_node_.end());
RemoveNode(it->second);
}
// Prints the internal tree data structure to |out|. For testing only.
void PrintTree(std::ostream& out, bool print_timestamp = false) const {
if (root_) PrintTreeInternal(out, root_, 1, print_timestamp);
}
// Returns node handle corresponding to the value. The value may not be in the
// tree.
uint32_t GetNodeHandle(uint32_t value) const {
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) return 0;
return it->second;
}
// Returns total node count (both those in the tree and removed,
// but not the NIL singleton).
size_t GetTotalNodeCount() const {
assert(nodes_.size());
return nodes_.size() - 1;
}
uint32_t GetLastAccessedValue() const { return last_accessed_value_; }
private:
// Prints the internal tree data structure for debug purposes in the following
// format:
// 10H3S4----5H1S1-----D2
// 15H2S2----12H1S1----D3
// Right links are horizontal, left links step down one line.
// 5H1S1 is read as value 5, height 1, size 1. Optionally node label can also
// contain timestamp (5H1S1T15). D3 stands for depth 3.
void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth,
bool print_timestamp) const;
};
void MoveToFrontTester::PrintTreeInternal(std::ostream& out, uint32_t node,
size_t depth,
bool print_timestamp) const {
if (!node) {
out << "D" << depth - 1 << std::endl;
return;
}
const size_t kTextFieldWvaluethWithoutTimestamp = 10;
const size_t kTextFieldWvaluethWithTimestamp = 14;
const size_t text_field_wvalueth = print_timestamp
? kTextFieldWvaluethWithTimestamp
: kTextFieldWvaluethWithoutTimestamp;
std::stringstream label;
label << ValueOf(node) << "H" << HeightOf(node) << "S" << SizeOf(node);
if (print_timestamp) label << "T" << TimestampOf(node);
const size_t label_length = label.str().length();
if (label_length < text_field_wvalueth)
label << std::string(text_field_wvalueth - label_length, '-');
out << label.str();
PrintTreeInternal(out, RightOf(node), depth + 1, print_timestamp);
if (LeftOf(node)) {
out << std::string(depth * text_field_wvalueth, ' ');
PrintTreeInternal(out, LeftOf(node), depth + 1, print_timestamp);
}
}
void CheckTree(const MoveToFrontTester& mtf, const std::string& expected,
bool print_timestamp = false) {
std::stringstream ss;
mtf.PrintTree(ss, print_timestamp);
EXPECT_EQ(expected, ss.str());
}
TEST(MoveToFront, EmptyTree) {
MoveToFrontTester mtf;
CheckTree(mtf, std::string());
}
TEST(MoveToFront, InsertLeftRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(30);
mtf.TestInsert(20);
CheckTree(mtf, std::string(R"(
30H2S2----20H1S1----D2
)")
.substr(1));
mtf.TestInsert(10);
CheckTree(mtf, std::string(R"(
20H2S3----10H1S1----D2
30H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, InsertRightRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(20);
CheckTree(mtf, std::string(R"(
10H2S2----D1
20H1S1----D2
)")
.substr(1));
mtf.TestInsert(30);
CheckTree(mtf, std::string(R"(
20H2S3----10H1S1----D2
30H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, InsertRightLeftRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(30);
mtf.TestInsert(20);
CheckTree(mtf, std::string(R"(
30H2S2----20H1S1----D2
)")
.substr(1));
mtf.TestInsert(25);
CheckTree(mtf, std::string(R"(
25H2S3----20H1S1----D2
30H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, InsertLeftRightRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(20);
CheckTree(mtf, std::string(R"(
10H2S2----D1
20H1S1----D2
)")
.substr(1));
mtf.TestInsert(15);
CheckTree(mtf, std::string(R"(
15H2S3----10H1S1----D2
20H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, RemoveSingleton) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
CheckTree(mtf, std::string(R"(
10H1S1----D1
)")
.substr(1));
mtf.TestRemove(10);
CheckTree(mtf, "");
}
TEST(MoveToFront, RemoveRootWithScapegoat) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(5);
mtf.TestInsert(15);
CheckTree(mtf, std::string(R"(
10H2S3----5H1S1-----D2
15H1S1----D2
)")
.substr(1));
mtf.TestRemove(10);
CheckTree(mtf, std::string(R"(
15H2S2----5H1S1-----D2
)")
.substr(1));
}
TEST(MoveToFront, RemoveRightRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(5);
mtf.TestInsert(15);
mtf.TestInsert(20);
CheckTree(mtf, std::string(R"(
10H3S4----5H1S1-----D2
15H2S2----D2
20H1S1----D3
)")
.substr(1));
mtf.TestRemove(5);
CheckTree(mtf, std::string(R"(
15H2S3----10H1S1----D2
20H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, RemoveLeftRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(15);
mtf.TestInsert(5);
mtf.TestInsert(1);
CheckTree(mtf, std::string(R"(
10H3S4----5H2S2-----1H1S1-----D3
15H1S1----D2
)")
.substr(1));
mtf.TestRemove(15);
CheckTree(mtf, std::string(R"(
5H2S3-----1H1S1-----D2
10H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, RemoveLeftRightRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(15);
mtf.TestInsert(5);
mtf.TestInsert(12);
CheckTree(mtf, std::string(R"(
10H3S4----5H1S1-----D2
15H2S2----12H1S1----D3
)")
.substr(1));
mtf.TestRemove(5);
CheckTree(mtf, std::string(R"(
12H2S3----10H1S1----D2
15H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, RemoveRightLeftRotation) {
MoveToFrontTester mtf;
mtf.TestInsert(10);
mtf.TestInsert(15);
mtf.TestInsert(5);
mtf.TestInsert(8);
CheckTree(mtf, std::string(R"(
10H3S4----5H2S2-----D2
8H1S1-----D3
15H1S1----D2
)")
.substr(1));
mtf.TestRemove(15);
CheckTree(mtf, std::string(R"(
8H2S3-----5H1S1-----D2
10H1S1----D2
)")
.substr(1));
}
TEST(MoveToFront, MultipleOperations) {
MoveToFrontTester mtf;
std::vector<uint32_t> vals = {5, 11, 12, 16, 15, 6, 14, 2,
7, 10, 4, 8, 9, 3, 1, 13};
for (uint32_t i : vals) {
mtf.TestInsert(i);
}
CheckTree(mtf, std::string(R"(
11H5S16---5H4S10----3H3S4-----2H2S2-----1H1S1-----D5
4H1S1-----D4
7H3S5-----6H1S1-----D4
9H2S3-----8H1S1-----D5
10H1S1----D5
15H3S5----13H2S3----12H1S1----D4
14H1S1----D4
16H1S1----D3
)")
.substr(1));
mtf.TestRemove(11);
CheckTree(mtf, std::string(R"(
10H5S15---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
4H1S1-----D4
7H3S4-----6H1S1-----D4
9H2S2-----8H1S1-----D5
15H3S5----13H2S3----12H1S1----D4
14H1S1----D4
16H1S1----D3
)")
.substr(1));
mtf.TestInsert(11);
CheckTree(mtf, std::string(R"(
10H5S16---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
4H1S1-----D4
7H3S4-----6H1S1-----D4
9H2S2-----8H1S1-----D5
13H3S6----12H2S2----11H1S1----D4
15H2S3----14H1S1----D4
16H1S1----D4
)")
.substr(1));
mtf.TestRemove(5);
CheckTree(mtf, std::string(R"(
10H5S15---6H4S8-----3H3S4-----2H2S2-----1H1S1-----D5
4H1S1-----D4
8H2S3-----7H1S1-----D4
9H1S1-----D4
13H3S6----12H2S2----11H1S1----D4
15H2S3----14H1S1----D4
16H1S1----D4
)")
.substr(1));
mtf.TestInsert(5);
CheckTree(mtf, std::string(R"(
10H5S16---6H4S9-----3H3S5-----2H2S2-----1H1S1-----D5
4H2S2-----D4
5H1S1-----D5
8H2S3-----7H1S1-----D4
9H1S1-----D4
13H3S6----12H2S2----11H1S1----D4
15H2S3----14H1S1----D4
16H1S1----D4
)")
.substr(1));
mtf.TestRemove(2);
mtf.TestRemove(1);
mtf.TestRemove(4);
mtf.TestRemove(3);
mtf.TestRemove(6);
mtf.TestRemove(5);
mtf.TestRemove(7);
mtf.TestRemove(9);
CheckTree(mtf, std::string(R"(
13H4S8----10H3S4----8H1S1-----D3
12H2S2----11H1S1----D4
15H2S3----14H1S1----D3
16H1S1----D3
)")
.substr(1));
}
TEST(MoveToFront, BiggerScaleTreeTest) {
MoveToFrontTester mtf;
std::set<uint32_t> all_vals;
const uint32_t kMagic1 = 2654435761;
const uint32_t kMagic2 = 10000;
for (uint32_t i = 1; i < 1000; ++i) {
const uint32_t val = (i * kMagic1) % kMagic2;
if (!all_vals.count(val)) {
mtf.TestInsert(val);
all_vals.insert(val);
}
}
for (uint32_t i = 1; i < 1000; ++i) {
const uint32_t val = (i * kMagic1) % kMagic2;
if (val % 2 == 0) {
mtf.TestRemove(val);
all_vals.erase(val);
}
}
for (uint32_t i = 1000; i < 2000; ++i) {
const uint32_t val = (i * kMagic1) % kMagic2;
if (!all_vals.count(val)) {
mtf.TestInsert(val);
all_vals.insert(val);
}
}
for (uint32_t i = 1; i < 2000; ++i) {
const uint32_t val = (i * kMagic1) % kMagic2;
if (val > 50) {
mtf.TestRemove(val);
all_vals.erase(val);
}
}
EXPECT_EQ(all_vals, std::set<uint32_t>({2, 4, 11, 13, 24, 33, 35, 37, 46}));
CheckTree(mtf, std::string(R"(
33H4S9----11H3S5----2H2S2-----D3
4H1S1-----D4
13H2S2----D3
24H1S1----D4
37H2S3----35H1S1----D3
46H1S1----D3
)")
.substr(1));
}
TEST(MoveToFront, RankFromValue) {
MoveToFrontTester mtf;
uint32_t rank = 0;
EXPECT_FALSE(mtf.RankFromValue(1, &rank));
EXPECT_TRUE(mtf.Insert(1));
EXPECT_TRUE(mtf.Insert(2));
EXPECT_TRUE(mtf.Insert(3));
EXPECT_FALSE(mtf.Insert(2));
CheckTree(mtf,
std::string(R"(
2H2S3T2-------1H1S1T1-------D2
3H1S1T3-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_FALSE(mtf.RankFromValue(4, &rank));
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
EXPECT_EQ(3u, rank);
CheckTree(mtf,
std::string(R"(
3H2S3T3-------2H1S1T2-------D2
1H1S1T4-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
EXPECT_EQ(1u, rank);
EXPECT_TRUE(mtf.RankFromValue(3, &rank));
EXPECT_EQ(2u, rank);
EXPECT_TRUE(mtf.RankFromValue(2, &rank));
EXPECT_EQ(3u, rank);
EXPECT_TRUE(mtf.Insert(40));
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
EXPECT_EQ(4u, rank);
EXPECT_TRUE(mtf.Insert(50));
EXPECT_TRUE(mtf.RankFromValue(1, &rank));
EXPECT_EQ(2u, rank);
CheckTree(mtf,
std::string(R"(
2H3S5T6-------3H1S1T5-------D2
50H2S3T9------40H1S1T7------D3
1H1S1T10------D3
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.RankFromValue(50, &rank));
EXPECT_EQ(2u, rank);
EXPECT_EQ(5u, mtf.GetSize());
CheckTree(mtf,
std::string(R"(
2H3S5T6-------3H1S1T5-------D2
1H2S3T10------40H1S1T7------D3
50H1S1T11-----D3
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_FALSE(mtf.RankFromValue(0, &rank));
EXPECT_FALSE(mtf.RankFromValue(20, &rank));
}
TEST(MoveToFront, ValueFromRank) {
MoveToFrontTester mtf;
uint32_t value = 0;
EXPECT_FALSE(mtf.ValueFromRank(0, &value));
EXPECT_FALSE(mtf.ValueFromRank(1, &value));
EXPECT_TRUE(mtf.Insert(1));
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
EXPECT_TRUE(mtf.Insert(2));
EXPECT_EQ(2u, mtf.GetLastAccessedValue());
EXPECT_TRUE(mtf.Insert(3));
EXPECT_EQ(3u, mtf.GetLastAccessedValue());
EXPECT_TRUE(mtf.ValueFromRank(3, &value));
EXPECT_EQ(1u, value);
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
EXPECT_TRUE(mtf.ValueFromRank(1, &value));
EXPECT_EQ(1u, value);
EXPECT_EQ(1u, mtf.GetLastAccessedValue());
CheckTree(mtf,
std::string(R"(
3H2S3T3-------2H1S1T2-------D2
1H1S1T4-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.ValueFromRank(2, &value));
EXPECT_EQ(3u, value);
EXPECT_EQ(3u, mtf.GetSize());
CheckTree(mtf,
std::string(R"(
1H2S3T4-------2H1S1T2-------D2
3H1S1T5-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.ValueFromRank(3, &value));
EXPECT_EQ(2u, value);
CheckTree(mtf,
std::string(R"(
3H2S3T5-------1H1S1T4-------D2
2H1S1T6-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.Insert(10));
CheckTree(mtf,
std::string(R"(
3H3S4T5-------1H1S1T4-------D2
2H2S2T6-------D2
10H1S1T7------D3
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.ValueFromRank(1, &value));
EXPECT_EQ(10u, value);
}
TEST(MoveToFront, Remove) {
MoveToFrontTester mtf;
EXPECT_FALSE(mtf.Remove(1));
EXPECT_EQ(0u, mtf.GetTotalNodeCount());
EXPECT_TRUE(mtf.Insert(1));
EXPECT_TRUE(mtf.Insert(2));
EXPECT_TRUE(mtf.Insert(3));
CheckTree(mtf,
std::string(R"(
2H2S3T2-------1H1S1T1-------D2
3H1S1T3-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_EQ(1u, mtf.GetNodeHandle(1));
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
EXPECT_TRUE(mtf.Remove(1));
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
CheckTree(mtf,
std::string(R"(
2H2S2T2-------D1
3H1S1T3-------D2
)")
.substr(1),
/* print_timestamp = */ true);
uint32_t value = 0;
EXPECT_TRUE(mtf.ValueFromRank(2, &value));
EXPECT_EQ(2u, value);
CheckTree(mtf,
std::string(R"(
3H2S2T3-------D1
2H1S1T4-------D2
)")
.substr(1),
/* print_timestamp = */ true);
EXPECT_TRUE(mtf.Insert(1));
EXPECT_EQ(1u, mtf.GetNodeHandle(1));
EXPECT_EQ(3u, mtf.GetTotalNodeCount());
}
TEST(MoveToFront, LargerScale) {
MoveToFrontTester mtf;
uint32_t value = 0;
uint32_t rank = 0;
for (uint32_t i = 1; i < 1000; ++i) {
ASSERT_TRUE(mtf.Insert(i));
ASSERT_EQ(i, mtf.GetSize());
ASSERT_TRUE(mtf.RankFromValue(i, &rank));
ASSERT_EQ(1u, rank);
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
ASSERT_EQ(i, value);
}
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(1u, value);
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(2u, value);
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(3u, value);
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(4u, value);
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(5u, value);
ASSERT_TRUE(mtf.ValueFromRank(999, &value));
ASSERT_EQ(6u, value);
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
ASSERT_EQ(905u, value);
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
ASSERT_EQ(906u, value);
ASSERT_TRUE(mtf.ValueFromRank(101, &value));
ASSERT_EQ(907u, value);
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
ASSERT_EQ(805u, value);
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
ASSERT_EQ(806u, value);
ASSERT_TRUE(mtf.ValueFromRank(201, &value));
ASSERT_EQ(807u, value);
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
ASSERT_EQ(705u, value);
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
ASSERT_EQ(706u, value);
ASSERT_TRUE(mtf.ValueFromRank(301, &value));
ASSERT_EQ(707u, value);
ASSERT_TRUE(mtf.RankFromValue(605, &rank));
ASSERT_EQ(401u, rank);
ASSERT_TRUE(mtf.RankFromValue(606, &rank));
ASSERT_EQ(401u, rank);
ASSERT_TRUE(mtf.RankFromValue(607, &rank));
ASSERT_EQ(401u, rank);
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
ASSERT_EQ(607u, value);
ASSERT_TRUE(mtf.ValueFromRank(2, &value));
ASSERT_EQ(606u, value);
ASSERT_TRUE(mtf.ValueFromRank(3, &value));
ASSERT_EQ(605u, value);
ASSERT_TRUE(mtf.ValueFromRank(4, &value));
ASSERT_EQ(707u, value);
ASSERT_TRUE(mtf.ValueFromRank(5, &value));
ASSERT_EQ(706u, value);
ASSERT_TRUE(mtf.ValueFromRank(6, &value));
ASSERT_EQ(705u, value);
ASSERT_TRUE(mtf.ValueFromRank(7, &value));
ASSERT_EQ(807u, value);
ASSERT_TRUE(mtf.ValueFromRank(8, &value));
ASSERT_EQ(806u, value);
ASSERT_TRUE(mtf.ValueFromRank(9, &value));
ASSERT_EQ(805u, value);
ASSERT_TRUE(mtf.ValueFromRank(10, &value));
ASSERT_EQ(907u, value);
ASSERT_TRUE(mtf.ValueFromRank(11, &value));
ASSERT_EQ(906u, value);
ASSERT_TRUE(mtf.ValueFromRank(12, &value));
ASSERT_EQ(905u, value);
ASSERT_TRUE(mtf.ValueFromRank(13, &value));
ASSERT_EQ(6u, value);
ASSERT_TRUE(mtf.ValueFromRank(14, &value));
ASSERT_EQ(5u, value);
ASSERT_TRUE(mtf.ValueFromRank(15, &value));
ASSERT_EQ(4u, value);
ASSERT_TRUE(mtf.ValueFromRank(16, &value));
ASSERT_EQ(3u, value);
ASSERT_TRUE(mtf.ValueFromRank(17, &value));
ASSERT_EQ(2u, value);
ASSERT_TRUE(mtf.ValueFromRank(18, &value));
ASSERT_EQ(1u, value);
ASSERT_TRUE(mtf.ValueFromRank(19, &value));
ASSERT_EQ(999u, value);
ASSERT_TRUE(mtf.ValueFromRank(20, &value));
ASSERT_EQ(998u, value);
ASSERT_TRUE(mtf.ValueFromRank(21, &value));
ASSERT_EQ(997u, value);
ASSERT_TRUE(mtf.RankFromValue(997, &rank));
ASSERT_EQ(1u, rank);
ASSERT_TRUE(mtf.RankFromValue(998, &rank));
ASSERT_EQ(2u, rank);
ASSERT_TRUE(mtf.RankFromValue(996, &rank));
ASSERT_EQ(22u, rank);
ASSERT_TRUE(mtf.Remove(995));
ASSERT_TRUE(mtf.RankFromValue(994, &rank));
ASSERT_EQ(23u, rank);
for (uint32_t i = 10; i < 1000; ++i) {
if (i != 995) {
ASSERT_TRUE(mtf.Remove(i));
} else {
ASSERT_FALSE(mtf.Remove(i));
}
}
CheckTree(mtf,
std::string(R"(
6H4S9T1029----8H2S3T8-------7H1S1T7-------D3
9H1S1T9-------D3
2H3S5T1033----4H2S3T1031----5H1S1T1030----D4
3H1S1T1032----D4
1H1S1T1034----D3
)")
.substr(1),
/* print_timestamp = */ true);
ASSERT_TRUE(mtf.Insert(1000));
ASSERT_TRUE(mtf.ValueFromRank(1, &value));
ASSERT_EQ(1000u, value);
}
} // namespace
} // namespace comp
} // namespace spvtools

View File

@ -1,27 +0,0 @@
# Copyright (c) 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set(VAL_TEST_COMMON_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
)
add_spvtools_unittest(TARGET stats
SRCS stats_aggregate_test.cpp
stats_analyzer_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/spirv_stats.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/stats_analyzer.cpp
${VAL_TEST_COMMON_SRCS}
LIBS ${SPIRV_TOOLS}
)

View File

@ -1,438 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Tests for unique type declaration rules validator.
#include <string>
#include <unordered_map>
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
#include "tools/stats/spirv_stats.h"
namespace spvtools {
namespace stats {
namespace {
using spvtest::ScopedContext;
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
const spv_position_t& position,
const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: " << position.index << ": " << message << std::endl;
break;
case SPV_MSG_WARNING:
std::cout << "warning: " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cout << "info: " << position.index << ": " << message << std::endl;
break;
default:
break;
}
}
// Calls AggregateStats for binary compiled from |code|.
void CompileAndAggregateStats(const std::string& code, SpirvStats* stats,
spv_target_env env = SPV_ENV_UNIVERSAL_1_1) {
spvtools::Context ctx(env);
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
spv_binary binary;
ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(ctx.CContext(), code.c_str(),
code.size(), &binary, nullptr));
ASSERT_EQ(SPV_SUCCESS, AggregateStats(ctx.CContext(), binary->code,
binary->wordCount, nullptr, stats));
spvBinaryDestroy(binary);
}
TEST(AggregateStats, CapabilityHistogram) {
const std::string code1 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
)";
const std::string code2 = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(4u, stats.capability_hist.size());
EXPECT_EQ(0u, stats.capability_hist.count(SpvCapabilityShader));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityLinkage));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(5u, stats.capability_hist.size());
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel));
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityLinkage));
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(5u, stats.capability_hist.size());
EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
EXPECT_EQ(3u, stats.capability_hist.at(SpvCapabilityLinkage));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(5u, stats.capability_hist.size());
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityShader));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel));
EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer));
EXPECT_EQ(4u, stats.capability_hist.at(SpvCapabilityLinkage));
}
TEST(AggregateStats, ExtensionHistogram) {
const std::string code1 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpExtension "SPV_KHR_16bit_storage"
OpMemoryModel Physical32 OpenCL
)";
const std::string code2 = R"(
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_NV_viewport_array2"
OpExtension "greatest_extension_ever"
OpMemoryModel Logical GLSL450
)";
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(1u, stats.extension_hist.size());
EXPECT_EQ(0u, stats.extension_hist.count("SPV_NV_viewport_array2"));
EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(3u, stats.extension_hist.size());
EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(3u, stats.extension_hist.size());
EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2"));
EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever"));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(3u, stats.extension_hist.size());
EXPECT_EQ(2u, stats.extension_hist.at("SPV_NV_viewport_array2"));
EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage"));
EXPECT_EQ(2u, stats.extension_hist.at("greatest_extension_ever"));
}
TEST(AggregateStats, VersionHistogram) {
const std::string code1 = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(1u, stats.version_hist.size());
EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
EXPECT_EQ(2u, stats.version_hist.size());
EXPECT_EQ(1u, stats.version_hist.at(0x00010100));
EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(2u, stats.version_hist.size());
EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
EXPECT_EQ(1u, stats.version_hist.at(0x00010000));
CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0);
EXPECT_EQ(2u, stats.version_hist.size());
EXPECT_EQ(2u, stats.version_hist.at(0x00010100));
EXPECT_EQ(2u, stats.version_hist.at(0x00010000));
}
TEST(AggregateStats, GeneratorHistogram) {
const std::string code1 = R"(
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
)";
const uint32_t kGeneratorKhronosAssembler = SPV_GENERATOR_KHRONOS_ASSEMBLER
<< 16;
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(1u, stats.generator_hist.size());
EXPECT_EQ(1u, stats.generator_hist.at(kGeneratorKhronosAssembler));
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(1u, stats.generator_hist.size());
EXPECT_EQ(2u, stats.generator_hist.at(kGeneratorKhronosAssembler));
}
TEST(AggregateStats, OpcodeHistogram) {
const std::string code1 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
%u64 = OpTypeInt 64 0
%u32 = OpTypeInt 32 0
%f32 = OpTypeFloat 32
)";
const std::string code2 = R"(
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_NV_viewport_array2"
OpMemoryModel Logical GLSL450
)";
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(4u, stats.opcode_hist.size());
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpCapability));
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpMemoryModel));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(5u, stats.opcode_hist.size());
EXPECT_EQ(6u, stats.opcode_hist.at(SpvOpCapability));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpMemoryModel));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt));
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat));
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(5u, stats.opcode_hist.size());
EXPECT_EQ(10u, stats.opcode_hist.at(SpvOpCapability));
EXPECT_EQ(3u, stats.opcode_hist.at(SpvOpMemoryModel));
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(5u, stats.opcode_hist.size());
EXPECT_EQ(12u, stats.opcode_hist.at(SpvOpCapability));
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpMemoryModel));
EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat));
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension));
}
TEST(AggregateStats, OpcodeMarkovHistogram) {
const std::string code1 = R"(
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_NV_viewport_array2"
OpMemoryModel Logical GLSL450
)";
const std::string code2 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
%u64 = OpTypeInt 64 0
%u32 = OpTypeInt 32 0
%f32 = OpTypeFloat 32
)";
SpirvStats stats;
stats.opcode_markov_hist.resize(2);
CompileAndAggregateStats(code1, &stats);
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(1u,
stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
EXPECT_EQ(1u, stats.opcode_markov_hist[1].size());
EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
EXPECT_EQ(1u,
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
CompileAndAggregateStats(code2, &stats);
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
EXPECT_EQ(4u, stats.opcode_markov_hist[0].size());
EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size());
EXPECT_EQ(
4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(1u,
stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
EXPECT_EQ(1u,
stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt));
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt));
EXPECT_EQ(1u,
stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat));
EXPECT_EQ(3u, stats.opcode_markov_hist[1].size());
EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size());
EXPECT_EQ(
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(1u,
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
EXPECT_EQ(1u,
stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt));
EXPECT_EQ(1u,
stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt));
EXPECT_EQ(1u,
stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat));
}
TEST(AggregateStats, ConstantLiteralsHistogram) {
const std::string code1 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpCapability Float64
OpCapability Int16
OpCapability Int64
OpMemoryModel Physical32 OpenCL
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%u64 = OpTypeInt 64 0
%f32 = OpTypeFloat 32
%f64 = OpTypeFloat 64
%1 = OpConstant %f32 0.1
%2 = OpConstant %f32 -2
%3 = OpConstant %f64 -2
%4 = OpConstant %u16 16
%5 = OpConstant %u16 2
%6 = OpConstant %u32 32
%7 = OpConstant %u64 64
)";
const std::string code2 = R"(
OpCapability Shader
OpCapability Linkage
OpCapability Int16
OpCapability Int64
OpMemoryModel Logical GLSL450
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%s16 = OpTypeInt 16 1
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%u64 = OpTypeInt 64 0
%s64 = OpTypeInt 64 1
%1 = OpConstant %f32 0.1
%2 = OpConstant %f32 -2
%3 = OpConstant %u16 1
%4 = OpConstant %u16 16
%5 = OpConstant %u16 2
%6 = OpConstant %s16 -16
%7 = OpConstant %u32 32
%8 = OpConstant %s32 2
%9 = OpConstant %s32 -32
%10 = OpConstant %u64 64
%11 = OpConstant %s64 -64
)";
SpirvStats stats;
CompileAndAggregateStats(code1, &stats);
EXPECT_EQ(2u, stats.f32_constant_hist.size());
EXPECT_EQ(1u, stats.f64_constant_hist.size());
EXPECT_EQ(1u, stats.f32_constant_hist.at(0.1f));
EXPECT_EQ(1u, stats.f32_constant_hist.at(-2.f));
EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
EXPECT_EQ(2u, stats.u16_constant_hist.size());
EXPECT_EQ(0u, stats.s16_constant_hist.size());
EXPECT_EQ(1u, stats.u32_constant_hist.size());
EXPECT_EQ(0u, stats.s32_constant_hist.size());
EXPECT_EQ(1u, stats.u64_constant_hist.size());
EXPECT_EQ(0u, stats.s64_constant_hist.size());
EXPECT_EQ(1u, stats.u16_constant_hist.at(16));
EXPECT_EQ(1u, stats.u16_constant_hist.at(2));
EXPECT_EQ(1u, stats.u32_constant_hist.at(32));
EXPECT_EQ(1u, stats.u64_constant_hist.at(64));
CompileAndAggregateStats(code2, &stats);
EXPECT_EQ(2u, stats.f32_constant_hist.size());
EXPECT_EQ(1u, stats.f64_constant_hist.size());
EXPECT_EQ(2u, stats.f32_constant_hist.at(0.1f));
EXPECT_EQ(2u, stats.f32_constant_hist.at(-2.f));
EXPECT_EQ(1u, stats.f64_constant_hist.at(-2));
EXPECT_EQ(3u, stats.u16_constant_hist.size());
EXPECT_EQ(1u, stats.s16_constant_hist.size());
EXPECT_EQ(1u, stats.u32_constant_hist.size());
EXPECT_EQ(2u, stats.s32_constant_hist.size());
EXPECT_EQ(1u, stats.u64_constant_hist.size());
EXPECT_EQ(1u, stats.s64_constant_hist.size());
EXPECT_EQ(2u, stats.u16_constant_hist.at(16));
EXPECT_EQ(2u, stats.u16_constant_hist.at(2));
EXPECT_EQ(1u, stats.u16_constant_hist.at(1));
EXPECT_EQ(1u, stats.s16_constant_hist.at(-16));
EXPECT_EQ(2u, stats.u32_constant_hist.at(32));
EXPECT_EQ(1u, stats.s32_constant_hist.at(2));
EXPECT_EQ(1u, stats.s32_constant_hist.at(-32));
EXPECT_EQ(2u, stats.u64_constant_hist.at(64));
EXPECT_EQ(1u, stats.s64_constant_hist.at(-64));
}
} // namespace
} // namespace stats
} // namespace spvtools

View File

@ -1,174 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Tests for unique type declaration rules validator.
#include <sstream>
#include <string>
#include "source/latest_version_spirv_header.h"
#include "test/test_fixture.h"
#include "tools/stats/stats_analyzer.h"
namespace spvtools {
namespace stats {
namespace {
// Fills |stats| with some synthetic header stats, as if aggregated from 100
// modules (100 used for simpler percentage evaluation).
void FillDefaultStats(SpirvStats* stats) {
*stats = SpirvStats();
stats->version_hist[0x00010000] = 40;
stats->version_hist[0x00010100] = 60;
stats->generator_hist[0x00000000] = 64;
stats->generator_hist[0x00010000] = 1;
stats->generator_hist[0x00020000] = 2;
stats->generator_hist[0x00030000] = 3;
stats->generator_hist[0x00040000] = 4;
stats->generator_hist[0x00050000] = 5;
stats->generator_hist[0x00060000] = 6;
stats->generator_hist[0x00070000] = 7;
stats->generator_hist[0x00080000] = 8;
int num_version_entries = 0;
for (const auto& pair : stats->version_hist) {
num_version_entries += pair.second;
}
int num_generator_entries = 0;
for (const auto& pair : stats->generator_hist) {
num_generator_entries += pair.second;
}
EXPECT_EQ(num_version_entries, num_generator_entries);
}
TEST(StatsAnalyzer, Version) {
SpirvStats stats;
FillDefaultStats(&stats);
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteVersion(ss);
const std::string output = ss.str();
const std::string expected_output = "Version 1.1 60%\nVersion 1.0 40%\n";
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, Generator) {
SpirvStats stats;
FillDefaultStats(&stats);
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteGenerator(ss);
const std::string output = ss.str();
const std::string expected_output =
"Khronos 64%\nKhronos Glslang Reference Front End 8%\n"
"Khronos SPIR-V Tools Assembler 7%\nKhronos LLVM/SPIR-V Translator 6%"
"\nARM 5%\nNVIDIA 4%\nCodeplay 3%\nValve 2%\nLunarG 1%\n";
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, Capability) {
SpirvStats stats;
FillDefaultStats(&stats);
stats.capability_hist[SpvCapabilityShader] = 25;
stats.capability_hist[SpvCapabilityKernel] = 75;
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteCapability(ss);
const std::string output = ss.str();
const std::string expected_output = "Kernel 75%\nShader 25%\n";
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, Extension) {
SpirvStats stats;
FillDefaultStats(&stats);
stats.extension_hist["greatest_extension_ever"] = 1;
stats.extension_hist["worst_extension_ever"] = 10;
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteExtension(ss);
const std::string output = ss.str();
const std::string expected_output =
"worst_extension_ever 10%\ngreatest_extension_ever 1%\n";
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, Opcode) {
SpirvStats stats;
FillDefaultStats(&stats);
stats.opcode_hist[SpvOpCapability] = 20;
stats.opcode_hist[SpvOpConstant] = 80;
stats.opcode_hist[SpvOpDecorate] = 100;
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteOpcode(ss);
const std::string output = ss.str();
const std::string expected_output =
"Total unique opcodes used: 3\nDecorate 50%\n"
"Constant 40%\nCapability 10%\n";
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, OpcodeMarkov) {
SpirvStats stats;
FillDefaultStats(&stats);
stats.opcode_hist[SpvOpFMul] = 400;
stats.opcode_hist[SpvOpFAdd] = 200;
stats.opcode_hist[SpvOpFSub] = 400;
stats.opcode_markov_hist.resize(1);
auto& hist = stats.opcode_markov_hist[0];
hist[SpvOpFMul][SpvOpFAdd] = 100;
hist[SpvOpFMul][SpvOpFSub] = 300;
hist[SpvOpFAdd][SpvOpFMul] = 100;
hist[SpvOpFAdd][SpvOpFAdd] = 100;
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteOpcodeMarkov(ss);
const std::string output = ss.str();
const std::string expected_output =
"FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n"
"FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n"
"FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n"
"FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n";
EXPECT_EQ(expected_output, output);
}
} // namespace
} // namespace stats
} // namespace spvtools

View File

@ -48,13 +48,6 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
add_spvtools_tool(TARGET spirv-reduce SRCS reduce/reduce.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-reduce ${SPIRV_TOOLS})
endif()
add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-stats
SRCS stats/stats.cpp
stats/stats_analyzer.cpp
stats/stats_analyzer.h
stats/spirv_stats.cpp
stats/spirv_stats.h
LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-cfg
SRCS cfg/cfg.cpp
cfg/bin_to_dot.h
@ -62,26 +55,12 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
LIBS ${SPIRV_TOOLS})
target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
target_include_directories(spirv-stats PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt
spirv-cfg spirv-link)
if(NOT DEFINED IOS_PLATFORM)
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-reduce)
endif()
if(SPIRV_BUILD_COMPRESSION)
add_spvtools_tool(TARGET spirv-markv
SRCS comp/markv.cpp
comp/markv_model_factory.cpp
comp/markv_model_shader.cpp
LIBS SPIRV-Tools-comp SPIRV-Tools-opt ${SPIRV_TOOLS})
target_include_directories(spirv-markv PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-markv)
endif(SPIRV_BUILD_COMPRESSION)
if(ENABLE_SPIRV_TOOLS_INSTALL)
install(TARGETS ${SPIRV_INSTALL_TARGETS}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}

View File

@ -1,385 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "source/comp/markv.h"
#include "source/spirv_target_env.h"
#include "source/table.h"
#include "spirv-tools/optimizer.hpp"
#include "tools/comp/markv_model_factory.h"
#include "tools/io.h"
namespace {
const auto kSpvEnv = SPV_ENV_UNIVERSAL_1_2;
enum Task {
kNoTask = 0,
kEncode,
kDecode,
kTest,
};
struct ScopedContext {
ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {}
~ScopedContext() { spvContextDestroy(context); }
spv_context context;
};
void print_usage(char* argv0) {
printf(
R"(%s - Encodes or decodes a SPIR-V binary to or from a MARK-V binary.
USAGE: %s [e|d|t] [options] [<filename>]
The input binary is read from <filename>. If no file is specified,
or if the filename is "-", then the binary is read from standard input.
If no output is specified then the output is printed to stdout in a human
readable format.
WIP: MARK-V codec is in early stages of development. At the moment it only
can encode and decode some SPIR-V files and only if exacly the same build of
software is used (is doesn't write or handle version numbers yet).
Tasks:
e Encode SPIR-V to MARK-V.
d Decode MARK-V to SPIR-V.
t Test the codec by first encoding the given SPIR-V file to
MARK-V, then decoding it back to SPIR-V and comparing results.
Options:
-h, --help Print this help.
--comments Write codec comments to stderr.
--version Display MARK-V codec version.
--validate Validate SPIR-V while encoding or decoding.
--model=<model-name>
Compression model, possible values:
shader_lite - fast, poor compression ratio
shader_mid - balanced
shader_max - best compression ratio
Default: shader_lite
-o <filename> Set the output filename.
Output goes to standard output if this option is
not specified, or if the filename is "-".
Not needed for 't' task (testing).
)",
argv0, argv0);
}
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
const spv_position_t& position,
const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: " << position.index << ": " << message << std::endl;
break;
case SPV_MSG_WARNING:
std::cerr << "warning: " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cerr << "info: " << position.index << ": " << message << std::endl;
break;
default:
break;
}
}
} // namespace
int main(int argc, char** argv) {
const char* input_filename = nullptr;
const char* output_filename = nullptr;
Task task = kNoTask;
if (argc < 3) {
print_usage(argv[0]);
return 0;
}
const char* task_char = argv[1];
if (0 == strcmp("e", task_char)) {
task = kEncode;
} else if (0 == strcmp("d", task_char)) {
task = kDecode;
} else if (0 == strcmp("t", task_char)) {
task = kTest;
}
if (task == kNoTask) {
print_usage(argv[0]);
return 1;
}
bool want_comments = false;
bool validate_spirv_binary = false;
spvtools::comp::MarkvModelType model_type =
spvtools::comp::kMarkvModelUnknown;
for (int argi = 2; argi < argc; ++argi) {
if ('-' == argv[argi][0]) {
switch (argv[argi][1]) {
case 'h':
print_usage(argv[0]);
return 0;
case 'o': {
if (!output_filename && argi + 1 < argc &&
(task == kEncode || task == kDecode)) {
output_filename = argv[++argi];
} else {
print_usage(argv[0]);
return 1;
}
} break;
case '-': {
if (0 == strcmp(argv[argi], "--help")) {
print_usage(argv[0]);
return 0;
} else if (0 == strcmp(argv[argi], "--comments")) {
want_comments = true;
} else if (0 == strcmp(argv[argi], "--version")) {
fprintf(stderr, "error: Not implemented\n");
return 1;
} else if (0 == strcmp(argv[argi], "--validate")) {
validate_spirv_binary = true;
} else if (0 == strcmp(argv[argi], "--model=shader_lite")) {
if (model_type != spvtools::comp::kMarkvModelUnknown)
fprintf(stderr, "error: More than one model specified\n");
model_type = spvtools::comp::kMarkvModelShaderLite;
} else if (0 == strcmp(argv[argi], "--model=shader_mid")) {
if (model_type != spvtools::comp::kMarkvModelUnknown)
fprintf(stderr, "error: More than one model specified\n");
model_type = spvtools::comp::kMarkvModelShaderMid;
} else if (0 == strcmp(argv[argi], "--model=shader_max")) {
if (model_type != spvtools::comp::kMarkvModelUnknown)
fprintf(stderr, "error: More than one model specified\n");
model_type = spvtools::comp::kMarkvModelShaderMax;
} else {
print_usage(argv[0]);
return 1;
}
} break;
case '\0': {
// Setting a filename of "-" to indicate stdin.
if (!input_filename) {
input_filename = argv[argi];
} else {
fprintf(stderr, "error: More than one input file specified\n");
return 1;
}
} break;
default:
print_usage(argv[0]);
return 1;
}
} else {
if (!input_filename) {
input_filename = argv[argi];
} else {
fprintf(stderr, "error: More than one input file specified\n");
return 1;
}
}
}
if (model_type == spvtools::comp::kMarkvModelUnknown)
model_type = spvtools::comp::kMarkvModelShaderLite;
const auto no_comments = spvtools::comp::MarkvLogConsumer();
const auto output_to_stderr = [](const std::string& str) {
std::cerr << str;
};
ScopedContext ctx(kSpvEnv);
std::unique_ptr<spvtools::comp::MarkvModel> model =
spvtools::comp::CreateMarkvModel(model_type);
std::vector<uint32_t> spirv;
std::vector<uint8_t> markv;
spvtools::comp::MarkvCodecOptions options;
options.validate_spirv_binary = validate_spirv_binary;
if (task == kEncode) {
if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
assert(!spirv.empty());
if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
ctx.context, spirv, options, *model,
DiagnosticsMessageHandler,
want_comments ? output_to_stderr : no_comments,
spvtools::comp::MarkvDebugConsumer(), &markv)) {
std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
<< std::endl;
return 1;
}
if (!WriteFile<uint8_t>(output_filename, "wb", markv.data(), markv.size()))
return 1;
} else if (task == kDecode) {
if (!ReadFile<uint8_t>(input_filename, "rb", &markv)) return 1;
assert(!markv.empty());
if (SPV_SUCCESS != spvtools::comp::MarkvToSpirv(
ctx.context, markv, options, *model,
DiagnosticsMessageHandler,
want_comments ? output_to_stderr : no_comments,
spvtools::comp::MarkvDebugConsumer(), &spirv)) {
std::cerr << "error: Failed to decode " << input_filename << " to SPIR-V "
<< std::endl;
return 1;
}
if (!WriteFile<uint32_t>(output_filename, "wb", spirv.data(), spirv.size()))
return 1;
} else if (task == kTest) {
if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
assert(!spirv.empty());
std::vector<uint32_t> spirv_before;
spvtools::Optimizer optimizer(kSpvEnv);
optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
if (!optimizer.Run(spirv.data(), spirv.size(), &spirv_before)) {
std::cerr << "error: Optimizer failure on: " << input_filename
<< std::endl;
}
std::vector<std::string> encoder_instruction_bits;
std::vector<std::string> encoder_instruction_comments;
std::vector<std::vector<uint32_t>> encoder_instruction_words;
std::vector<std::string> decoder_instruction_bits;
std::vector<std::string> decoder_instruction_comments;
std::vector<std::vector<uint32_t>> decoder_instruction_words;
const auto encoder_debug_consumer = [&](const std::vector<uint32_t>& words,
const std::string& bits,
const std::string& comment) {
encoder_instruction_words.push_back(words);
encoder_instruction_bits.push_back(bits);
encoder_instruction_comments.push_back(comment);
return true;
};
if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
ctx.context, spirv_before, options, *model,
DiagnosticsMessageHandler,
want_comments ? output_to_stderr : no_comments,
encoder_debug_consumer, &markv)) {
std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
<< std::endl;
return 1;
}
const auto write_bug_report = [&]() {
for (size_t inst_index = 0; inst_index < decoder_instruction_words.size();
++inst_index) {
std::cerr << "\nInstruction #" << inst_index << std::endl;
std::cerr << "\nEncoder words: ";
for (uint32_t word : encoder_instruction_words[inst_index])
std::cerr << word << " ";
std::cerr << "\nDecoder words: ";
for (uint32_t word : decoder_instruction_words[inst_index])
std::cerr << word << " ";
std::cerr << std::endl;
std::cerr << "\nEncoder bits: " << encoder_instruction_bits[inst_index];
std::cerr << "\nDecoder bits: " << decoder_instruction_bits[inst_index];
std::cerr << std::endl;
std::cerr << "\nEncoder comments:\n"
<< encoder_instruction_comments[inst_index];
std::cerr << "Decoder comments:\n"
<< decoder_instruction_comments[inst_index];
std::cerr << std::endl;
}
};
const auto decoder_debug_consumer = [&](const std::vector<uint32_t>& words,
const std::string& bits,
const std::string& comment) {
const size_t inst_index = decoder_instruction_words.size();
if (inst_index >= encoder_instruction_words.size()) {
write_bug_report();
std::cerr << "error: Decoder has more instructions than encoder: "
<< input_filename << std::endl;
return false;
}
decoder_instruction_words.push_back(words);
decoder_instruction_bits.push_back(bits);
decoder_instruction_comments.push_back(comment);
if (encoder_instruction_words[inst_index] !=
decoder_instruction_words[inst_index]) {
write_bug_report();
std::cerr << "error: Words of the last decoded instruction differ from "
"reference: "
<< input_filename << std::endl;
return false;
}
if (encoder_instruction_bits[inst_index] !=
decoder_instruction_bits[inst_index]) {
write_bug_report();
std::cerr << "error: Bits of the last decoded instruction differ from "
"reference: "
<< input_filename << std::endl;
return false;
}
return true;
};
std::vector<uint32_t> spirv_after;
const spv_result_t decoding_result = spvtools::comp::MarkvToSpirv(
ctx.context, markv, options, *model, DiagnosticsMessageHandler,
want_comments ? output_to_stderr : no_comments, decoder_debug_consumer,
&spirv_after);
if (decoding_result == SPV_REQUESTED_TERMINATION) {
std::cerr << "error: Decoding interrupted by the debugger: "
<< input_filename << std::endl;
return 1;
}
if (decoding_result != SPV_SUCCESS) {
std::cerr << "error: Failed to decode encoded " << input_filename
<< " back to SPIR-V " << std::endl;
return 1;
}
assert(spirv_before.size() == spirv_after.size());
assert(std::mismatch(std::next(spirv_before.begin(), 5), spirv_before.end(),
std::next(spirv_after.begin(), 5)) ==
std::make_pair(spirv_before.end(), spirv_after.end()));
}
return 0;
}

View File

@ -1,50 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tools/comp/markv_model_factory.h"
#include "source/util/make_unique.h"
#include "tools/comp/markv_model_shader.h"
namespace spvtools {
namespace comp {
std::unique_ptr<MarkvModel> CreateMarkvModel(MarkvModelType type) {
std::unique_ptr<MarkvModel> model;
switch (type) {
case kMarkvModelShaderLite: {
model = MakeUnique<MarkvModelShaderLite>();
break;
}
case kMarkvModelShaderMid: {
model = MakeUnique<MarkvModelShaderMid>();
break;
}
case kMarkvModelShaderMax: {
model = MakeUnique<MarkvModelShaderMax>();
break;
}
case kMarkvModelUnknown: {
assert(0 && "kMarkvModelUnknown supplied to CreateMarkvModel");
return model;
}
}
model->SetModelType(static_cast<uint32_t>(type));
return model;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,37 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TOOLS_COMP_MARKV_MODEL_FACTORY_H_
#define TOOLS_COMP_MARKV_MODEL_FACTORY_H_
#include <memory>
#include "source/comp/markv_model.h"
namespace spvtools {
namespace comp {
enum MarkvModelType {
kMarkvModelUnknown = 0,
kMarkvModelShaderLite,
kMarkvModelShaderMid,
kMarkvModelShaderMax,
};
std::unique_ptr<MarkvModel> CreateMarkvModel(MarkvModelType type);
} // namespace comp
} // namespace spvtools
#endif // TOOLS_COMP_MARKV_MODEL_FACTORY_H_

View File

@ -1,84 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tools/comp/markv_model_shader.h"
#include <algorithm>
#include <map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "source/util/make_unique.h"
namespace spvtools {
namespace comp {
namespace {
// Signals that the value is not in the coding scheme and a fallback method
// needs to be used.
const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove();
inline uint32_t CombineOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands) {
return opcode | (num_operands << 16);
}
#include "tools/comp/markv_model_shader_default_autogen.inc"
} // namespace
MarkvModelShaderLite::MarkvModelShaderLite() {
const uint16_t kVersionNumber = 1;
SetModelVersion(kVersionNumber);
opcode_and_num_operands_huffman_codec_ =
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor;
}
MarkvModelShaderMid::MarkvModelShaderMid() {
const uint16_t kVersionNumber = 1;
SetModelVersion(kVersionNumber);
opcode_and_num_operands_huffman_codec_ =
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs();
id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs();
descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme();
literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs();
id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor;
}
MarkvModelShaderMax::MarkvModelShaderMax() {
const uint16_t kVersionNumber = 1;
SetModelVersion(kVersionNumber);
opcode_and_num_operands_huffman_codec_ =
MakeUnique<HuffmanCodec<uint64_t>>(GetOpcodeAndNumOperandsHist());
opcode_and_num_operands_markov_huffman_codecs_ =
GetOpcodeAndNumOperandsMarkovHuffmanCodecs();
non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs();
id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs();
descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme();
literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs();
id_fallback_strategy_ = IdFallbackStrategy::kRuleBased;
}
} // namespace comp
} // namespace spvtools

View File

@ -1,47 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TOOLS_COMP_MARKV_MODEL_SHADER_H_
#define TOOLS_COMP_MARKV_MODEL_SHADER_H_
#include "source/comp/markv_model.h"
namespace spvtools {
namespace comp {
// MARK-V shader compression model, which only uses fast and lightweight
// algorithms, which do not require training and are not heavily dependent on
// SPIR-V grammar. Compression ratio is worse than by other models.
class MarkvModelShaderLite : public MarkvModel {
public:
MarkvModelShaderLite();
};
// MARK-V shader compression model with balanced compression ratio and runtime
// performance.
class MarkvModelShaderMid : public MarkvModel {
public:
MarkvModelShaderMid();
};
// MARK-V shader compression model designed for maximum compression.
class MarkvModelShaderMax : public MarkvModel {
public:
MarkvModelShaderMax();
};
} // namespace comp
} // namespace spvtools
#endif // TOOLS_COMP_MARKV_MODEL_SHADER_H_

File diff suppressed because it is too large Load Diff

View File

@ -1,165 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tools/stats/spirv_stats.h"
#include <cassert>
#include <algorithm>
#include <memory>
#include <string>
#include "source/diagnostic.h"
#include "source/enum_string_mapping.h"
#include "source/extensions.h"
#include "source/id_descriptor.h"
#include "source/instruction.h"
#include "source/opcode.h"
#include "source/operand.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"
#include "spirv-tools/libspirv.h"
namespace spvtools {
namespace stats {
namespace {
// Helper class for stats aggregation. Receives as in/out parameter.
// Constructs ValidationState and updates it by running validator for each
// instruction.
class StatsAggregator {
public:
StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state)
: stats_(in_out_stats), vstate_(state) {}
// Processes the instructions to collect stats.
void aggregate() {
const auto& instructions = vstate_->ordered_instructions();
++stats_->version_hist[vstate_->version()];
++stats_->generator_hist[vstate_->generator()];
for (size_t i = 0; i < instructions.size(); ++i) {
const auto& inst = instructions[i];
ProcessOpcode(&inst, i);
ProcessCapability(&inst);
ProcessExtension(&inst);
ProcessConstant(&inst);
}
}
// Collects OpCapability statistics.
void ProcessCapability(const val::Instruction* inst) {
if (inst->opcode() != SpvOpCapability) return;
const uint32_t capability = inst->word(inst->operands()[0].offset);
++stats_->capability_hist[capability];
}
// Collects OpExtension statistics.
void ProcessExtension(const val::Instruction* inst) {
if (inst->opcode() != SpvOpExtension) return;
const std::string extension = GetExtensionString(&inst->c_inst());
++stats_->extension_hist[extension];
}
// Collects OpCode statistics.
void ProcessOpcode(const val::Instruction* inst, size_t idx) {
const SpvOp opcode = inst->opcode();
++stats_->opcode_hist[opcode];
if (idx == 0) return;
--idx;
const auto& instructions = vstate_->ordered_instructions();
auto step_it = stats_->opcode_markov_hist.begin();
for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) {
auto& hist = (*step_it)[instructions[idx].opcode()];
++hist[opcode];
if (idx == 0) break;
}
}
// Collects OpConstant statistics.
void ProcessConstant(const val::Instruction* inst) {
if (inst->opcode() != SpvOpConstant) return;
const uint32_t type_id = inst->GetOperandAs<uint32_t>(0);
const auto type_decl_it = vstate_->all_definitions().find(type_id);
assert(type_decl_it != vstate_->all_definitions().end());
const val::Instruction& type_decl_inst = *type_decl_it->second;
const SpvOp type_op = type_decl_inst.opcode();
if (type_op == SpvOpTypeInt) {
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2);
assert(is_signed == 0 || is_signed == 1);
if (bit_width == 16) {
if (is_signed)
++stats_->s16_constant_hist[inst->GetOperandAs<int16_t>(2)];
else
++stats_->u16_constant_hist[inst->GetOperandAs<uint16_t>(2)];
} else if (bit_width == 32) {
if (is_signed)
++stats_->s32_constant_hist[inst->GetOperandAs<int32_t>(2)];
else
++stats_->u32_constant_hist[inst->GetOperandAs<uint32_t>(2)];
} else if (bit_width == 64) {
if (is_signed)
++stats_->s64_constant_hist[inst->GetOperandAs<int64_t>(2)];
else
++stats_->u64_constant_hist[inst->GetOperandAs<uint64_t>(2)];
} else {
assert(false && "TypeInt bit width is not 16, 32 or 64");
}
} else if (type_op == SpvOpTypeFloat) {
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
if (bit_width == 32) {
++stats_->f32_constant_hist[inst->GetOperandAs<float>(2)];
} else if (bit_width == 64) {
++stats_->f64_constant_hist[inst->GetOperandAs<double>(2)];
} else {
assert(bit_width == 16);
}
}
}
private:
SpirvStats* stats_;
const val::ValidationState_t* vstate_;
IdDescriptorCollection id_descriptors_;
};
} // namespace
spv_result_t AggregateStats(const spv_context context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
SpirvStats* stats) {
std::unique_ptr<val::ValidationState_t> vstate;
spv_validator_options_t options;
spv_result_t result = ValidateBinaryAndKeepValidationState(
context, &options, words, num_words, pDiagnostic, &vstate);
if (result != SPV_SUCCESS) return result;
StatsAggregator stats_aggregator(stats, vstate.get());
stats_aggregator.aggregate();
return SPV_SUCCESS;
}
} // namespace stats
} // namespace spvtools

View File

@ -1,93 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TOOLS_STATS_SPIRV_STATS_H_
#define TOOLS_STATS_SPIRV_STATS_H_
#include <map>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace stats {
struct SpirvStats {
// Version histogram, version_word -> count.
std::unordered_map<uint32_t, uint32_t> version_hist;
// Generator histogram, generator_word -> count.
std::unordered_map<uint32_t, uint32_t> generator_hist;
// Capability histogram, SpvCapabilityXXX -> count.
std::unordered_map<uint32_t, uint32_t> capability_hist;
// Extension histogram, extension_string -> count.
std::unordered_map<std::string, uint32_t> extension_hist;
// Opcode histogram, SpvOpXXX -> count.
std::unordered_map<uint32_t, uint32_t> opcode_hist;
// OpConstant u16 histogram, value -> count.
std::unordered_map<uint16_t, uint32_t> u16_constant_hist;
// OpConstant u32 histogram, value -> count.
std::unordered_map<uint32_t, uint32_t> u32_constant_hist;
// OpConstant u64 histogram, value -> count.
std::unordered_map<uint64_t, uint32_t> u64_constant_hist;
// OpConstant s16 histogram, value -> count.
std::unordered_map<int16_t, uint32_t> s16_constant_hist;
// OpConstant s32 histogram, value -> count.
std::unordered_map<int32_t, uint32_t> s32_constant_hist;
// OpConstant s64 histogram, value -> count.
std::unordered_map<int64_t, uint32_t> s64_constant_hist;
// OpConstant f32 histogram, value -> count.
std::unordered_map<float, uint32_t> f32_constant_hist;
// OpConstant f64 histogram, value -> count.
std::unordered_map<double, uint32_t> f64_constant_hist;
// Used to collect statistics on opcodes triggering other opcodes.
// Container scheme: gap between instructions -> cue opcode -> later opcode
// -> count.
// For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to
// the number of times an OpMul appears, followed by 2 other instructions,
// followed by OpFAdd.
// opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times
// OpFMul appears, directly followed by OpFAdd.
// The size of the outer std::vector also serves as an input parameter,
// determining how many steps will be collected.
// I.e. do opcode_markov_hist.resize(1) to collect data for one step only.
std::vector<
std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
opcode_markov_hist;
};
// Aggregates existing |stats| with new stats extracted from |binary|.
spv_result_t AggregateStats(const spv_context context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
SpirvStats* stats);
} // namespace stats
} // namespace spvtools
#endif // TOOLS_STATS_SPIRV_STATS_H_

View File

@ -1,173 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <fstream>
#include <iostream>
#include <unordered_map>
#include <vector>
#include "spirv-tools/libspirv.h"
#include "tools/io.h"
#include "tools/stats/spirv_stats.h"
#include "tools/stats/stats_analyzer.h"
namespace {
void PrintUsage(char* argv0) {
printf(
R"(%s - Collect statistics from one or more SPIR-V binary file(s).
USAGE: %s [options] [<filepaths>]
TIP: In order to collect statistics from all .spv files under current dir use
find . -name "*.spv" -print0 | xargs -0 -s 2000000 %s
Options:
-h, --help
Print this help.
-v, --verbose
Print additional info to stderr.
)",
argv0, argv0, argv0);
}
void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
const spv_position_t& position,
const char* message) {
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
std::cerr << "error: " << position.index << ": " << message << std::endl;
break;
case SPV_MSG_WARNING:
std::cout << "warning: " << position.index << ": " << message
<< std::endl;
break;
case SPV_MSG_INFO:
std::cout << "info: " << position.index << ": " << message << std::endl;
break;
default:
break;
}
}
} // namespace
int main(int argc, char** argv) {
bool continue_processing = true;
int return_code = 0;
bool expect_output_path = false;
bool verbose = false;
std::vector<const char*> paths;
const char* output_path = nullptr;
for (int argi = 1; continue_processing && argi < argc; ++argi) {
const char* cur_arg = argv[argi];
if ('-' == cur_arg[0]) {
if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) {
PrintUsage(argv[0]);
continue_processing = false;
return_code = 0;
} else if (0 == strcmp(cur_arg, "--verbose") ||
0 == strcmp(cur_arg, "-v")) {
verbose = true;
} else if (0 == strcmp(cur_arg, "--output") ||
0 == strcmp(cur_arg, "-o")) {
expect_output_path = true;
} else {
PrintUsage(argv[0]);
continue_processing = false;
return_code = 1;
}
} else {
if (expect_output_path) {
output_path = cur_arg;
expect_output_path = false;
} else {
paths.push_back(cur_arg);
}
}
}
// Exit if command line parsing was not successful.
if (!continue_processing) {
return return_code;
}
std::cerr << "Processing " << paths.size() << " files..." << std::endl;
spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_1);
ctx.SetMessageConsumer(DiagnosticsMessageHandler);
spvtools::stats::SpirvStats stats;
stats.opcode_markov_hist.resize(1);
for (size_t index = 0; index < paths.size(); ++index) {
const size_t kMilestonePeriod = 1000;
if (verbose) {
if (index % kMilestonePeriod == kMilestonePeriod - 1)
std::cerr << "Processed " << index + 1 << " files..." << std::endl;
}
const char* path = paths[index];
std::vector<uint32_t> contents;
if (!ReadFile<uint32_t>(path, "rb", &contents)) return 1;
if (SPV_SUCCESS !=
spvtools::stats::AggregateStats(ctx.CContext(), contents.data(),
contents.size(), nullptr, &stats)) {
std::cerr << "error: Failed to aggregate stats for " << path << std::endl;
return 1;
}
}
spvtools::stats::StatsAnalyzer analyzer(stats);
std::ofstream fout;
if (output_path) {
fout.open(output_path);
if (!fout.is_open()) {
std::cerr << "error: Failed to open " << output_path << std::endl;
return 1;
}
}
std::ostream& out = fout.is_open() ? fout : std::cout;
out << std::endl;
analyzer.WriteVersion(out);
analyzer.WriteGenerator(out);
out << std::endl;
analyzer.WriteCapability(out);
out << std::endl;
analyzer.WriteExtension(out);
out << std::endl;
analyzer.WriteOpcode(out);
out << std::endl;
analyzer.WriteOpcodeMarkov(out);
out << std::endl;
analyzer.WriteConstantLiterals(out);
return 0;
}

View File

@ -1,235 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tools/stats/stats_analyzer.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "source/comp/markv_model.h"
#include "source/enum_string_mapping.h"
#include "source/latest_version_spirv_header.h"
#include "source/opcode.h"
#include "source/operand.h"
#include "source/spirv_constant.h"
namespace spvtools {
namespace stats {
namespace {
// Signals that the value is not in the coding scheme and a fallback method
// needs to be used.
const uint64_t kMarkvNoneOfTheAbove =
comp::MarkvModel::GetMarkvNoneOfTheAbove();
std::string GetVersionString(uint32_t word) {
std::stringstream ss;
ss << "Version " << SPV_SPIRV_VERSION_MAJOR_PART(word) << "."
<< SPV_SPIRV_VERSION_MINOR_PART(word);
return ss.str();
}
std::string GetGeneratorString(uint32_t word) {
return spvGeneratorStr(SPV_GENERATOR_TOOL_PART(word));
}
std::string GetOpcodeString(uint32_t word) {
return spvOpcodeString(static_cast<SpvOp>(word));
}
std::string GetCapabilityString(uint32_t word) {
return CapabilityToString(static_cast<SpvCapability>(word));
}
template <class T>
std::string KeyIsLabel(T key) {
std::stringstream ss;
ss << key;
return ss.str();
}
template <class Key>
std::unordered_map<Key, double> GetRecall(
const std::unordered_map<Key, uint32_t>& hist, uint64_t total) {
std::unordered_map<Key, double> freq;
for (const auto& pair : hist) {
const double frequency =
static_cast<double>(pair.second) / static_cast<double>(total);
freq.emplace(pair.first, frequency);
}
return freq;
}
template <class Key>
std::unordered_map<Key, double> GetPrevalence(
const std::unordered_map<Key, uint32_t>& hist) {
uint64_t total = 0;
for (const auto& pair : hist) {
total += pair.second;
}
return GetRecall(hist, total);
}
// Writes |freq| to |out| sorted by frequency in the following format:
// LABEL3 70%
// LABEL1 20%
// LABEL2 10%
// |label_from_key| is used to convert |Key| to label.
template <class Key>
void WriteFreq(std::ostream& out, const std::unordered_map<Key, double>& freq,
std::string (*label_from_key)(Key)) {
std::vector<std::pair<Key, double>> sorted_freq(freq.begin(), freq.end());
std::sort(sorted_freq.begin(), sorted_freq.end(),
[](const std::pair<Key, double>& left,
const std::pair<Key, double>& right) {
return left.second > right.second;
});
for (const auto& pair : sorted_freq) {
if (pair.second < 0.001) break;
out << label_from_key(pair.first) << " " << pair.second * 100.0 << "%"
<< std::endl;
}
}
} // namespace
StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) {
num_modules_ = 0;
for (const auto& pair : stats_.version_hist) {
num_modules_ += pair.second;
}
version_freq_ = GetRecall(stats_.version_hist, num_modules_);
generator_freq_ = GetRecall(stats_.generator_hist, num_modules_);
capability_freq_ = GetRecall(stats_.capability_hist, num_modules_);
extension_freq_ = GetRecall(stats_.extension_hist, num_modules_);
opcode_freq_ = GetPrevalence(stats_.opcode_hist);
}
void StatsAnalyzer::WriteVersion(std::ostream& out) {
WriteFreq(out, version_freq_, GetVersionString);
}
void StatsAnalyzer::WriteGenerator(std::ostream& out) {
WriteFreq(out, generator_freq_, GetGeneratorString);
}
void StatsAnalyzer::WriteCapability(std::ostream& out) {
WriteFreq(out, capability_freq_, GetCapabilityString);
}
void StatsAnalyzer::WriteExtension(std::ostream& out) {
WriteFreq(out, extension_freq_, KeyIsLabel);
}
void StatsAnalyzer::WriteOpcode(std::ostream& out) {
out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl;
WriteFreq(out, opcode_freq_, GetOpcodeString);
}
void StatsAnalyzer::WriteConstantLiterals(std::ostream& out) {
out << "Constant literals" << std::endl;
out << "Float 32" << std::endl;
WriteFreq(out, GetPrevalence(stats_.f32_constant_hist), KeyIsLabel);
out << std::endl << "Float 64" << std::endl;
WriteFreq(out, GetPrevalence(stats_.f64_constant_hist), KeyIsLabel);
out << std::endl << "Unsigned int 16" << std::endl;
WriteFreq(out, GetPrevalence(stats_.u16_constant_hist), KeyIsLabel);
out << std::endl << "Signed int 16" << std::endl;
WriteFreq(out, GetPrevalence(stats_.s16_constant_hist), KeyIsLabel);
out << std::endl << "Unsigned int 32" << std::endl;
WriteFreq(out, GetPrevalence(stats_.u32_constant_hist), KeyIsLabel);
out << std::endl << "Signed int 32" << std::endl;
WriteFreq(out, GetPrevalence(stats_.s32_constant_hist), KeyIsLabel);
out << std::endl << "Unsigned int 64" << std::endl;
WriteFreq(out, GetPrevalence(stats_.u64_constant_hist), KeyIsLabel);
out << std::endl << "Signed int 64" << std::endl;
WriteFreq(out, GetPrevalence(stats_.s64_constant_hist), KeyIsLabel);
}
void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) {
if (stats_.opcode_markov_hist.empty()) return;
const std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
cue_to_hist = stats_.opcode_markov_hist[0];
// Sort by prevalence of the opcodes in opcode_freq_ (descending).
std::vector<std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end());
std::sort(
sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(),
[this](const std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
left,
const std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
right) {
const double lf = opcode_freq_[left.first];
const double rf = opcode_freq_[right.first];
if (lf == rf) return right.first > left.first;
return lf > rf;
});
for (const auto& kv : sorted_cue_to_hist) {
const uint32_t cue = kv.first;
const double kFrequentEnoughToAnalyze = 0.0001;
if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue;
const std::unordered_map<uint32_t, uint32_t>& hist = kv.second;
uint32_t total = 0;
for (const auto& pair : hist) {
total += pair.second;
}
std::vector<std::pair<uint32_t, uint32_t>> sorted_hist(hist.begin(),
hist.end());
std::sort(sorted_hist.begin(), sorted_hist.end(),
[](const std::pair<uint32_t, uint32_t>& left,
const std::pair<uint32_t, uint32_t>& right) {
if (left.second == right.second)
return right.first > left.first;
return left.second > right.second;
});
for (const auto& pair : sorted_hist) {
const double prior = opcode_freq_[pair.first];
const double posterior =
static_cast<double>(pair.second) / static_cast<double>(total);
out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first)
<< " " << posterior * 100 << "% (base rate " << prior * 100
<< "%, pair occurrences " << pair.second << ")" << std::endl;
}
}
}
} // namespace stats
} // namespace spvtools

View File

@ -1,58 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TOOLS_STATS_STATS_ANALYZER_H_
#define TOOLS_STATS_STATS_ANALYZER_H_
#include <string>
#include <unordered_map>
#include "tools/stats/spirv_stats.h"
namespace spvtools {
namespace stats {
class StatsAnalyzer {
public:
explicit StatsAnalyzer(const SpirvStats& stats);
// Writes respective histograms to |out|.
void WriteVersion(std::ostream& out);
void WriteGenerator(std::ostream& out);
void WriteCapability(std::ostream& out);
void WriteExtension(std::ostream& out);
void WriteOpcode(std::ostream& out);
void WriteConstantLiterals(std::ostream& out);
// Writes first order Markov analysis to |out|.
// stats_.opcode_markov_hist needs to contain raw data for at least one
// level.
void WriteOpcodeMarkov(std::ostream& out);
private:
const SpirvStats& stats_;
uint32_t num_modules_;
std::unordered_map<uint32_t, double> version_freq_;
std::unordered_map<uint32_t, double> generator_freq_;
std::unordered_map<uint32_t, double> capability_freq_;
std::unordered_map<std::string, double> extension_freq_;
std::unordered_map<uint32_t, double> opcode_freq_;
};
} // namespace stats
} // namespace spvtools
#endif // TOOLS_STATS_STATS_ANALYZER_H_