SPIRV-Tools/source/comp/markv_codec.cpp
Lenny Komow e9e4393b1c Fix Visual Studio size_t cast compiler warning
Visual Studio was complaining about possible loss of data on 64-bit
builds, due to an implicit cast from size_t to int. This changes the
data to use an int with no cast.
2017-07-13 13:02:43 -06:00

1557 lines
54 KiB
C++

// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains
// - SPIR-V to MARK-V encoder
// - MARK-V to SPIR-V decoder
//
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result ids which can be regenerated) and
// uses various bit reduction techiniques to reduce the size of the binary.
//
// MarkvModel is a flatbuffers object containing a set of rules defining how
// compression/decompression is done (coding schemes, dictionaries).
#include <algorithm>
#include <cassert>
#include <cstring>
#include <functional>
#include <iostream>
#include <list>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "binary.h"
#include "diagnostic.h"
#include "enum_string_mapping.h"
#include "extensions.h"
#include "ext_inst.h"
#include "instruction.h"
#include "opcode.h"
#include "operand.h"
#include "spirv-tools/libspirv.h"
#include "spirv-tools/markv.h"
#include "spirv_endian.h"
#include "spirv_validator_options.h"
#include "util/bit_stream.h"
#include "util/parse_number.h"
#include "validate.h"
#include "val/instruction.h"
#include "val/validation_state.h"
using libspirv::Instruction;
using libspirv::ValidationState_t;
using spvtools::ValidateInstructionAndUpdateValidationState;
using spvutils::BitReaderWord64;
using spvutils::BitWriterWord64;
struct spv_markv_encoder_options_t {
};
struct spv_markv_decoder_options_t {
};
namespace {
const uint32_t kSpirvMagicNumber = SpvMagicNumber;
const uint32_t kMarkvMagicNumber = 0x07230303;
enum {
kMarkvFirstOpcode = 65536,
kMarkvOpNextInstructionEncodesResultId = 65536,
};
const size_t kCommentNumWhitespaces = 2;
// TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
// containing MARK-V model for a specific dataset.
class MarkvModel {
public:
size_t opcode_chunk_length() const { return 7; }
size_t num_operands_chunk_length() const { return 3; }
size_t id_index_chunk_length() const { return 3; }
size_t u16_chunk_length() const { return 4; }
size_t s16_chunk_length() const { return 4; }
size_t s16_block_exponent() const { return 6; }
size_t u32_chunk_length() const { return 8; }
size_t s32_chunk_length() const { return 8; }
size_t s32_block_exponent() const { return 10; }
size_t u64_chunk_length() const { return 8; }
size_t s64_chunk_length() const { return 8; }
size_t s64_block_exponent() const { return 10; }
};
const MarkvModel* GetDefaultModel() {
static MarkvModel model;
return &model;
}
// Returns chunk length used for variable length encoding of spirv operand
// words. Returns zero if operand type corresponds to potentially multiple
// words or a word which is not expected to profit from variable width encoding.
// Chunk length is selected based on the size of expected value.
// Most of these values will later be encoded with probability-based coding,
// but variable width integer coding is a good quick solution.
// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
switch (type) {
case SPV_OPERAND_TYPE_TYPE_ID:
return 4;
case SPV_OPERAND_TYPE_RESULT_ID:
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
return 8;
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
return 6;
case SPV_OPERAND_TYPE_CAPABILITY:
return 6;
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
return 3;
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
case SPV_OPERAND_TYPE_MEMORY_MODEL:
return 2;
case SPV_OPERAND_TYPE_EXECUTION_MODE:
return 6;
case SPV_OPERAND_TYPE_STORAGE_CLASS:
return 4;
case SPV_OPERAND_TYPE_DIMENSIONALITY:
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
return 3;
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
return 2;
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
return 6;
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
return 2;
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
return 3;
case SPV_OPERAND_TYPE_DECORATION:
case SPV_OPERAND_TYPE_BUILT_IN:
return 6;
case SPV_OPERAND_TYPE_GROUP_OPERATION:
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
return 2;
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
return 4;
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
return 6;
default:
return 0;
}
return 0;
}
// Returns true if the opcode has a fixed number of operands. May return a
// false negative.
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
switch (opcode) {
// TODO(atgoo@github.com) This is not a complete list.
case SpvOpNop:
case SpvOpName:
case SpvOpUndef:
case SpvOpSizeOf:
case SpvOpLine:
case SpvOpNoLine:
case SpvOpDecorationGroup:
case SpvOpExtension:
case SpvOpExtInstImport:
case SpvOpMemoryModel:
case SpvOpCapability:
case SpvOpTypeVoid:
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypeVector:
case SpvOpTypeMatrix:
case SpvOpTypeSampler:
case SpvOpTypeSampledImage:
case SpvOpTypeArray:
case SpvOpTypePointer:
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpLabel:
case SpvOpBranch:
case SpvOpFunction:
case SpvOpFunctionParameter:
case SpvOpFunctionEnd:
case SpvOpBitcast:
case SpvOpCopyObject:
case SpvOpTranspose:
case SpvOpSNegate:
case SpvOpFNegate:
case SpvOpIAdd:
case SpvOpFAdd:
case SpvOpISub:
case SpvOpFSub:
case SpvOpIMul:
case SpvOpFMul:
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSMod:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
return true;
default:
break;
}
return false;
}
size_t GetNumBitsToNextByte(size_t bit_pos) {
return (8 - (bit_pos % 8)) % 8;
}
bool ShouldByteBreak(size_t bit_pos) {
const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
}
// Defines and returns current MARK-V version.
uint32_t GetMarkvVersion() {
const uint32_t kVersionMajor = 1;
const uint32_t kVersionMinor = 0;
return kVersionMinor | (kVersionMajor << 16);
}
class CommentLogger {
public:
void AppendText(const std::string& str) {
Append(str);
use_delimiter_ = false;
}
void AppendTextNewLine(const std::string& str) {
Append(str);
Append("\n");
use_delimiter_ = false;
}
void AppendBitSequence(const std::string& str) {
if (use_delimiter_)
Append("-");
Append(str);
use_delimiter_ = true;
}
void AppendWhitespaces(size_t num) {
Append(std::string(num, ' '));
use_delimiter_ = false;
}
void NewLine() {
Append("\n");
use_delimiter_ = false;
}
std::string GetText() const {
return ss_.str();
}
private:
void Append(const std::string& str) {
ss_ << str;
// std::cerr << str;
}
std::stringstream ss_;
// If true a delimiter will be appended before the next bit sequence.
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
bool use_delimiter_ = false;
};
// Creates spv_text object containing text from |str|.
// The returned value is owned by the caller and needs to be destroyed with
// spvTextDestroy.
spv_text CreateSpvText(const std::string& str) {
spv_text out = new spv_text_t();
assert(out);
char* cstr = new char[str.length() + 1];
assert(cstr);
std::strncpy(cstr, str.c_str(), str.length());
cstr[str.length()] = '\0';
out->str = cstr;
out->length = str.length();
return out;
}
// Base class for MARK-V encoder and decoder. Contains common functionality
// such as:
// - Validator connection and validation state.
// - SPIR-V grammar and helper functions.
class MarkvCodecBase {
public:
virtual ~MarkvCodecBase() {
spvValidatorOptionsDestroy(validator_options_);
}
MarkvCodecBase() = delete;
void SetModel(const MarkvModel* model) {
model_ = model;
}
protected:
struct MarkvHeader {
MarkvHeader() {
magic_number = kMarkvMagicNumber;
markv_version = GetMarkvVersion();
markv_model = 0;
markv_length_in_bits = 0;
spirv_version = 0;
spirv_generator = 0;
}
uint32_t magic_number;
uint32_t markv_version;
// Magic number to identify or verify MarkvModel used for encoding.
uint32_t markv_model;
uint32_t markv_length_in_bits;
uint32_t spirv_version;
uint32_t spirv_generator;
};
explicit MarkvCodecBase(spv_const_context context,
spv_validator_options validator_options)
: validator_options_(validator_options),
vstate_(context, validator_options_), grammar_(context),
model_(GetDefaultModel()) {}
// Validates a single instruction and updates validation state of the module.
spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
}
// Returns the current instruction (the one last processed by the validator).
const Instruction& GetCurrentInstruction() const {
return vstate_.ordered_instructions().back();
}
spv_validator_options validator_options_;
ValidationState_t vstate_;
const libspirv::AssemblyGrammar grammar_;
MarkvHeader header_;
const MarkvModel* model_;
// Move-to-front list of all ids.
// TODO(atgoo@github.com) Consider a better move-to-front implementation.
std::list<uint32_t> move_to_front_ids_;
};
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
// EncodeInstruction which can be used as callback by spvBinaryParse.
// Encoded binary is written to an internally maintained bitstream.
// After the last instruction is encoded, the resulting MARK-V binary can be
// acquired by calling GetMarkvBinary().
// The encoder uses SPIR-V validator to keep internal state, therefore
// SPIR-V binary needs to be able to pass validator checks.
// CreateCommentsLogger() can be used to enable the encoder to write comments
// on how encoding was done, which can later be accessed with GetComments().
class MarkvEncoder : public MarkvCodecBase {
public:
MarkvEncoder(spv_const_context context,
spv_const_markv_encoder_options options)
: MarkvCodecBase(context, GetValidatorOptions(options)),
options_(options) {
(void) options_;
}
// Writes data from SPIR-V header to MARK-V header.
spv_result_t EncodeHeader(
spv_endianness_t /* endian */, uint32_t /* magic */,
uint32_t version, uint32_t generator, uint32_t id_bound,
uint32_t /* schema */) {
vstate_.setIdBound(id_bound);
header_.spirv_version = version;
header_.spirv_generator = generator;
return SPV_SUCCESS;
}
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
// Operation can fail if the instruction fails to pass the validator or if
// the encoder stubmles on something unexpected.
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
// Concatenates MARK-V header and the bit stream with encoded instructions
// into a single buffer and returns it as spv_markv_binary. The returned
// value is owned by the caller and needs to be destroyed with
// spvMarkvBinaryDestroy().
spv_markv_binary GetMarkvBinary() {
header_.markv_length_in_bits =
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
spv_markv_binary markv_binary = new spv_markv_binary_t();
markv_binary->data = new uint8_t[num_bytes];
markv_binary->length = num_bytes;
assert(writer_.GetData());
std::memcpy(markv_binary->data, &header_, sizeof(header_));
std::memcpy(markv_binary->data + sizeof(header_),
writer_.GetData(), writer_.GetDataSizeBytes());
return markv_binary;
}
// Creates an internal logger which writes comments on the encoding process.
// Output can later be accessed with GetComments().
void CreateCommentsLogger() {
logger_.reset(new CommentLogger());
writer_.SetCallback([this](const std::string& str){
logger_->AppendBitSequence(str);
});
}
// Optionally adds disassembly to the comments.
// Disassembly should contain all instructions in the module separated by
// \n, and no header.
void SetDisassembly(std::string&& disassembly) {
disassembly_.reset(new std::stringstream(std::move(disassembly)));
}
// Extracts the next instruction line from the disassembly and logs it.
void LogDisassemblyInstruction() {
if (logger_ && disassembly_) {
std::string line;
std::getline(*disassembly_, line, '\n');
logger_->AppendTextNewLine(line);
}
}
// Extracts the text from the comment logger.
std::string GetComments() const {
if (!logger_)
return "";
return logger_->GetText();
}
private:
// Creates and returns validator options. Return value owned by the caller.
static spv_validator_options GetValidatorOptions(
spv_const_markv_encoder_options) {
return spvValidatorOptionsCreate();
}
// Writes a single word to bit stream. |type| determines if the word is
// encoded and how.
void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
const size_t chunk_length =
GetOperandVariableWidthChunkLength(type);
if (chunk_length) {
writer_.WriteVariableWidthU32(word, chunk_length);
} else {
writer_.WriteUnencoded(word);
}
}
// Returns id index and updates move-to-front.
// Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
// instructions.
uint16_t GetIdIndex(uint32_t id) {
if (all_known_ids_.count(id)) {
uint16_t index = 0;
for (auto it = move_to_front_ids_.begin();
it != move_to_front_ids_.end(); ++it) {
if (*it == id) {
if (index != 0) {
move_to_front_ids_.erase(it);
move_to_front_ids_.push_front(id);
}
return index;
}
++index;
}
assert(0 && "Id not found in move_to_front_ids_");
return 0;
} else {
all_known_ids_.insert(id);
move_to_front_ids_.push_front(id);
return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
}
}
void AddByteBreakIfAgreed() {
if (!ShouldByteBreak(writer_.GetNumBits()))
return;
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("ByteBreak:");
}
writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
}
// Encodes a literal number operand and writes it to the bit stream.
void EncodeLiteralNumber(const Instruction& instruction,
const spv_parsed_operand_t& operand);
spv_const_markv_encoder_options options_;
// Bit stream where encoded instructions are written.
BitWriterWord64 writer_;
// If not nullptr, encoder will write comments.
std::unique_ptr<CommentLogger> logger_;
// If not nullptr, disassembled instruction lines will be written to comments.
// Format: \n separated instruction lines, no header.
std::unique_ptr<std::stringstream> disassembly_;
// All ids which were previosly encountered in the module.
std::unordered_set<uint32_t> all_known_ids_;
};
// Decodes MARK-V buffers written by MarkvEncoder.
class MarkvDecoder : public MarkvCodecBase {
public:
MarkvDecoder(spv_const_context context,
const uint8_t* markv_data,
size_t markv_size_bytes,
spv_const_markv_decoder_options options)
: MarkvCodecBase(context, GetValidatorOptions(options)),
options_(options), reader_(markv_data, markv_size_bytes) {
(void) options_;
vstate_.setIdBound(1);
parsed_operands_.reserve(25);
}
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
// Can be called only once. Fails if data of wrong format or ends prematurely,
// of if validation fails.
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
private:
// Describes the format of a typed literal number.
struct NumberType {
spv_number_kind_t type;
uint32_t bit_width;
};
// Creates and returns validator options. Return value owned by the caller.
static spv_validator_options GetValidatorOptions(
spv_const_markv_decoder_options) {
return spvValidatorOptionsCreate();
}
// Reads a single word from bit stream. |type| determines if the word needs
// to be decoded and how. Returns false if read fails.
bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
if (chunk_length) {
return reader_.ReadVariableWidthU32(word, chunk_length);
} else {
return reader_.ReadUnencoded(word);
}
}
// Fetches the id from the move-to-front list and moves it to front.
uint32_t GetIdAndMoveToFront(uint16_t index) {
if (index >= move_to_front_ids_.size()) {
// Issue new id.
const uint32_t id = vstate_.getIdBound();
move_to_front_ids_.push_front(id);
vstate_.setIdBound(id + 1);
return id;
} else {
if (index == 0)
return move_to_front_ids_.front();
// Iterate to index.
auto it = move_to_front_ids_.begin();
for (size_t i = 0; i < index; ++i)
++it;
const uint32_t id = *it;
move_to_front_ids_.erase(it);
move_to_front_ids_.push_front(id);
return id;
}
}
// Decodes id index and fetches the id from move-to-front list.
bool DecodeId(uint32_t* id) {
uint16_t index = 0;
if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
return false;
*id = GetIdAndMoveToFront(index);
return true;
}
bool ReadToByteBreakIfAgreed() {
if (!ShouldByteBreak(reader_.GetNumReadBits()))
return true;
uint64_t bits = 0;
if (!reader_.ReadBits(&bits,
GetNumBitsToNextByte(reader_.GetNumReadBits())))
return false;
if (bits != 0)
return false;
return true;
}
// Reads a literal number as it is described in |operand| from the bit stream,
// decodes and writes it to spirv_.
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
// Reads instruction from bit stream, decodes and validates it.
// Decoded instruction is valid until the next call of DecodeInstruction().
spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
// Read operand from the stream decodes and validates it.
spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
spv_parsed_instruction_t* inst,
const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands,
bool read_result_id);
// Records the numeric type for an operand according to the type information
// associated with the given non-zero type Id. This can fail if the type Id
// is not a type Id, or if the type Id does not reference a scalar numeric
// type. On success, return SPV_SUCCESS and populates the num_words,
// number_kind, and number_bit_width fields of parsed_operand.
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
uint32_t type_id);
// Records the number type for the given instruction, if that
// instruction generates a type. For types that aren't scalar numbers,
// record something with number kind SPV_NUMBER_NONE.
void RecordNumberType(const spv_parsed_instruction_t& inst);
spv_const_markv_decoder_options options_;
// Temporary sink where decoded SPIR-V words are written. Once it contains the
// entire module, the container is moved and returned.
std::vector<uint32_t> spirv_;
// Bit stream containing encoded data.
BitReaderWord64 reader_;
// Temporary storage for operands of the currently parsed instruction.
// Valid until next DecodeInstruction call.
std::vector<spv_parsed_operand_t> parsed_operands_;
// Maps a result ID to its type ID. By convention:
// - a result ID that is a type definition maps to itself.
// - a result ID without a type maps to 0. (E.g. for OpLabel)
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
// Maps a type ID to its number type description.
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
// Maps an ExtInstImport id to the extended instruction type.
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
};
void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width == 32) {
const uint32_t word = instruction.word(operand.offset);
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int32_t val = 0;
std::memcpy(&val, &word, 4);
writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
model_->s32_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
writer_.WriteUnencoded(word);
} else {
assert(0);
}
} else if (operand.number_bit_width == 16) {
const uint16_t word =
static_cast<uint16_t>(instruction.word(operand.offset));
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int16_t val = 0;
std::memcpy(&val, &word, 2);
writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
model_->s16_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
// TODO(atgoo@github.com) Write only 16 bits.
writer_.WriteUnencoded(word);
} else {
assert(0);
}
} else {
assert(operand.number_bit_width == 64);
const uint64_t word =
uint64_t(instruction.word(operand.offset)) |
(uint64_t(instruction.word(operand.offset + 1)) << 32);
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
std::memcpy(&val, &word, 8);
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
model_->s64_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
writer_.WriteUnencoded(word);
} else {
assert(0);
}
}
}
spv_result_t MarkvEncoder::EncodeInstruction(
const spv_parsed_instruction_t& inst) {
const spv_result_t validation_result = UpdateValidationState(inst);
if (validation_result != SPV_SUCCESS)
return validation_result;
bool result_id_was_forward_declared = false;
if (all_known_ids_.count(inst.result_id)) {
// Result id of the instruction was forward declared.
// Write a service opcode to signal this to the decoder.
writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
model_->opcode_chunk_length());
result_id_was_forward_declared = true;
}
const Instruction& instruction = GetCurrentInstruction();
const auto& operands = instruction.operands();
LogDisassemblyInstruction();
// Write opcode.
writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
// If the opcode has a variable number of operands, encode the number of
// operands with the instruction.
if (logger_)
logger_->AppendWhitespaces(kCommentNumWhitespaces);
writer_.WriteVariableWidthU16(inst.num_operands,
model_->num_operands_chunk_length());
}
// Write operands.
for (const auto& operand : operands) {
if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
!result_id_was_forward_declared) {
// Register the id, but don't encode it.
GetIdIndex(instruction.word(operand.offset));
continue;
}
if (logger_)
logger_->AppendWhitespaces(kCommentNumWhitespaces);
if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
EncodeLiteralNumber(instruction, operand);
} else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
const char* src =
reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
const size_t length = spv_strnlen_s(src, operand.num_words * 4);
if (length == operand.num_words * 4)
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to find terminal character of literal string";
for (size_t i = 0; i < length + 1; ++i)
writer_.WriteUnencoded(src[i]);
} else if (spvIsIdType(operand.type)) {
const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
} else {
for (int i = 0; i < operand.num_words; ++i) {
const uint32_t word = instruction.word(operand.offset + i);
EncodeOperandWord(operand.type, word);
}
}
}
AddByteBreakIfAgreed();
if (logger_) {
logger_->NewLine();
logger_->NewLine();
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width == 32) {
uint32_t word = 0;
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal U32";
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int32_t val = 0;
if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
model_->s32_block_exponent()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal S32";
std::memcpy(&word, &val, 4);
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
if (!reader_.ReadUnencoded(&word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal F32";
} else {
assert(0);
}
spirv_.push_back(word);
} else if (operand.number_bit_width == 16) {
uint32_t word = 0;
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
uint16_t val = 0;
if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal U16";
word = val;
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int16_t val = 0;
if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
model_->s16_block_exponent()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal S16";
// Int16 is stored as int32 in SPIR-V, not as bits.
int32_t val32 = val;
std::memcpy(&word, &val32, 4);
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
uint16_t word16 = 0;
if (!reader_.ReadUnencoded(&word16))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal F16";
word = word16;
} else {
assert(0);
}
spirv_.push_back(word);
} else {
assert(operand.number_bit_width == 64);
uint64_t word = 0;
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal U64";
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
model_->s64_block_exponent()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal S64";
std::memcpy(&word, &val, 8);
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
if (!reader_.ReadUnencoded(&word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal F64";
} else {
assert(0);
}
spirv_.push_back(static_cast<uint32_t>(word));
spirv_.push_back(static_cast<uint32_t>(word >> 32));
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
const bool header_read_success =
reader_.ReadUnencoded(&header_.magic_number) &&
reader_.ReadUnencoded(&header_.markv_version) &&
reader_.ReadUnencoded(&header_.markv_model) &&
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
reader_.ReadUnencoded(&header_.spirv_version) &&
reader_.ReadUnencoded(&header_.spirv_generator);
if (!header_read_success)
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Unable to read MARK-V header";
assert(header_.magic_number == kMarkvMagicNumber);
assert(header_.markv_length_in_bits > 0);
if (header_.magic_number != kMarkvMagicNumber)
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has incorrect magic number";
// TODO(atgoo@github.com): Print version strings.
if (header_.markv_version != GetMarkvVersion())
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec have different versions";
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
spirv_.resize(5, 0);
spirv_[0] = kSpirvMagicNumber;
spirv_[1] = header_.spirv_version;
spirv_[2] = header_.spirv_generator;
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
spv_parsed_instruction_t inst = {};
const spv_result_t decode_result = DecodeInstruction(&inst);
if (decode_result != SPV_SUCCESS)
return decode_result;
const spv_result_t validation_result = UpdateValidationState(inst);
if (validation_result != SPV_SUCCESS)
return validation_result;
}
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
!reader_.OnlyZeroesLeft()) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has wrong stated bit length "
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
}
// Decoding of the module is finished, validation state should have correct
// id bound.
spirv_[3] = vstate_.getIdBound();
*spirv_binary = std::move(spirv_);
return SPV_SUCCESS;
}
// TODO(atgoo@github.com): The implementation borrows heavily from
// Parser::parseOperand.
// Consider coupling them together in some way once MARK-V codec is more mature.
// For now it's better to keep the code independent for experimentation
// purposes.
spv_result_t MarkvDecoder::DecodeOperand(
size_t instruction_offset, size_t operand_offset,
spv_parsed_instruction_t* inst, const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands,
bool read_result_id) {
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
spv_parsed_operand_t parsed_operand;
memset(&parsed_operand, 0, sizeof(parsed_operand));
assert((operand_offset >> 16) == 0);
parsed_operand.offset = static_cast<uint16_t>(operand_offset);
parsed_operand.type = type;
// Set default values, may be updated later.
parsed_operand.number_kind = SPV_NUMBER_NONE;
parsed_operand.number_bit_width = 0;
const size_t first_word_index = spirv_.size();
switch (type) {
case SPV_OPERAND_TYPE_TYPE_ID: {
if (!DecodeId(&inst->type_id)) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read type_id";
}
if (inst->type_id == 0)
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
spirv_.push_back(inst->type_id);
vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
break;
}
case SPV_OPERAND_TYPE_RESULT_ID: {
if (read_result_id) {
if (!DecodeId(&inst->result_id))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read result_id";
} else {
inst->result_id = vstate_.getIdBound();
vstate_.setIdBound(inst->result_id + 1);
move_to_front_ids_.push_front(inst->result_id);
}
spirv_.push_back(inst->result_id);
// Save the result ID to type ID mapping.
// In the grammar, type ID always appears before result ID.
// A regular value maps to its type. Some instructions (e.g. OpLabel)
// have no type Id, and will map to 0. The result Id for a
// type-generating instruction (e.g. OpTypeInt) maps to itself.
auto insertion_result = id_to_type_id_.emplace(
inst->result_id,
spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
if(!insertion_result.second) {
return vstate_.diag(SPV_ERROR_INVALID_ID)
<< "Unexpected behavior: id->type_id pair was already registered";
}
break;
}
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
uint32_t id = 0;
if (!DecodeId(&id))
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
if (id == 0)
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
spirv_.push_back(id);
vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
parsed_operand.type = SPV_OPERAND_TYPE_ID;
if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
// The current word is the extended instruction set id.
// Set the extended instruction set type for the current instruction.
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
return vstate_.diag(SPV_ERROR_INVALID_ID)
<< "OpExtInst set id " << id
<< " does not reference an OpExtInstImport result Id";
}
inst->ext_inst_type = ext_inst_type_iter->second;
}
}
break;
}
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
uint32_t word = 0;
if (!DecodeOperandWord(type, &word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read enum";
spirv_.push_back(word);
assert(SpvOpExtInst == opcode);
assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE);
spv_ext_inst_desc ext_inst;
if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction number: " << word;
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
// These are regular single-word literal integer operands.
// Post-parsing validation should check the range of the parsed value.
parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
// It turns out they are always unsigned integers!
parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
parsed_operand.number_bit_width = 32;
uint32_t word = 0;
if (!DecodeOperandWord(type, &word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal integer";
spirv_.push_back(word);
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
if (opcode == SpvOpSwitch) {
// The literal operands have the same type as the value
// referenced by the selector Id.
const uint32_t selector_id = spirv_.at(instruction_offset + 1);
const auto type_id_iter = id_to_type_id_.find(selector_id);
if (type_id_iter == id_to_type_id_.end() ||
type_id_iter->second == 0) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " has no type";
}
uint32_t type_id = type_id_iter->second;
if (selector_id == type_id) {
// Recall that by convention, a result ID that is a type definition
// maps to itself.
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is a type, not a value";
}
if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
return error;
if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is not a scalar integer";
}
} else {
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
// The literal number type is determined by the type Id for the
// constant.
assert(inst->type_id);
if (auto error =
SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
return error;
}
if (auto error = DecodeLiteralNumber(parsed_operand))
return error;
break;
case SPV_OPERAND_TYPE_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
std::vector<char> str;
// The loop is expected to terminate once we encounter '\0' or exhaust
// the bit stream.
while (true) {
char ch = 0;
if (!reader_.ReadUnencoded(&ch))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal string";
str.push_back(ch);
if (ch == '\0')
break;
}
while (str.size() % 4 != 0)
str.push_back('\0');
spirv_.resize(spirv_.size() + str.size() / 4);
std::memcpy(&spirv_[first_word_index], str.data(), str.size());
if (SpvOpExtInstImport == opcode) {
// Record the extended instruction type for the ID for this import.
// There is only one string literal argument to OpExtInstImport,
// so it's sufficient to guard this just on the opcode.
const spv_ext_inst_type_t ext_inst_type =
spvExtInstImportTypeGet(str.data());
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction import '" << str.data() << "'";
}
// We must have parsed a valid result ID. It's a condition
// of the grammar, and we only accept non-zero result Ids.
assert(inst->result_id);
const bool inserted = import_id_to_ext_inst_type_.emplace(
inst->result_id, ext_inst_type).second;
(void)inserted;
assert(inserted);
}
break;
}
case SPV_OPERAND_TYPE_CAPABILITY:
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
case SPV_OPERAND_TYPE_MEMORY_MODEL:
case SPV_OPERAND_TYPE_EXECUTION_MODE:
case SPV_OPERAND_TYPE_STORAGE_CLASS:
case SPV_OPERAND_TYPE_DIMENSIONALITY:
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
case SPV_OPERAND_TYPE_DECORATION:
case SPV_OPERAND_TYPE_BUILT_IN:
case SPV_OPERAND_TYPE_GROUP_OPERATION:
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
// A single word that is a plain enum value.
uint32_t word = 0;
if (!DecodeOperandWord(type, &word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read enum";
spirv_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
spv_operand_desc entry;
if (grammar_.lookupOperand(type, word, &entry)) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid "
<< spvOperandTypeStr(parsed_operand.type)
<< " operand: " << word;
}
// Prepare to accept operands to this operand, if needed.
spvPushOperandTypes(entry->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
// This operand is a mask.
uint32_t word = 0;
if (!DecodeOperandWord(type, &word))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read " << spvOperandTypeStr(type)
<< " for " << spvOpcodeString(SpvOp(inst->opcode));
spirv_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
// MSB to LSB since we can only prepend operands to a pattern.
// The only case in the grammar where you have more than one mask bit
// having an operand is for image operands. See SPIR-V 3.14 Image
// Operands.
uint32_t remaining_word = word;
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
if (remaining_word & mask) {
spv_operand_desc entry;
if (grammar_.lookupOperand(type, mask, &entry)) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid " << spvOperandTypeStr(parsed_operand.type)
<< " operand: " << word << " has invalid mask component "
<< mask;
}
remaining_word ^= mask;
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
if (word == 0) {
// An all-zeroes mask *might* also be valid.
spv_operand_desc entry;
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
// Prepare for its operands, if any.
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
break;
}
default:
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Internal error: Unhandled operand type: " << type;
}
parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
parsed_operands_.push_back(parsed_operand);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
parsed_operands_.clear();
const size_t instruction_offset = spirv_.size();
bool read_result_id = false;
while (true) {
uint32_t word = 0;
if (!reader_.ReadVariableWidthU32(&word,
model_->opcode_chunk_length())) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read opcode of instruction";
}
if (word >= kMarkvFirstOpcode) {
if (word == kMarkvOpNextInstructionEncodesResultId) {
read_result_id = true;
} else {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Encountered unknown MARK-V opcode";
}
} else {
inst->opcode = static_cast<uint16_t>(word);
break;
}
}
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
// Opcode/num_words placeholder, the word will be filled in later.
spirv_.push_back(0);
spv_opcode_desc opcode_desc;
if (grammar_.lookupOpcode(opcode, &opcode_desc)
!= SPV_SUCCESS) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
}
spv_operand_pattern_t expected_operands;
expected_operands.reserve(opcode_desc->numTypes);
for (auto i = 0; i < opcode_desc->numTypes; i++)
expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
if (!reader_.ReadVariableWidthU16(&inst->num_operands,
model_->num_operands_chunk_length()))
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read num_operands of instruction";
} else {
inst->num_operands = static_cast<uint16_t>(expected_operands.size());
}
for (size_t operand_index = 0;
operand_index < static_cast<size_t>(inst->num_operands);
++operand_index) {
assert(!expected_operands.empty());
const spv_operand_type_t type =
spvTakeFirstMatchableOperand(&expected_operands);
const size_t operand_offset = spirv_.size() - instruction_offset;
const spv_result_t decode_result =
DecodeOperand(instruction_offset, operand_offset, inst, type,
&expected_operands, read_result_id);
if (decode_result != SPV_SUCCESS)
return decode_result;
}
assert(inst->num_operands == parsed_operands_.size());
// Only valid while spirv_ and parsed_operands_ remain unchanged.
inst->words = &spirv_[instruction_offset];
inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
spirv_[instruction_offset] =
spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
assert(inst->num_words == std::accumulate(
parsed_operands_.begin(), parsed_operands_.end(), 1,
[](int num_words, const spv_parsed_operand_t& operand) {
return num_words += operand.num_words;
}) && "num_words in instruction doesn't correspond to the sum of num_words"
"in the operands");
RecordNumberType(*inst);
if (!ReadToByteBreakIfAgreed())
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read to byte break";
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
assert(type_id != 0);
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
if (type_info_iter == type_id_to_number_type_info_.end()) {
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a type";
}
const NumberType& info = type_info_iter->second;
if (info.type == SPV_NUMBER_NONE) {
// This is a valid type, but for something other than a scalar number.
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a scalar numeric type";
}
parsed_operand->number_kind = info.type;
parsed_operand->number_bit_width = info.bit_width;
// Round up the word count.
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
return SPV_SUCCESS;
}
void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
if (spvOpcodeGeneratesType(opcode)) {
NumberType info = {SPV_NUMBER_NONE, 0};
if (SpvOpTypeInt == opcode) {
info.bit_width = inst.words[inst.operands[1].offset];
info.type = inst.words[inst.operands[2].offset] ?
SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
} else if (SpvOpTypeFloat == opcode) {
info.bit_width = inst.words[inst.operands[1].offset];
info.type = SPV_NUMBER_FLOATING;
}
// The *result* Id of a type generating instruction is the type Id.
type_id_to_number_type_info_[inst.result_id] = info;
}
}
spv_result_t EncodeHeader(
void* user_data, spv_endianness_t endian, uint32_t magic,
uint32_t version, uint32_t generator, uint32_t id_bound,
uint32_t schema) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeHeader(
endian, magic, version, generator, id_bound, schema);
}
spv_result_t EncodeInstruction(
void* user_data, const spv_parsed_instruction_t* inst) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeInstruction(*inst);
}
} // namespace
spv_result_t spvSpirvToMarkv(spv_const_context context,
const uint32_t* spirv_words,
const size_t spirv_num_words,
spv_const_markv_encoder_options options,
spv_markv_binary* markv_binary,
spv_text* comments, spv_diagnostic* diagnostic) {
spv_context_t hijack_context = *context;
if (diagnostic) {
*diagnostic = nullptr;
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
}
spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
spv_endianness_t endian;
spv_position_t position = {};
if (spvBinaryEndianness(&spirv_binary, &endian)) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V magic number.";
}
spv_header_t header;
if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V header.";
}
MarkvEncoder encoder(&hijack_context, options);
if (comments) {
encoder.CreateCommentsLogger();
spv_text text = nullptr;
if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
!= SPV_SUCCESS) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Failed to disassemble SPIR-V binary.";
}
assert(text);
encoder.SetDisassembly(std::string(text->str, text->length));
spvTextDestroy(text);
}
if (spvBinaryParse(
&hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
EncodeInstruction, diagnostic) != SPV_SUCCESS) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Unable to encode to MARK-V.";
}
if (comments)
*comments = CreateSpvText(encoder.GetComments());
*markv_binary = encoder.GetMarkvBinary();
return SPV_SUCCESS;
}
spv_result_t spvMarkvToSpirv(spv_const_context context,
const uint8_t* markv_data,
size_t markv_size_bytes,
spv_const_markv_decoder_options options,
spv_binary* spirv_binary,
spv_text* /* comments */, spv_diagnostic* diagnostic) {
spv_position_t position = {};
spv_context_t hijack_context = *context;
if (diagnostic) {
*diagnostic = nullptr;
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
}
MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
std::vector<uint32_t> words;
if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
return libspirv::DiagnosticStream(position, hijack_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Unable to decode MARK-V.";
}
assert(!words.empty());
*spirv_binary = new spv_binary_t();
(*spirv_binary)->code = new uint32_t[words.size()];
(*spirv_binary)->wordCount = words.size();
std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
return SPV_SUCCESS;
}
void spvMarkvBinaryDestroy(spv_markv_binary binary) {
if (!binary) return;
delete[] binary->data;
delete binary;
}
spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
return new spv_markv_encoder_options_t;
}
void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
delete options;
}
spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
return new spv_markv_decoder_options_t;
}
void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
delete options;
}