mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-12-29 03:01:08 +00:00
e9e4393b1c
Visual Studio was complaining about possible loss of data on 64-bit builds, due to an implicit cast from size_t to int. This changes the data to use an int with no cast.
1557 lines
54 KiB
C++
1557 lines
54 KiB
C++
// Copyright (c) 2017 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Contains
|
|
// - SPIR-V to MARK-V encoder
|
|
// - MARK-V to SPIR-V decoder
|
|
//
|
|
// MARK-V is a compression format for SPIR-V binaries. It strips away
|
|
// non-essential information (such as result ids which can be regenerated) and
|
|
// uses various bit reduction techiniques to reduce the size of the binary.
|
|
//
|
|
// MarkvModel is a flatbuffers object containing a set of rules defining how
|
|
// compression/decompression is done (coding schemes, dictionaries).
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "binary.h"
|
|
#include "diagnostic.h"
|
|
#include "enum_string_mapping.h"
|
|
#include "extensions.h"
|
|
#include "ext_inst.h"
|
|
#include "instruction.h"
|
|
#include "opcode.h"
|
|
#include "operand.h"
|
|
#include "spirv-tools/libspirv.h"
|
|
#include "spirv-tools/markv.h"
|
|
#include "spirv_endian.h"
|
|
#include "spirv_validator_options.h"
|
|
#include "util/bit_stream.h"
|
|
#include "util/parse_number.h"
|
|
#include "validate.h"
|
|
#include "val/instruction.h"
|
|
#include "val/validation_state.h"
|
|
|
|
using libspirv::Instruction;
|
|
using libspirv::ValidationState_t;
|
|
using spvtools::ValidateInstructionAndUpdateValidationState;
|
|
using spvutils::BitReaderWord64;
|
|
using spvutils::BitWriterWord64;
|
|
|
|
struct spv_markv_encoder_options_t {
|
|
};
|
|
|
|
struct spv_markv_decoder_options_t {
|
|
};
|
|
|
|
namespace {
|
|
|
|
const uint32_t kSpirvMagicNumber = SpvMagicNumber;
|
|
const uint32_t kMarkvMagicNumber = 0x07230303;
|
|
|
|
enum {
|
|
kMarkvFirstOpcode = 65536,
|
|
kMarkvOpNextInstructionEncodesResultId = 65536,
|
|
};
|
|
|
|
const size_t kCommentNumWhitespaces = 2;
|
|
|
|
// TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
|
|
// containing MARK-V model for a specific dataset.
|
|
class MarkvModel {
|
|
public:
|
|
size_t opcode_chunk_length() const { return 7; }
|
|
size_t num_operands_chunk_length() const { return 3; }
|
|
size_t id_index_chunk_length() const { return 3; }
|
|
|
|
size_t u16_chunk_length() const { return 4; }
|
|
size_t s16_chunk_length() const { return 4; }
|
|
size_t s16_block_exponent() const { return 6; }
|
|
|
|
size_t u32_chunk_length() const { return 8; }
|
|
size_t s32_chunk_length() const { return 8; }
|
|
size_t s32_block_exponent() const { return 10; }
|
|
|
|
size_t u64_chunk_length() const { return 8; }
|
|
size_t s64_chunk_length() const { return 8; }
|
|
size_t s64_block_exponent() const { return 10; }
|
|
};
|
|
|
|
const MarkvModel* GetDefaultModel() {
|
|
static MarkvModel model;
|
|
return &model;
|
|
}
|
|
|
|
// Returns chunk length used for variable length encoding of spirv operand
|
|
// words. Returns zero if operand type corresponds to potentially multiple
|
|
// words or a word which is not expected to profit from variable width encoding.
|
|
// Chunk length is selected based on the size of expected value.
|
|
// Most of these values will later be encoded with probability-based coding,
|
|
// but variable width integer coding is a good quick solution.
|
|
// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
|
|
size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
|
|
return 8;
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
|
|
return 6;
|
|
default:
|
|
return 0;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Returns true if the opcode has a fixed number of operands. May return a
|
|
// false negative.
|
|
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
|
|
switch (opcode) {
|
|
// TODO(atgoo@github.com) This is not a complete list.
|
|
case SpvOpNop:
|
|
case SpvOpName:
|
|
case SpvOpUndef:
|
|
case SpvOpSizeOf:
|
|
case SpvOpLine:
|
|
case SpvOpNoLine:
|
|
case SpvOpDecorationGroup:
|
|
case SpvOpExtension:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpMemoryModel:
|
|
case SpvOpCapability:
|
|
case SpvOpTypeVoid:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeMatrix:
|
|
case SpvOpTypeSampler:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeArray:
|
|
case SpvOpTypePointer:
|
|
case SpvOpConstantTrue:
|
|
case SpvOpConstantFalse:
|
|
case SpvOpLabel:
|
|
case SpvOpBranch:
|
|
case SpvOpFunction:
|
|
case SpvOpFunctionParameter:
|
|
case SpvOpFunctionEnd:
|
|
case SpvOpBitcast:
|
|
case SpvOpCopyObject:
|
|
case SpvOpTranspose:
|
|
case SpvOpSNegate:
|
|
case SpvOpFNegate:
|
|
case SpvOpIAdd:
|
|
case SpvOpFAdd:
|
|
case SpvOpISub:
|
|
case SpvOpFSub:
|
|
case SpvOpIMul:
|
|
case SpvOpFMul:
|
|
case SpvOpUDiv:
|
|
case SpvOpSDiv:
|
|
case SpvOpFDiv:
|
|
case SpvOpUMod:
|
|
case SpvOpSRem:
|
|
case SpvOpSMod:
|
|
case SpvOpFRem:
|
|
case SpvOpFMod:
|
|
case SpvOpVectorTimesScalar:
|
|
case SpvOpMatrixTimesScalar:
|
|
case SpvOpVectorTimesMatrix:
|
|
case SpvOpMatrixTimesVector:
|
|
case SpvOpMatrixTimesMatrix:
|
|
case SpvOpOuterProduct:
|
|
case SpvOpDot:
|
|
return true;
|
|
default:
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
size_t GetNumBitsToNextByte(size_t bit_pos) {
|
|
return (8 - (bit_pos % 8)) % 8;
|
|
}
|
|
|
|
bool ShouldByteBreak(size_t bit_pos) {
|
|
const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
|
|
return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
|
|
}
|
|
|
|
// Defines and returns current MARK-V version.
|
|
uint32_t GetMarkvVersion() {
|
|
const uint32_t kVersionMajor = 1;
|
|
const uint32_t kVersionMinor = 0;
|
|
return kVersionMinor | (kVersionMajor << 16);
|
|
}
|
|
|
|
class CommentLogger {
|
|
public:
|
|
void AppendText(const std::string& str) {
|
|
Append(str);
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendTextNewLine(const std::string& str) {
|
|
Append(str);
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendBitSequence(const std::string& str) {
|
|
if (use_delimiter_)
|
|
Append("-");
|
|
Append(str);
|
|
use_delimiter_ = true;
|
|
}
|
|
|
|
void AppendWhitespaces(size_t num) {
|
|
Append(std::string(num, ' '));
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void NewLine() {
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
std::string GetText() const {
|
|
return ss_.str();
|
|
}
|
|
|
|
private:
|
|
void Append(const std::string& str) {
|
|
ss_ << str;
|
|
// std::cerr << str;
|
|
}
|
|
|
|
std::stringstream ss_;
|
|
|
|
// If true a delimiter will be appended before the next bit sequence.
|
|
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
|
|
bool use_delimiter_ = false;
|
|
};
|
|
|
|
// Creates spv_text object containing text from |str|.
|
|
// The returned value is owned by the caller and needs to be destroyed with
|
|
// spvTextDestroy.
|
|
spv_text CreateSpvText(const std::string& str) {
|
|
spv_text out = new spv_text_t();
|
|
assert(out);
|
|
char* cstr = new char[str.length() + 1];
|
|
assert(cstr);
|
|
std::strncpy(cstr, str.c_str(), str.length());
|
|
cstr[str.length()] = '\0';
|
|
out->str = cstr;
|
|
out->length = str.length();
|
|
return out;
|
|
}
|
|
|
|
// Base class for MARK-V encoder and decoder. Contains common functionality
|
|
// such as:
|
|
// - Validator connection and validation state.
|
|
// - SPIR-V grammar and helper functions.
|
|
class MarkvCodecBase {
|
|
public:
|
|
virtual ~MarkvCodecBase() {
|
|
spvValidatorOptionsDestroy(validator_options_);
|
|
}
|
|
|
|
MarkvCodecBase() = delete;
|
|
|
|
void SetModel(const MarkvModel* model) {
|
|
model_ = model;
|
|
}
|
|
|
|
protected:
|
|
struct MarkvHeader {
|
|
MarkvHeader() {
|
|
magic_number = kMarkvMagicNumber;
|
|
markv_version = GetMarkvVersion();
|
|
markv_model = 0;
|
|
markv_length_in_bits = 0;
|
|
spirv_version = 0;
|
|
spirv_generator = 0;
|
|
}
|
|
|
|
uint32_t magic_number;
|
|
uint32_t markv_version;
|
|
// Magic number to identify or verify MarkvModel used for encoding.
|
|
uint32_t markv_model;
|
|
uint32_t markv_length_in_bits;
|
|
uint32_t spirv_version;
|
|
uint32_t spirv_generator;
|
|
};
|
|
|
|
explicit MarkvCodecBase(spv_const_context context,
|
|
spv_validator_options validator_options)
|
|
: validator_options_(validator_options),
|
|
vstate_(context, validator_options_), grammar_(context),
|
|
model_(GetDefaultModel()) {}
|
|
|
|
// Validates a single instruction and updates validation state of the module.
|
|
spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
|
|
return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
|
|
}
|
|
|
|
// Returns the current instruction (the one last processed by the validator).
|
|
const Instruction& GetCurrentInstruction() const {
|
|
return vstate_.ordered_instructions().back();
|
|
}
|
|
|
|
spv_validator_options validator_options_;
|
|
ValidationState_t vstate_;
|
|
const libspirv::AssemblyGrammar grammar_;
|
|
MarkvHeader header_;
|
|
const MarkvModel* model_;
|
|
|
|
// Move-to-front list of all ids.
|
|
// TODO(atgoo@github.com) Consider a better move-to-front implementation.
|
|
std::list<uint32_t> move_to_front_ids_;
|
|
};
|
|
|
|
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
|
|
// EncodeInstruction which can be used as callback by spvBinaryParse.
|
|
// Encoded binary is written to an internally maintained bitstream.
|
|
// After the last instruction is encoded, the resulting MARK-V binary can be
|
|
// acquired by calling GetMarkvBinary().
|
|
// The encoder uses SPIR-V validator to keep internal state, therefore
|
|
// SPIR-V binary needs to be able to pass validator checks.
|
|
// CreateCommentsLogger() can be used to enable the encoder to write comments
|
|
// on how encoding was done, which can later be accessed with GetComments().
|
|
class MarkvEncoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvEncoder(spv_const_context context,
|
|
spv_const_markv_encoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options) {
|
|
(void) options_;
|
|
}
|
|
|
|
// Writes data from SPIR-V header to MARK-V header.
|
|
spv_result_t EncodeHeader(
|
|
spv_endianness_t /* endian */, uint32_t /* magic */,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t /* schema */) {
|
|
vstate_.setIdBound(id_bound);
|
|
header_.spirv_version = version;
|
|
header_.spirv_generator = generator;
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
|
|
// Operation can fail if the instruction fails to pass the validator or if
|
|
// the encoder stubmles on something unexpected.
|
|
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
|
|
|
|
// Concatenates MARK-V header and the bit stream with encoded instructions
|
|
// into a single buffer and returns it as spv_markv_binary. The returned
|
|
// value is owned by the caller and needs to be destroyed with
|
|
// spvMarkvBinaryDestroy().
|
|
spv_markv_binary GetMarkvBinary() {
|
|
header_.markv_length_in_bits =
|
|
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
|
|
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
|
|
|
|
spv_markv_binary markv_binary = new spv_markv_binary_t();
|
|
markv_binary->data = new uint8_t[num_bytes];
|
|
markv_binary->length = num_bytes;
|
|
assert(writer_.GetData());
|
|
std::memcpy(markv_binary->data, &header_, sizeof(header_));
|
|
std::memcpy(markv_binary->data + sizeof(header_),
|
|
writer_.GetData(), writer_.GetDataSizeBytes());
|
|
return markv_binary;
|
|
}
|
|
|
|
// Creates an internal logger which writes comments on the encoding process.
|
|
// Output can later be accessed with GetComments().
|
|
void CreateCommentsLogger() {
|
|
logger_.reset(new CommentLogger());
|
|
writer_.SetCallback([this](const std::string& str){
|
|
logger_->AppendBitSequence(str);
|
|
});
|
|
}
|
|
|
|
// Optionally adds disassembly to the comments.
|
|
// Disassembly should contain all instructions in the module separated by
|
|
// \n, and no header.
|
|
void SetDisassembly(std::string&& disassembly) {
|
|
disassembly_.reset(new std::stringstream(std::move(disassembly)));
|
|
}
|
|
|
|
// Extracts the next instruction line from the disassembly and logs it.
|
|
void LogDisassemblyInstruction() {
|
|
if (logger_ && disassembly_) {
|
|
std::string line;
|
|
std::getline(*disassembly_, line, '\n');
|
|
logger_->AppendTextNewLine(line);
|
|
}
|
|
}
|
|
|
|
// Extracts the text from the comment logger.
|
|
std::string GetComments() const {
|
|
if (!logger_)
|
|
return "";
|
|
return logger_->GetText();
|
|
}
|
|
|
|
private:
|
|
// Creates and returns validator options. Return value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_encoder_options) {
|
|
return spvValidatorOptionsCreate();
|
|
}
|
|
|
|
// Writes a single word to bit stream. |type| determines if the word is
|
|
// encoded and how.
|
|
void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
|
|
const size_t chunk_length =
|
|
GetOperandVariableWidthChunkLength(type);
|
|
if (chunk_length) {
|
|
writer_.WriteVariableWidthU32(word, chunk_length);
|
|
} else {
|
|
writer_.WriteUnencoded(word);
|
|
}
|
|
}
|
|
|
|
// Returns id index and updates move-to-front.
|
|
// Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
|
|
// instructions.
|
|
uint16_t GetIdIndex(uint32_t id) {
|
|
if (all_known_ids_.count(id)) {
|
|
uint16_t index = 0;
|
|
for (auto it = move_to_front_ids_.begin();
|
|
it != move_to_front_ids_.end(); ++it) {
|
|
if (*it == id) {
|
|
if (index != 0) {
|
|
move_to_front_ids_.erase(it);
|
|
move_to_front_ids_.push_front(id);
|
|
}
|
|
return index;
|
|
}
|
|
++index;
|
|
}
|
|
assert(0 && "Id not found in move_to_front_ids_");
|
|
return 0;
|
|
} else {
|
|
all_known_ids_.insert(id);
|
|
move_to_front_ids_.push_front(id);
|
|
return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
|
|
}
|
|
}
|
|
|
|
void AddByteBreakIfAgreed() {
|
|
if (!ShouldByteBreak(writer_.GetNumBits()))
|
|
return;
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("ByteBreak:");
|
|
}
|
|
|
|
writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
|
|
}
|
|
|
|
// Encodes a literal number operand and writes it to the bit stream.
|
|
void EncodeLiteralNumber(const Instruction& instruction,
|
|
const spv_parsed_operand_t& operand);
|
|
|
|
spv_const_markv_encoder_options options_;
|
|
|
|
// Bit stream where encoded instructions are written.
|
|
BitWriterWord64 writer_;
|
|
|
|
// If not nullptr, encoder will write comments.
|
|
std::unique_ptr<CommentLogger> logger_;
|
|
|
|
// If not nullptr, disassembled instruction lines will be written to comments.
|
|
// Format: \n separated instruction lines, no header.
|
|
std::unique_ptr<std::stringstream> disassembly_;
|
|
|
|
// All ids which were previosly encountered in the module.
|
|
std::unordered_set<uint32_t> all_known_ids_;
|
|
};
|
|
|
|
// Decodes MARK-V buffers written by MarkvEncoder.
|
|
class MarkvDecoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvDecoder(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options), reader_(markv_data, markv_size_bytes) {
|
|
(void) options_;
|
|
vstate_.setIdBound(1);
|
|
parsed_operands_.reserve(25);
|
|
}
|
|
|
|
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
|
|
// Can be called only once. Fails if data of wrong format or ends prematurely,
|
|
// of if validation fails.
|
|
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
|
|
|
|
private:
|
|
// Describes the format of a typed literal number.
|
|
struct NumberType {
|
|
spv_number_kind_t type;
|
|
uint32_t bit_width;
|
|
};
|
|
|
|
// Creates and returns validator options. Return value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_decoder_options) {
|
|
return spvValidatorOptionsCreate();
|
|
}
|
|
|
|
// Reads a single word from bit stream. |type| determines if the word needs
|
|
// to be decoded and how. Returns false if read fails.
|
|
bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
|
|
const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
|
|
if (chunk_length) {
|
|
return reader_.ReadVariableWidthU32(word, chunk_length);
|
|
} else {
|
|
return reader_.ReadUnencoded(word);
|
|
}
|
|
}
|
|
|
|
// Fetches the id from the move-to-front list and moves it to front.
|
|
uint32_t GetIdAndMoveToFront(uint16_t index) {
|
|
if (index >= move_to_front_ids_.size()) {
|
|
// Issue new id.
|
|
const uint32_t id = vstate_.getIdBound();
|
|
move_to_front_ids_.push_front(id);
|
|
vstate_.setIdBound(id + 1);
|
|
return id;
|
|
} else {
|
|
if (index == 0)
|
|
return move_to_front_ids_.front();
|
|
|
|
// Iterate to index.
|
|
auto it = move_to_front_ids_.begin();
|
|
for (size_t i = 0; i < index; ++i)
|
|
++it;
|
|
const uint32_t id = *it;
|
|
move_to_front_ids_.erase(it);
|
|
move_to_front_ids_.push_front(id);
|
|
return id;
|
|
}
|
|
}
|
|
|
|
// Decodes id index and fetches the id from move-to-front list.
|
|
bool DecodeId(uint32_t* id) {
|
|
uint16_t index = 0;
|
|
if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
|
|
return false;
|
|
|
|
*id = GetIdAndMoveToFront(index);
|
|
return true;
|
|
}
|
|
|
|
bool ReadToByteBreakIfAgreed() {
|
|
if (!ShouldByteBreak(reader_.GetNumReadBits()))
|
|
return true;
|
|
|
|
uint64_t bits = 0;
|
|
if (!reader_.ReadBits(&bits,
|
|
GetNumBitsToNextByte(reader_.GetNumReadBits())))
|
|
return false;
|
|
|
|
if (bits != 0)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
// Reads a literal number as it is described in |operand| from the bit stream,
|
|
// decodes and writes it to spirv_.
|
|
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
|
|
|
|
// Reads instruction from bit stream, decodes and validates it.
|
|
// Decoded instruction is valid until the next call of DecodeInstruction().
|
|
spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
|
|
|
|
// Read operand from the stream decodes and validates it.
|
|
spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
|
|
spv_parsed_instruction_t* inst,
|
|
const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands,
|
|
bool read_result_id);
|
|
|
|
// Records the numeric type for an operand according to the type information
|
|
// associated with the given non-zero type Id. This can fail if the type Id
|
|
// is not a type Id, or if the type Id does not reference a scalar numeric
|
|
// type. On success, return SPV_SUCCESS and populates the num_words,
|
|
// number_kind, and number_bit_width fields of parsed_operand.
|
|
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
|
|
uint32_t type_id);
|
|
|
|
// Records the number type for the given instruction, if that
|
|
// instruction generates a type. For types that aren't scalar numbers,
|
|
// record something with number kind SPV_NUMBER_NONE.
|
|
void RecordNumberType(const spv_parsed_instruction_t& inst);
|
|
|
|
spv_const_markv_decoder_options options_;
|
|
|
|
// Temporary sink where decoded SPIR-V words are written. Once it contains the
|
|
// entire module, the container is moved and returned.
|
|
std::vector<uint32_t> spirv_;
|
|
|
|
// Bit stream containing encoded data.
|
|
BitReaderWord64 reader_;
|
|
|
|
// Temporary storage for operands of the currently parsed instruction.
|
|
// Valid until next DecodeInstruction call.
|
|
std::vector<spv_parsed_operand_t> parsed_operands_;
|
|
|
|
// Maps a result ID to its type ID. By convention:
|
|
// - a result ID that is a type definition maps to itself.
|
|
// - a result ID without a type maps to 0. (E.g. for OpLabel)
|
|
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
|
|
// Maps a type ID to its number type description.
|
|
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
|
|
// Maps an ExtInstImport id to the extended instruction type.
|
|
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
|
|
};
|
|
|
|
void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width == 32) {
|
|
const uint32_t word = instruction.word(operand.offset);
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int32_t val = 0;
|
|
std::memcpy(&val, &word, 4);
|
|
writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
|
|
model_->s32_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
assert(0);
|
|
}
|
|
} else if (operand.number_bit_width == 16) {
|
|
const uint16_t word =
|
|
static_cast<uint16_t>(instruction.word(operand.offset));
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int16_t val = 0;
|
|
std::memcpy(&val, &word, 2);
|
|
writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
|
|
model_->s16_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
// TODO(atgoo@github.com) Write only 16 bits.
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
assert(0);
|
|
}
|
|
} else {
|
|
assert(operand.number_bit_width == 64);
|
|
const uint64_t word =
|
|
uint64_t(instruction.word(operand.offset)) |
|
|
(uint64_t(instruction.word(operand.offset + 1)) << 32);
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
std::memcpy(&val, &word, 8);
|
|
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
assert(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeInstruction(
|
|
const spv_parsed_instruction_t& inst) {
|
|
const spv_result_t validation_result = UpdateValidationState(inst);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
|
|
bool result_id_was_forward_declared = false;
|
|
if (all_known_ids_.count(inst.result_id)) {
|
|
// Result id of the instruction was forward declared.
|
|
// Write a service opcode to signal this to the decoder.
|
|
writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
|
|
model_->opcode_chunk_length());
|
|
result_id_was_forward_declared = true;
|
|
}
|
|
|
|
const Instruction& instruction = GetCurrentInstruction();
|
|
const auto& operands = instruction.operands();
|
|
|
|
LogDisassemblyInstruction();
|
|
|
|
// Write opcode.
|
|
writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
|
|
|
|
if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
|
|
// If the opcode has a variable number of operands, encode the number of
|
|
// operands with the instruction.
|
|
|
|
if (logger_)
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
|
|
writer_.WriteVariableWidthU16(inst.num_operands,
|
|
model_->num_operands_chunk_length());
|
|
}
|
|
|
|
// Write operands.
|
|
for (const auto& operand : operands) {
|
|
if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
|
|
!result_id_was_forward_declared) {
|
|
// Register the id, but don't encode it.
|
|
GetIdIndex(instruction.word(operand.offset));
|
|
continue;
|
|
}
|
|
|
|
if (logger_)
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
|
|
if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
|
|
EncodeLiteralNumber(instruction, operand);
|
|
} else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
|
|
const char* src =
|
|
reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
|
|
const size_t length = spv_strnlen_s(src, operand.num_words * 4);
|
|
if (length == operand.num_words * 4)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to find terminal character of literal string";
|
|
for (size_t i = 0; i < length + 1; ++i)
|
|
writer_.WriteUnencoded(src[i]);
|
|
} else if (spvIsIdType(operand.type)) {
|
|
const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
|
|
writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
|
|
} else {
|
|
for (int i = 0; i < operand.num_words; ++i) {
|
|
const uint32_t word = instruction.word(operand.offset + i);
|
|
EncodeOperandWord(operand.type, word);
|
|
}
|
|
}
|
|
}
|
|
|
|
AddByteBreakIfAgreed();
|
|
|
|
if (logger_) {
|
|
logger_->NewLine();
|
|
logger_->NewLine();
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeLiteralNumber(
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width == 32) {
|
|
uint32_t word = 0;
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal U32";
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int32_t val = 0;
|
|
if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
|
|
model_->s32_block_exponent()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal S32";
|
|
std::memcpy(&word, &val, 4);
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
if (!reader_.ReadUnencoded(&word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal F32";
|
|
} else {
|
|
assert(0);
|
|
}
|
|
spirv_.push_back(word);
|
|
} else if (operand.number_bit_width == 16) {
|
|
uint32_t word = 0;
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
uint16_t val = 0;
|
|
if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal U16";
|
|
word = val;
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int16_t val = 0;
|
|
if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
|
|
model_->s16_block_exponent()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal S16";
|
|
// Int16 is stored as int32 in SPIR-V, not as bits.
|
|
int32_t val32 = val;
|
|
std::memcpy(&word, &val32, 4);
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
uint16_t word16 = 0;
|
|
if (!reader_.ReadUnencoded(&word16))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal F16";
|
|
word = word16;
|
|
} else {
|
|
assert(0);
|
|
}
|
|
spirv_.push_back(word);
|
|
} else {
|
|
assert(operand.number_bit_width == 64);
|
|
uint64_t word = 0;
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal U64";
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal S64";
|
|
std::memcpy(&word, &val, 8);
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
if (!reader_.ReadUnencoded(&word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal F64";
|
|
} else {
|
|
assert(0);
|
|
}
|
|
spirv_.push_back(static_cast<uint32_t>(word));
|
|
spirv_.push_back(static_cast<uint32_t>(word >> 32));
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
|
|
const bool header_read_success =
|
|
reader_.ReadUnencoded(&header_.magic_number) &&
|
|
reader_.ReadUnencoded(&header_.markv_version) &&
|
|
reader_.ReadUnencoded(&header_.markv_model) &&
|
|
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
|
|
reader_.ReadUnencoded(&header_.spirv_version) &&
|
|
reader_.ReadUnencoded(&header_.spirv_generator);
|
|
|
|
if (!header_read_success)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to read MARK-V header";
|
|
|
|
assert(header_.magic_number == kMarkvMagicNumber);
|
|
assert(header_.markv_length_in_bits > 0);
|
|
|
|
if (header_.magic_number != kMarkvMagicNumber)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has incorrect magic number";
|
|
|
|
// TODO(atgoo@github.com): Print version strings.
|
|
if (header_.markv_version != GetMarkvVersion())
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary and the codec have different versions";
|
|
|
|
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
|
|
spirv_.resize(5, 0);
|
|
spirv_[0] = kSpirvMagicNumber;
|
|
spirv_[1] = header_.spirv_version;
|
|
spirv_[2] = header_.spirv_generator;
|
|
|
|
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
|
|
spv_parsed_instruction_t inst = {};
|
|
const spv_result_t decode_result = DecodeInstruction(&inst);
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
|
|
const spv_result_t validation_result = UpdateValidationState(inst);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
}
|
|
|
|
|
|
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
|
|
!reader_.OnlyZeroesLeft()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has wrong stated bit length "
|
|
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
|
|
}
|
|
|
|
// Decoding of the module is finished, validation state should have correct
|
|
// id bound.
|
|
spirv_[3] = vstate_.getIdBound();
|
|
|
|
*spirv_binary = std::move(spirv_);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// TODO(atgoo@github.com): The implementation borrows heavily from
|
|
// Parser::parseOperand.
|
|
// Consider coupling them together in some way once MARK-V codec is more mature.
|
|
// For now it's better to keep the code independent for experimentation
|
|
// purposes.
|
|
spv_result_t MarkvDecoder::DecodeOperand(
|
|
size_t instruction_offset, size_t operand_offset,
|
|
spv_parsed_instruction_t* inst, const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands,
|
|
bool read_result_id) {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
|
|
|
|
spv_parsed_operand_t parsed_operand;
|
|
memset(&parsed_operand, 0, sizeof(parsed_operand));
|
|
|
|
assert((operand_offset >> 16) == 0);
|
|
parsed_operand.offset = static_cast<uint16_t>(operand_offset);
|
|
parsed_operand.type = type;
|
|
|
|
// Set default values, may be updated later.
|
|
parsed_operand.number_kind = SPV_NUMBER_NONE;
|
|
parsed_operand.number_bit_width = 0;
|
|
|
|
const size_t first_word_index = spirv_.size();
|
|
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_TYPE_ID: {
|
|
if (!DecodeId(&inst->type_id)) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read type_id";
|
|
}
|
|
|
|
if (inst->type_id == 0)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
|
|
|
|
spirv_.push_back(inst->type_id);
|
|
vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_RESULT_ID: {
|
|
if (read_result_id) {
|
|
if (!DecodeId(&inst->result_id))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read result_id";
|
|
} else {
|
|
inst->result_id = vstate_.getIdBound();
|
|
vstate_.setIdBound(inst->result_id + 1);
|
|
move_to_front_ids_.push_front(inst->result_id);
|
|
}
|
|
|
|
spirv_.push_back(inst->result_id);
|
|
|
|
// Save the result ID to type ID mapping.
|
|
// In the grammar, type ID always appears before result ID.
|
|
// A regular value maps to its type. Some instructions (e.g. OpLabel)
|
|
// have no type Id, and will map to 0. The result Id for a
|
|
// type-generating instruction (e.g. OpTypeInt) maps to itself.
|
|
auto insertion_result = id_to_type_id_.emplace(
|
|
inst->result_id,
|
|
spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
|
|
if(!insertion_result.second) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_ID)
|
|
<< "Unexpected behavior: id->type_id pair was already registered";
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
uint32_t id = 0;
|
|
if (!DecodeId(&id))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
|
|
|
|
if (id == 0)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
|
|
|
|
spirv_.push_back(id);
|
|
vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
|
|
|
|
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
|
|
|
|
parsed_operand.type = SPV_OPERAND_TYPE_ID;
|
|
|
|
if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
|
|
// The current word is the extended instruction set id.
|
|
// Set the extended instruction set type for the current instruction.
|
|
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
|
|
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_ID)
|
|
<< "OpExtInst set id " << id
|
|
<< " does not reference an OpExtInstImport result Id";
|
|
}
|
|
inst->ext_inst_type = ext_inst_type_iter->second;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
|
|
uint32_t word = 0;
|
|
if (!DecodeOperandWord(type, &word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read enum";
|
|
|
|
spirv_.push_back(word);
|
|
|
|
assert(SpvOpExtInst == opcode);
|
|
assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE);
|
|
spv_ext_inst_desc ext_inst;
|
|
if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction number: " << word;
|
|
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
|
|
// These are regular single-word literal integer operands.
|
|
// Post-parsing validation should check the range of the parsed value.
|
|
parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
|
|
// It turns out they are always unsigned integers!
|
|
parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
|
|
parsed_operand.number_bit_width = 32;
|
|
|
|
uint32_t word = 0;
|
|
if (!DecodeOperandWord(type, &word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal integer";
|
|
|
|
spirv_.push_back(word);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
|
|
parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
|
|
if (opcode == SpvOpSwitch) {
|
|
// The literal operands have the same type as the value
|
|
// referenced by the selector Id.
|
|
const uint32_t selector_id = spirv_.at(instruction_offset + 1);
|
|
const auto type_id_iter = id_to_type_id_.find(selector_id);
|
|
if (type_id_iter == id_to_type_id_.end() ||
|
|
type_id_iter->second == 0) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " has no type";
|
|
}
|
|
uint32_t type_id = type_id_iter->second;
|
|
|
|
if (selector_id == type_id) {
|
|
// Recall that by convention, a result ID that is a type definition
|
|
// maps to itself.
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is a type, not a value";
|
|
}
|
|
if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
|
|
return error;
|
|
if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
|
|
parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is not a scalar integer";
|
|
}
|
|
} else {
|
|
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
|
|
// The literal number type is determined by the type Id for the
|
|
// constant.
|
|
assert(inst->type_id);
|
|
if (auto error =
|
|
SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
|
|
return error;
|
|
}
|
|
|
|
if (auto error = DecodeLiteralNumber(parsed_operand))
|
|
return error;
|
|
|
|
break;
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
|
|
parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
|
|
std::vector<char> str;
|
|
// The loop is expected to terminate once we encounter '\0' or exhaust
|
|
// the bit stream.
|
|
while (true) {
|
|
char ch = 0;
|
|
if (!reader_.ReadUnencoded(&ch))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal string";
|
|
|
|
str.push_back(ch);
|
|
|
|
if (ch == '\0')
|
|
break;
|
|
}
|
|
|
|
while (str.size() % 4 != 0)
|
|
str.push_back('\0');
|
|
|
|
spirv_.resize(spirv_.size() + str.size() / 4);
|
|
std::memcpy(&spirv_[first_word_index], str.data(), str.size());
|
|
|
|
if (SpvOpExtInstImport == opcode) {
|
|
// Record the extended instruction type for the ID for this import.
|
|
// There is only one string literal argument to OpExtInstImport,
|
|
// so it's sufficient to guard this just on the opcode.
|
|
const spv_ext_inst_type_t ext_inst_type =
|
|
spvExtInstImportTypeGet(str.data());
|
|
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction import '" << str.data() << "'";
|
|
}
|
|
// We must have parsed a valid result ID. It's a condition
|
|
// of the grammar, and we only accept non-zero result Ids.
|
|
assert(inst->result_id);
|
|
const bool inserted = import_id_to_ext_inst_type_.emplace(
|
|
inst->result_id, ext_inst_type).second;
|
|
(void)inserted;
|
|
assert(inserted);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
|
|
// A single word that is a plain enum value.
|
|
uint32_t word = 0;
|
|
if (!DecodeOperandWord(type, &word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read enum";
|
|
|
|
spirv_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
|
|
parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
|
|
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, word, &entry)) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid "
|
|
<< spvOperandTypeStr(parsed_operand.type)
|
|
<< " operand: " << word;
|
|
}
|
|
|
|
// Prepare to accept operands to this operand, if needed.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
|
|
// This operand is a mask.
|
|
uint32_t word = 0;
|
|
if (!DecodeOperandWord(type, &word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read " << spvOperandTypeStr(type)
|
|
<< " for " << spvOpcodeString(SpvOp(inst->opcode));
|
|
|
|
spirv_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
|
|
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
|
|
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
|
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
|
|
|
// Check validity of set mask bits. Also prepare for operands for those
|
|
// masks if they have any. To get operand order correct, scan from
|
|
// MSB to LSB since we can only prepend operands to a pattern.
|
|
// The only case in the grammar where you have more than one mask bit
|
|
// having an operand is for image operands. See SPIR-V 3.14 Image
|
|
// Operands.
|
|
uint32_t remaining_word = word;
|
|
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
|
|
if (remaining_word & mask) {
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, mask, &entry)) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid " << spvOperandTypeStr(parsed_operand.type)
|
|
<< " operand: " << word << " has invalid mask component "
|
|
<< mask;
|
|
}
|
|
remaining_word ^= mask;
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
if (word == 0) {
|
|
// An all-zeroes mask *might* also be valid.
|
|
spv_operand_desc entry;
|
|
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
|
|
// Prepare for its operands, if any.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Internal error: Unhandled operand type: " << type;
|
|
}
|
|
|
|
parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
|
|
|
|
assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
|
|
assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
|
|
|
|
parsed_operands_.push_back(parsed_operand);
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
|
|
parsed_operands_.clear();
|
|
const size_t instruction_offset = spirv_.size();
|
|
|
|
bool read_result_id = false;
|
|
|
|
while (true) {
|
|
uint32_t word = 0;
|
|
if (!reader_.ReadVariableWidthU32(&word,
|
|
model_->opcode_chunk_length())) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read opcode of instruction";
|
|
}
|
|
|
|
if (word >= kMarkvFirstOpcode) {
|
|
if (word == kMarkvOpNextInstructionEncodesResultId) {
|
|
read_result_id = true;
|
|
} else {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Encountered unknown MARK-V opcode";
|
|
}
|
|
} else {
|
|
inst->opcode = static_cast<uint16_t>(word);
|
|
break;
|
|
}
|
|
}
|
|
|
|
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
|
|
|
|
// Opcode/num_words placeholder, the word will be filled in later.
|
|
spirv_.push_back(0);
|
|
|
|
spv_opcode_desc opcode_desc;
|
|
if (grammar_.lookupOpcode(opcode, &opcode_desc)
|
|
!= SPV_SUCCESS) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
|
|
}
|
|
|
|
spv_operand_pattern_t expected_operands;
|
|
expected_operands.reserve(opcode_desc->numTypes);
|
|
for (auto i = 0; i < opcode_desc->numTypes; i++)
|
|
expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
|
|
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
if (!reader_.ReadVariableWidthU16(&inst->num_operands,
|
|
model_->num_operands_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read num_operands of instruction";
|
|
} else {
|
|
inst->num_operands = static_cast<uint16_t>(expected_operands.size());
|
|
}
|
|
|
|
for (size_t operand_index = 0;
|
|
operand_index < static_cast<size_t>(inst->num_operands);
|
|
++operand_index) {
|
|
assert(!expected_operands.empty());
|
|
const spv_operand_type_t type =
|
|
spvTakeFirstMatchableOperand(&expected_operands);
|
|
|
|
const size_t operand_offset = spirv_.size() - instruction_offset;
|
|
|
|
const spv_result_t decode_result =
|
|
DecodeOperand(instruction_offset, operand_offset, inst, type,
|
|
&expected_operands, read_result_id);
|
|
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
}
|
|
|
|
assert(inst->num_operands == parsed_operands_.size());
|
|
|
|
// Only valid while spirv_ and parsed_operands_ remain unchanged.
|
|
inst->words = &spirv_[instruction_offset];
|
|
inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
|
|
inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
|
|
spirv_[instruction_offset] =
|
|
spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
|
|
|
|
assert(inst->num_words == std::accumulate(
|
|
parsed_operands_.begin(), parsed_operands_.end(), 1,
|
|
[](int num_words, const spv_parsed_operand_t& operand) {
|
|
return num_words += operand.num_words;
|
|
}) && "num_words in instruction doesn't correspond to the sum of num_words"
|
|
"in the operands");
|
|
|
|
RecordNumberType(*inst);
|
|
|
|
if (!ReadToByteBreakIfAgreed())
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read to byte break";
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
|
|
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
|
|
assert(type_id != 0);
|
|
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
|
|
if (type_info_iter == type_id_to_number_type_info_.end()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a type";
|
|
}
|
|
|
|
const NumberType& info = type_info_iter->second;
|
|
if (info.type == SPV_NUMBER_NONE) {
|
|
// This is a valid type, but for something other than a scalar number.
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a scalar numeric type";
|
|
}
|
|
|
|
parsed_operand->number_kind = info.type;
|
|
parsed_operand->number_bit_width = info.bit_width;
|
|
// Round up the word count.
|
|
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
|
|
if (spvOpcodeGeneratesType(opcode)) {
|
|
NumberType info = {SPV_NUMBER_NONE, 0};
|
|
if (SpvOpTypeInt == opcode) {
|
|
info.bit_width = inst.words[inst.operands[1].offset];
|
|
info.type = inst.words[inst.operands[2].offset] ?
|
|
SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
|
|
} else if (SpvOpTypeFloat == opcode) {
|
|
info.bit_width = inst.words[inst.operands[1].offset];
|
|
info.type = SPV_NUMBER_FLOATING;
|
|
}
|
|
// The *result* Id of a type generating instruction is the type Id.
|
|
type_id_to_number_type_info_[inst.result_id] = info;
|
|
}
|
|
}
|
|
|
|
spv_result_t EncodeHeader(
|
|
void* user_data, spv_endianness_t endian, uint32_t magic,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t schema) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeHeader(
|
|
endian, magic, version, generator, id_bound, schema);
|
|
}
|
|
|
|
spv_result_t EncodeInstruction(
|
|
void* user_data, const spv_parsed_instruction_t* inst) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeInstruction(*inst);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
spv_result_t spvSpirvToMarkv(spv_const_context context,
|
|
const uint32_t* spirv_words,
|
|
const size_t spirv_num_words,
|
|
spv_const_markv_encoder_options options,
|
|
spv_markv_binary* markv_binary,
|
|
spv_text* comments, spv_diagnostic* diagnostic) {
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
|
|
|
|
spv_endianness_t endian;
|
|
spv_position_t position = {};
|
|
if (spvBinaryEndianness(&spirv_binary, &endian)) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V magic number.";
|
|
}
|
|
|
|
spv_header_t header;
|
|
if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V header.";
|
|
}
|
|
|
|
MarkvEncoder encoder(&hijack_context, options);
|
|
|
|
if (comments) {
|
|
encoder.CreateCommentsLogger();
|
|
|
|
spv_text text = nullptr;
|
|
if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
|
|
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
|
|
!= SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to disassemble SPIR-V binary.";
|
|
}
|
|
assert(text);
|
|
encoder.SetDisassembly(std::string(text->str, text->length));
|
|
spvTextDestroy(text);
|
|
}
|
|
|
|
if (spvBinaryParse(
|
|
&hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
|
|
EncodeInstruction, diagnostic) != SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to encode to MARK-V.";
|
|
}
|
|
|
|
if (comments)
|
|
*comments = CreateSpvText(encoder.GetComments());
|
|
|
|
*markv_binary = encoder.GetMarkvBinary();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t spvMarkvToSpirv(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options,
|
|
spv_binary* spirv_binary,
|
|
spv_text* /* comments */, spv_diagnostic* diagnostic) {
|
|
spv_position_t position = {};
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
|
|
|
|
std::vector<uint32_t> words;
|
|
|
|
if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to decode MARK-V.";
|
|
}
|
|
|
|
assert(!words.empty());
|
|
|
|
*spirv_binary = new spv_binary_t();
|
|
(*spirv_binary)->code = new uint32_t[words.size()];
|
|
(*spirv_binary)->wordCount = words.size();
|
|
std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void spvMarkvBinaryDestroy(spv_markv_binary binary) {
|
|
if (!binary) return;
|
|
delete[] binary->data;
|
|
delete binary;
|
|
}
|
|
|
|
spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
|
|
return new spv_markv_encoder_options_t;
|
|
}
|
|
|
|
void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
|
|
delete options;
|
|
}
|
|
|
|
spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
|
|
return new spv_markv_decoder_options_t;
|
|
}
|
|
|
|
void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
|
|
delete options;
|
|
}
|