// Copyright (c) 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Contains // - SPIR-V to MARK-V encoder // - MARK-V to SPIR-V decoder // // MARK-V is a compression format for SPIR-V binaries. It strips away // non-essential information (such as result ids which can be regenerated) and // uses various bit reduction techiniques to reduce the size of the binary. // // MarkvModel is a flatbuffers object containing a set of rules defining how // compression/decompression is done (coding schemes, dictionaries). #include #include #include #include #include #include #include #include #include #include #include "binary.h" #include "diagnostic.h" #include "enum_string_mapping.h" #include "extensions.h" #include "ext_inst.h" #include "instruction.h" #include "opcode.h" #include "operand.h" #include "spirv-tools/libspirv.h" #include "spirv-tools/markv.h" #include "spirv_endian.h" #include "spirv_validator_options.h" #include "util/bit_stream.h" #include "util/parse_number.h" #include "validate.h" #include "val/instruction.h" #include "val/validation_state.h" using libspirv::Instruction; using libspirv::ValidationState_t; using spvtools::ValidateInstructionAndUpdateValidationState; using spvutils::BitReaderWord64; using spvutils::BitWriterWord64; struct spv_markv_encoder_options_t { }; struct spv_markv_decoder_options_t { }; namespace { const uint32_t kSpirvMagicNumber = SpvMagicNumber; const uint32_t kMarkvMagicNumber = 0x07230303; enum { kMarkvFirstOpcode = 65536, kMarkvOpNextInstructionEncodesResultId = 65536, }; const size_t kCommentNumWhitespaces = 2; // TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer // containing MARK-V model for a specific dataset. class MarkvModel { public: size_t opcode_chunk_length() const { return 7; } size_t num_operands_chunk_length() const { return 3; } size_t id_index_chunk_length() const { return 3; } size_t u16_chunk_length() const { return 4; } size_t s16_chunk_length() const { return 4; } size_t s16_block_exponent() const { return 6; } size_t u32_chunk_length() const { return 8; } size_t s32_chunk_length() const { return 8; } size_t s32_block_exponent() const { return 10; } size_t u64_chunk_length() const { return 8; } size_t s64_chunk_length() const { return 8; } size_t s64_block_exponent() const { return 10; } }; const MarkvModel* GetDefaultModel() { static MarkvModel model; return &model; } // Returns chunk length used for variable length encoding of spirv operand // words. Returns zero if operand type corresponds to potentially multiple // words or a word which is not expected to profit from variable width encoding. // Chunk length is selected based on the size of expected value. // Most of these values will later be encoded with probability-based coding, // but variable width integer coding is a good quick solution. // TODO(atgoo@github.com): Put this in MarkvModel flatbuffer. size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) { switch (type) { case SPV_OPERAND_TYPE_TYPE_ID: return 4; case SPV_OPERAND_TYPE_RESULT_ID: case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: return 8; case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: return 6; case SPV_OPERAND_TYPE_CAPABILITY: return 6; case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: return 3; case SPV_OPERAND_TYPE_ADDRESSING_MODEL: case SPV_OPERAND_TYPE_MEMORY_MODEL: return 2; case SPV_OPERAND_TYPE_EXECUTION_MODE: return 6; case SPV_OPERAND_TYPE_STORAGE_CLASS: return 4; case SPV_OPERAND_TYPE_DIMENSIONALITY: case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: return 3; case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: return 2; case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: return 6; case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: case SPV_OPERAND_TYPE_LINKAGE_TYPE: case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: return 2; case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: return 3; case SPV_OPERAND_TYPE_DECORATION: case SPV_OPERAND_TYPE_BUILT_IN: return 6; case SPV_OPERAND_TYPE_GROUP_OPERATION: case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: return 2; case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: case SPV_OPERAND_TYPE_FUNCTION_CONTROL: case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: return 4; case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: return 6; default: return 0; } return 0; } // Returns true if the opcode has a fixed number of operands. May return a // false negative. bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { switch (opcode) { // TODO(atgoo@github.com) This is not a complete list. case SpvOpNop: case SpvOpName: case SpvOpUndef: case SpvOpSizeOf: case SpvOpLine: case SpvOpNoLine: case SpvOpDecorationGroup: case SpvOpExtension: case SpvOpExtInstImport: case SpvOpMemoryModel: case SpvOpCapability: case SpvOpTypeVoid: case SpvOpTypeBool: case SpvOpTypeInt: case SpvOpTypeFloat: case SpvOpTypeVector: case SpvOpTypeMatrix: case SpvOpTypeSampler: case SpvOpTypeSampledImage: case SpvOpTypeArray: case SpvOpTypePointer: case SpvOpConstantTrue: case SpvOpConstantFalse: case SpvOpLabel: case SpvOpBranch: case SpvOpFunction: case SpvOpFunctionParameter: case SpvOpFunctionEnd: case SpvOpBitcast: case SpvOpCopyObject: case SpvOpTranspose: case SpvOpSNegate: case SpvOpFNegate: case SpvOpIAdd: case SpvOpFAdd: case SpvOpISub: case SpvOpFSub: case SpvOpIMul: case SpvOpFMul: case SpvOpUDiv: case SpvOpSDiv: case SpvOpFDiv: case SpvOpUMod: case SpvOpSRem: case SpvOpSMod: case SpvOpFRem: case SpvOpFMod: case SpvOpVectorTimesScalar: case SpvOpMatrixTimesScalar: case SpvOpVectorTimesMatrix: case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: case SpvOpOuterProduct: case SpvOpDot: return true; default: break; } return false; } size_t GetNumBitsToNextByte(size_t bit_pos) { return (8 - (bit_pos % 8)) % 8; } bool ShouldByteBreak(size_t bit_pos) { const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos); return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2; } // Defines and returns current MARK-V version. uint32_t GetMarkvVersion() { const uint32_t kVersionMajor = 1; const uint32_t kVersionMinor = 0; return kVersionMinor | (kVersionMajor << 16); } class CommentLogger { public: void AppendText(const std::string& str) { Append(str); use_delimiter_ = false; } void AppendTextNewLine(const std::string& str) { Append(str); Append("\n"); use_delimiter_ = false; } void AppendBitSequence(const std::string& str) { if (use_delimiter_) Append("-"); Append(str); use_delimiter_ = true; } void AppendWhitespaces(size_t num) { Append(std::string(num, ' ')); use_delimiter_ = false; } void NewLine() { Append("\n"); use_delimiter_ = false; } std::string GetText() const { return ss_.str(); } private: void Append(const std::string& str) { ss_ << str; // std::cerr << str; } std::stringstream ss_; // If true a delimiter will be appended before the next bit sequence. // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. bool use_delimiter_ = false; }; // Creates spv_text object containing text from |str|. // The returned value is owned by the caller and needs to be destroyed with // spvTextDestroy. spv_text CreateSpvText(const std::string& str) { spv_text out = new spv_text_t(); assert(out); char* cstr = new char[str.length() + 1]; assert(cstr); std::strncpy(cstr, str.c_str(), str.length()); cstr[str.length()] = '\0'; out->str = cstr; out->length = str.length(); return out; } // Base class for MARK-V encoder and decoder. Contains common functionality // such as: // - Validator connection and validation state. // - SPIR-V grammar and helper functions. class MarkvCodecBase { public: virtual ~MarkvCodecBase() { spvValidatorOptionsDestroy(validator_options_); } MarkvCodecBase() = delete; void SetModel(const MarkvModel* model) { model_ = model; } protected: struct MarkvHeader { MarkvHeader() { magic_number = kMarkvMagicNumber; markv_version = GetMarkvVersion(); markv_model = 0; markv_length_in_bits = 0; spirv_version = 0; spirv_generator = 0; } uint32_t magic_number; uint32_t markv_version; // Magic number to identify or verify MarkvModel used for encoding. uint32_t markv_model; uint32_t markv_length_in_bits; uint32_t spirv_version; uint32_t spirv_generator; }; explicit MarkvCodecBase(spv_const_context context, spv_validator_options validator_options) : validator_options_(validator_options), vstate_(context, validator_options_), grammar_(context), model_(GetDefaultModel()) {} // Validates a single instruction and updates validation state of the module. spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) { return ValidateInstructionAndUpdateValidationState(&vstate_, &inst); } // Returns the current instruction (the one last processed by the validator). const Instruction& GetCurrentInstruction() const { return vstate_.ordered_instructions().back(); } spv_validator_options validator_options_; ValidationState_t vstate_; const libspirv::AssemblyGrammar grammar_; MarkvHeader header_; const MarkvModel* model_; // Move-to-front list of all ids. // TODO(atgoo@github.com) Consider a better move-to-front implementation. std::list move_to_front_ids_; }; // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and // EncodeInstruction which can be used as callback by spvBinaryParse. // Encoded binary is written to an internally maintained bitstream. // After the last instruction is encoded, the resulting MARK-V binary can be // acquired by calling GetMarkvBinary(). // The encoder uses SPIR-V validator to keep internal state, therefore // SPIR-V binary needs to be able to pass validator checks. // CreateCommentsLogger() can be used to enable the encoder to write comments // on how encoding was done, which can later be accessed with GetComments(). class MarkvEncoder : public MarkvCodecBase { public: MarkvEncoder(spv_const_context context, spv_const_markv_encoder_options options) : MarkvCodecBase(context, GetValidatorOptions(options)), options_(options) { (void) options_; } // Writes data from SPIR-V header to MARK-V header. spv_result_t EncodeHeader( spv_endianness_t /* endian */, uint32_t /* magic */, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t /* schema */) { vstate_.setIdBound(id_bound); header_.spirv_version = version; header_.spirv_generator = generator; return SPV_SUCCESS; } // Encodes SPIR-V instruction to MARK-V and writes to bit stream. // Operation can fail if the instruction fails to pass the validator or if // the encoder stubmles on something unexpected. spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); // Concatenates MARK-V header and the bit stream with encoded instructions // into a single buffer and returns it as spv_markv_binary. The returned // value is owned by the caller and needs to be destroyed with // spvMarkvBinaryDestroy(). spv_markv_binary GetMarkvBinary() { header_.markv_length_in_bits = static_cast(sizeof(header_) * 8 + writer_.GetNumBits()); const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); spv_markv_binary markv_binary = new spv_markv_binary_t(); markv_binary->data = new uint8_t[num_bytes]; markv_binary->length = num_bytes; assert(writer_.GetData()); std::memcpy(markv_binary->data, &header_, sizeof(header_)); std::memcpy(markv_binary->data + sizeof(header_), writer_.GetData(), writer_.GetDataSizeBytes()); return markv_binary; } // Creates an internal logger which writes comments on the encoding process. // Output can later be accessed with GetComments(). void CreateCommentsLogger() { logger_.reset(new CommentLogger()); writer_.SetCallback([this](const std::string& str){ logger_->AppendBitSequence(str); }); } // Optionally adds disassembly to the comments. // Disassembly should contain all instructions in the module separated by // \n, and no header. void SetDisassembly(std::string&& disassembly) { disassembly_.reset(new std::stringstream(std::move(disassembly))); } // Extracts the next instruction line from the disassembly and logs it. void LogDisassemblyInstruction() { if (logger_ && disassembly_) { std::string line; std::getline(*disassembly_, line, '\n'); logger_->AppendTextNewLine(line); } } // Extracts the text from the comment logger. std::string GetComments() const { if (!logger_) return ""; return logger_->GetText(); } private: // Creates and returns validator options. Return value owned by the caller. static spv_validator_options GetValidatorOptions( spv_const_markv_encoder_options) { return spvValidatorOptionsCreate(); } // Writes a single word to bit stream. |type| determines if the word is // encoded and how. void EncodeOperandWord(spv_operand_type_t type, uint32_t word) { const size_t chunk_length = GetOperandVariableWidthChunkLength(type); if (chunk_length) { writer_.WriteVariableWidthU32(word, chunk_length); } else { writer_.WriteUnencoded(word); } } // Returns id index and updates move-to-front. // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535 // instructions. uint16_t GetIdIndex(uint32_t id) { if (all_known_ids_.count(id)) { uint16_t index = 0; for (auto it = move_to_front_ids_.begin(); it != move_to_front_ids_.end(); ++it) { if (*it == id) { if (index != 0) { move_to_front_ids_.erase(it); move_to_front_ids_.push_front(id); } return index; } ++index; } assert(0 && "Id not found in move_to_front_ids_"); return 0; } else { all_known_ids_.insert(id); move_to_front_ids_.push_front(id); return static_cast(move_to_front_ids_.size() - 1); } } void AddByteBreakIfAgreed() { if (!ShouldByteBreak(writer_.GetNumBits())) return; if (logger_) { logger_->AppendWhitespaces(kCommentNumWhitespaces); logger_->AppendText("ByteBreak:"); } writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits())); } // Encodes a literal number operand and writes it to the bit stream. void EncodeLiteralNumber(const Instruction& instruction, const spv_parsed_operand_t& operand); spv_const_markv_encoder_options options_; // Bit stream where encoded instructions are written. BitWriterWord64 writer_; // If not nullptr, encoder will write comments. std::unique_ptr logger_; // If not nullptr, disassembled instruction lines will be written to comments. // Format: \n separated instruction lines, no header. std::unique_ptr disassembly_; // All ids which were previosly encountered in the module. std::unordered_set all_known_ids_; }; // Decodes MARK-V buffers written by MarkvEncoder. class MarkvDecoder : public MarkvCodecBase { public: MarkvDecoder(spv_const_context context, const uint8_t* markv_data, size_t markv_size_bytes, spv_const_markv_decoder_options options) : MarkvCodecBase(context, GetValidatorOptions(options)), options_(options), reader_(markv_data, markv_size_bytes) { (void) options_; vstate_.setIdBound(1); parsed_operands_.reserve(25); } // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. // Can be called only once. Fails if data of wrong format or ends prematurely, // of if validation fails. spv_result_t DecodeModule(std::vector* spirv_binary); private: // Describes the format of a typed literal number. struct NumberType { spv_number_kind_t type; uint32_t bit_width; }; // Creates and returns validator options. Return value owned by the caller. static spv_validator_options GetValidatorOptions( spv_const_markv_decoder_options) { return spvValidatorOptionsCreate(); } // Reads a single word from bit stream. |type| determines if the word needs // to be decoded and how. Returns false if read fails. bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) { const size_t chunk_length = GetOperandVariableWidthChunkLength(type); if (chunk_length) { return reader_.ReadVariableWidthU32(word, chunk_length); } else { return reader_.ReadUnencoded(word); } } // Fetches the id from the move-to-front list and moves it to front. uint32_t GetIdAndMoveToFront(uint16_t index) { if (index >= move_to_front_ids_.size()) { // Issue new id. const uint32_t id = vstate_.getIdBound(); move_to_front_ids_.push_front(id); vstate_.setIdBound(id + 1); return id; } else { if (index == 0) return move_to_front_ids_.front(); // Iterate to index. auto it = move_to_front_ids_.begin(); for (size_t i = 0; i < index; ++i) ++it; const uint32_t id = *it; move_to_front_ids_.erase(it); move_to_front_ids_.push_front(id); return id; } } // Decodes id index and fetches the id from move-to-front list. bool DecodeId(uint32_t* id) { uint16_t index = 0; if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length())) return false; *id = GetIdAndMoveToFront(index); return true; } bool ReadToByteBreakIfAgreed() { if (!ShouldByteBreak(reader_.GetNumReadBits())) return true; uint64_t bits = 0; if (!reader_.ReadBits(&bits, GetNumBitsToNextByte(reader_.GetNumReadBits()))) return false; if (bits != 0) return false; return true; } // Reads a literal number as it is described in |operand| from the bit stream, // decodes and writes it to spirv_. spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); // Reads instruction from bit stream, decodes and validates it. // Decoded instruction is valid until the next call of DecodeInstruction(). spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst); // Read operand from the stream decodes and validates it. spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset, spv_parsed_instruction_t* inst, const spv_operand_type_t type, spv_operand_pattern_t* expected_operands, bool read_result_id); // Records the numeric type for an operand according to the type information // associated with the given non-zero type Id. This can fail if the type Id // is not a type Id, or if the type Id does not reference a scalar numeric // type. On success, return SPV_SUCCESS and populates the num_words, // number_kind, and number_bit_width fields of parsed_operand. spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, uint32_t type_id); // Records the number type for the given instruction, if that // instruction generates a type. For types that aren't scalar numbers, // record something with number kind SPV_NUMBER_NONE. void RecordNumberType(const spv_parsed_instruction_t& inst); spv_const_markv_decoder_options options_; // Temporary sink where decoded SPIR-V words are written. Once it contains the // entire module, the container is moved and returned. std::vector spirv_; // Bit stream containing encoded data. BitReaderWord64 reader_; // Temporary storage for operands of the currently parsed instruction. // Valid until next DecodeInstruction call. std::vector parsed_operands_; // Maps a result ID to its type ID. By convention: // - a result ID that is a type definition maps to itself. // - a result ID without a type maps to 0. (E.g. for OpLabel) std::unordered_map id_to_type_id_; // Maps a type ID to its number type description. std::unordered_map type_id_to_number_type_info_; // Maps an ExtInstImport id to the extended instruction type. std::unordered_map import_id_to_ext_inst_type_; }; void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction, const spv_parsed_operand_t& operand) { if (operand.number_bit_width == 32) { const uint32_t word = instruction.word(operand.offset); if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { writer_.WriteVariableWidthU32(word, model_->u32_chunk_length()); } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int32_t val = 0; std::memcpy(&val, &word, 4); writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(), model_->s32_block_exponent()); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { writer_.WriteUnencoded(word); } else { assert(0); } } else if (operand.number_bit_width == 16) { const uint16_t word = static_cast(instruction.word(operand.offset)); if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { writer_.WriteVariableWidthU16(word, model_->u16_chunk_length()); } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int16_t val = 0; std::memcpy(&val, &word, 2); writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(), model_->s16_block_exponent()); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { // TODO(atgoo@github.com) Write only 16 bits. writer_.WriteUnencoded(word); } else { assert(0); } } else { assert(operand.number_bit_width == 64); const uint64_t word = uint64_t(instruction.word(operand.offset)) | (uint64_t(instruction.word(operand.offset + 1)) << 32); if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; std::memcpy(&val, &word, 8); writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), model_->s64_block_exponent()); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { writer_.WriteUnencoded(word); } else { assert(0); } } } spv_result_t MarkvEncoder::EncodeInstruction( const spv_parsed_instruction_t& inst) { const spv_result_t validation_result = UpdateValidationState(inst); if (validation_result != SPV_SUCCESS) return validation_result; bool result_id_was_forward_declared = false; if (all_known_ids_.count(inst.result_id)) { // Result id of the instruction was forward declared. // Write a service opcode to signal this to the decoder. writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId, model_->opcode_chunk_length()); result_id_was_forward_declared = true; } const Instruction& instruction = GetCurrentInstruction(); const auto& operands = instruction.operands(); LogDisassemblyInstruction(); // Write opcode. writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length()); if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) { // If the opcode has a variable number of operands, encode the number of // operands with the instruction. if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); writer_.WriteVariableWidthU16(inst.num_operands, model_->num_operands_chunk_length()); } // Write operands. for (const auto& operand : operands) { if (operand.type == SPV_OPERAND_TYPE_RESULT_ID && !result_id_was_forward_declared) { // Register the id, but don't encode it. GetIdIndex(instruction.word(operand.offset)); continue; } if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) { EncodeLiteralNumber(instruction, operand); } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) { const char* src = reinterpret_cast(&instruction.words()[operand.offset]); const size_t length = spv_strnlen_s(src, operand.num_words * 4); if (length == operand.num_words * 4) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to find terminal character of literal string"; for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); } else if (spvIsIdType(operand.type)) { const uint16_t id_index = GetIdIndex(instruction.word(operand.offset)); writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length()); } else { for (int i = 0; i < operand.num_words; ++i) { const uint32_t word = instruction.word(operand.offset + i); EncodeOperandWord(operand.type, word); } } } AddByteBreakIfAgreed(); if (logger_) { logger_->NewLine(); logger_->NewLine(); } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeLiteralNumber( const spv_parsed_operand_t& operand) { if (operand.number_bit_width == 32) { uint32_t word = 0; if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U32"; } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int32_t val = 0; if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(), model_->s32_block_exponent())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S32"; std::memcpy(&word, &val, 4); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { if (!reader_.ReadUnencoded(&word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F32"; } else { assert(0); } spirv_.push_back(word); } else if (operand.number_bit_width == 16) { uint32_t word = 0; if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { uint16_t val = 0; if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U16"; word = val; } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int16_t val = 0; if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(), model_->s16_block_exponent())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S16"; // Int16 is stored as int32 in SPIR-V, not as bits. int32_t val32 = val; std::memcpy(&word, &val32, 4); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { uint16_t word16 = 0; if (!reader_.ReadUnencoded(&word16)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F16"; word = word16; } else { assert(0); } spirv_.push_back(word); } else { assert(operand.number_bit_width == 64); uint64_t word = 0; if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), model_->s64_block_exponent())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; std::memcpy(&word, &val, 8); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { if (!reader_.ReadUnencoded(&word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; } else { assert(0); } spirv_.push_back(static_cast(word)); spirv_.push_back(static_cast(word >> 32)); } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { const bool header_read_success = reader_.ReadUnencoded(&header_.magic_number) && reader_.ReadUnencoded(&header_.markv_version) && reader_.ReadUnencoded(&header_.markv_model) && reader_.ReadUnencoded(&header_.markv_length_in_bits) && reader_.ReadUnencoded(&header_.spirv_version) && reader_.ReadUnencoded(&header_.spirv_generator); if (!header_read_success) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; assert(header_.magic_number == kMarkvMagicNumber); assert(header_.markv_length_in_bits > 0); if (header_.magic_number != kMarkvMagicNumber) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has incorrect magic number"; // TODO(atgoo@github.com): Print version strings. if (header_.markv_version != GetMarkvVersion()) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary and the codec have different versions"; spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. spirv_.resize(5, 0); spirv_[0] = kSpirvMagicNumber; spirv_[1] = header_.spirv_version; spirv_[2] = header_.spirv_generator; while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { spv_parsed_instruction_t inst = {}; const spv_result_t decode_result = DecodeInstruction(&inst); if (decode_result != SPV_SUCCESS) return decode_result; const spv_result_t validation_result = UpdateValidationState(inst); if (validation_result != SPV_SUCCESS) return validation_result; } if (reader_.GetNumReadBits() != header_.markv_length_in_bits || !reader_.OnlyZeroesLeft()) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has wrong stated bit length " << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; } // Decoding of the module is finished, validation state should have correct // id bound. spirv_[3] = vstate_.getIdBound(); *spirv_binary = std::move(spirv_); return SPV_SUCCESS; } // TODO(atgoo@github.com): The implementation borrows heavily from // Parser::parseOperand. // Consider coupling them together in some way once MARK-V codec is more mature. // For now it's better to keep the code independent for experimentation // purposes. spv_result_t MarkvDecoder::DecodeOperand( size_t instruction_offset, size_t operand_offset, spv_parsed_instruction_t* inst, const spv_operand_type_t type, spv_operand_pattern_t* expected_operands, bool read_result_id) { const SpvOp opcode = static_cast(inst->opcode); spv_parsed_operand_t parsed_operand; memset(&parsed_operand, 0, sizeof(parsed_operand)); assert((operand_offset >> 16) == 0); parsed_operand.offset = static_cast(operand_offset); parsed_operand.type = type; // Set default values, may be updated later. parsed_operand.number_kind = SPV_NUMBER_NONE; parsed_operand.number_bit_width = 0; const size_t first_word_index = spirv_.size(); switch (type) { case SPV_OPERAND_TYPE_TYPE_ID: { if (!DecodeId(&inst->type_id)) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read type_id"; } if (inst->type_id == 0) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0"; spirv_.push_back(inst->type_id); vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1)); break; } case SPV_OPERAND_TYPE_RESULT_ID: { if (read_result_id) { if (!DecodeId(&inst->result_id)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read result_id"; } else { inst->result_id = vstate_.getIdBound(); vstate_.setIdBound(inst->result_id + 1); move_to_front_ids_.push_front(inst->result_id); } spirv_.push_back(inst->result_id); // Save the result ID to type ID mapping. // In the grammar, type ID always appears before result ID. // A regular value maps to its type. Some instructions (e.g. OpLabel) // have no type Id, and will map to 0. The result Id for a // type-generating instruction (e.g. OpTypeInt) maps to itself. auto insertion_result = id_to_type_id_.emplace( inst->result_id, spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id); if(!insertion_result.second) { return vstate_.diag(SPV_ERROR_INVALID_ID) << "Unexpected behavior: id->type_id pair was already registered"; } break; } case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_OPTIONAL_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { uint32_t id = 0; if (!DecodeId(&id)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id"; if (id == 0) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; spirv_.push_back(id); vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1)); if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { parsed_operand.type = SPV_OPERAND_TYPE_ID; if (opcode == SpvOpExtInst && parsed_operand.offset == 3) { // The current word is the extended instruction set id. // Set the extended instruction set type for the current instruction. auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { return vstate_.diag(SPV_ERROR_INVALID_ID) << "OpExtInst set id " << id << " does not reference an OpExtInstImport result Id"; } inst->ext_inst_type = ext_inst_type_iter->second; } } break; } case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { uint32_t word = 0; if (!DecodeOperandWord(type, &word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read enum"; spirv_.push_back(word); assert(SpvOpExtInst == opcode); assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE); spv_ext_inst_desc ext_inst; if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction number: " << word; spvPushOperandTypes(ext_inst->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { // These are regular single-word literal integer operands. // Post-parsing validation should check the range of the parsed value. parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; // It turns out they are always unsigned integers! parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT; parsed_operand.number_bit_width = 32; uint32_t word = 0; if (!DecodeOperandWord(type, &word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal integer"; spirv_.push_back(word); break; } case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; if (opcode == SpvOpSwitch) { // The literal operands have the same type as the value // referenced by the selector Id. const uint32_t selector_id = spirv_.at(instruction_offset + 1); const auto type_id_iter = id_to_type_id_.find(selector_id); if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " has no type"; } uint32_t type_id = type_id_iter->second; if (selector_id == type_id) { // Recall that by convention, a result ID that is a type definition // maps to itself. return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is a type, not a value"; } if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id)) return error; if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT && parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is not a scalar integer"; } } else { assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); // The literal number type is determined by the type Id for the // constant. assert(inst->type_id); if (auto error = SetNumericTypeInfoForType(&parsed_operand, inst->type_id)) return error; } if (auto error = DecodeLiteralNumber(parsed_operand)) return error; break; case SPV_OPERAND_TYPE_LITERAL_STRING: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING; std::vector str; // The loop is expected to terminate once we encounter '\0' or exhaust // the bit stream. while (true) { char ch = 0; if (!reader_.ReadUnencoded(&ch)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal string"; str.push_back(ch); if (ch == '\0') break; } while (str.size() % 4 != 0) str.push_back('\0'); spirv_.resize(spirv_.size() + str.size() / 4); std::memcpy(&spirv_[first_word_index], str.data(), str.size()); if (SpvOpExtInstImport == opcode) { // Record the extended instruction type for the ID for this import. // There is only one string literal argument to OpExtInstImport, // so it's sufficient to guard this just on the opcode. const spv_ext_inst_type_t ext_inst_type = spvExtInstImportTypeGet(str.data()); if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction import '" << str.data() << "'"; } // We must have parsed a valid result ID. It's a condition // of the grammar, and we only accept non-zero result Ids. assert(inst->result_id); const bool inserted = import_id_to_ext_inst_type_.emplace( inst->result_id, ext_inst_type).second; (void)inserted; assert(inserted); } break; } case SPV_OPERAND_TYPE_CAPABILITY: case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: case SPV_OPERAND_TYPE_ADDRESSING_MODEL: case SPV_OPERAND_TYPE_MEMORY_MODEL: case SPV_OPERAND_TYPE_EXECUTION_MODE: case SPV_OPERAND_TYPE_STORAGE_CLASS: case SPV_OPERAND_TYPE_DIMENSIONALITY: case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: case SPV_OPERAND_TYPE_LINKAGE_TYPE: case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: case SPV_OPERAND_TYPE_DECORATION: case SPV_OPERAND_TYPE_BUILT_IN: case SPV_OPERAND_TYPE_GROUP_OPERATION: case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { // A single word that is a plain enum value. uint32_t word = 0; if (!DecodeOperandWord(type, &word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read enum"; spirv_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; spv_operand_desc entry; if (grammar_.lookupOperand(type, word, &entry)) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(parsed_operand.type) << " operand: " << word; } // Prepare to accept operands to this operand, if needed. spvPushOperandTypes(entry->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: case SPV_OPERAND_TYPE_FUNCTION_CONTROL: case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: { // This operand is a mask. uint32_t word = 0; if (!DecodeOperandWord(type, &word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read " << spvOperandTypeStr(type) << " for " << spvOpcodeString(SpvOp(inst->opcode)); spirv_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) parsed_operand.type = SPV_OPERAND_TYPE_IMAGE; else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; // Check validity of set mask bits. Also prepare for operands for those // masks if they have any. To get operand order correct, scan from // MSB to LSB since we can only prepend operands to a pattern. // The only case in the grammar where you have more than one mask bit // having an operand is for image operands. See SPIR-V 3.14 Image // Operands. uint32_t remaining_word = word; for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { if (remaining_word & mask) { spv_operand_desc entry; if (grammar_.lookupOperand(type, mask, &entry)) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(parsed_operand.type) << " operand: " << word << " has invalid mask component " << mask; } remaining_word ^= mask; spvPushOperandTypes(entry->operandTypes, expected_operands); } } if (word == 0) { // An all-zeroes mask *might* also be valid. spv_operand_desc entry; if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { // Prepare for its operands, if any. spvPushOperandTypes(entry->operandTypes, expected_operands); } } break; } default: return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Internal error: Unhandled operand type: " << type; } parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index); assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type)); assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type)); parsed_operands_.push_back(parsed_operand); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) { parsed_operands_.clear(); const size_t instruction_offset = spirv_.size(); bool read_result_id = false; while (true) { uint32_t word = 0; if (!reader_.ReadVariableWidthU32(&word, model_->opcode_chunk_length())) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read opcode of instruction"; } if (word >= kMarkvFirstOpcode) { if (word == kMarkvOpNextInstructionEncodesResultId) { read_result_id = true; } else { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Encountered unknown MARK-V opcode"; } } else { inst->opcode = static_cast(word); break; } } const SpvOp opcode = static_cast(inst->opcode); // Opcode/num_words placeholder, the word will be filled in later. spirv_.push_back(0); spv_opcode_desc opcode_desc; if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; } spv_operand_pattern_t expected_operands; expected_operands.reserve(opcode_desc->numTypes); for (auto i = 0; i < opcode_desc->numTypes; i++) expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); if (!OpcodeHasFixedNumberOfOperands(opcode)) { if (!reader_.ReadVariableWidthU16(&inst->num_operands, model_->num_operands_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read num_operands of instruction"; } else { inst->num_operands = static_cast(expected_operands.size()); } for (size_t operand_index = 0; operand_index < static_cast(inst->num_operands); ++operand_index) { assert(!expected_operands.empty()); const spv_operand_type_t type = spvTakeFirstMatchableOperand(&expected_operands); const size_t operand_offset = spirv_.size() - instruction_offset; const spv_result_t decode_result = DecodeOperand(instruction_offset, operand_offset, inst, type, &expected_operands, read_result_id); if (decode_result != SPV_SUCCESS) return decode_result; } assert(inst->num_operands == parsed_operands_.size()); // Only valid while spirv_ and parsed_operands_ remain unchanged. inst->words = &spirv_[instruction_offset]; inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); inst->num_words = static_cast(spirv_.size() - instruction_offset); spirv_[instruction_offset] = spvOpcodeMake(inst->num_words, SpvOp(inst->opcode)); assert(inst->num_words == std::accumulate( parsed_operands_.begin(), parsed_operands_.end(), 1, [](int num_words, const spv_parsed_operand_t& operand) { return num_words += operand.num_words; }) && "num_words in instruction doesn't correspond to the sum of num_words" "in the operands"); RecordNumberType(*inst); if (!ReadToByteBreakIfAgreed()) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; return SPV_SUCCESS; } spv_result_t MarkvDecoder::SetNumericTypeInfoForType( spv_parsed_operand_t* parsed_operand, uint32_t type_id) { assert(type_id != 0); auto type_info_iter = type_id_to_number_type_info_.find(type_id); if (type_info_iter == type_id_to_number_type_info_.end()) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a type"; } const NumberType& info = type_info_iter->second; if (info.type == SPV_NUMBER_NONE) { // This is a valid type, but for something other than a scalar number. return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a scalar numeric type"; } parsed_operand->number_kind = info.type; parsed_operand->number_bit_width = info.bit_width; // Round up the word count. parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); return SPV_SUCCESS; } void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) { const SpvOp opcode = static_cast(inst.opcode); if (spvOpcodeGeneratesType(opcode)) { NumberType info = {SPV_NUMBER_NONE, 0}; if (SpvOpTypeInt == opcode) { info.bit_width = inst.words[inst.operands[1].offset]; info.type = inst.words[inst.operands[2].offset] ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; } else if (SpvOpTypeFloat == opcode) { info.bit_width = inst.words[inst.operands[1].offset]; info.type = SPV_NUMBER_FLOATING; } // The *result* Id of a type generating instruction is the type Id. type_id_to_number_type_info_[inst.result_id] = info; } } spv_result_t EncodeHeader( void* user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t schema) { MarkvEncoder* encoder = reinterpret_cast(user_data); return encoder->EncodeHeader( endian, magic, version, generator, id_bound, schema); } spv_result_t EncodeInstruction( void* user_data, const spv_parsed_instruction_t* inst) { MarkvEncoder* encoder = reinterpret_cast(user_data); return encoder->EncodeInstruction(*inst); } } // namespace spv_result_t spvSpirvToMarkv(spv_const_context context, const uint32_t* spirv_words, const size_t spirv_num_words, spv_const_markv_encoder_options options, spv_markv_binary* markv_binary, spv_text* comments, spv_diagnostic* diagnostic) { spv_context_t hijack_context = *context; if (diagnostic) { *diagnostic = nullptr; libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); } spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words}; spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(&spirv_binary, &endian)) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } MarkvEncoder encoder(&hijack_context, options); if (comments) { encoder.CreateCommentsLogger(); spv_text text = nullptr; if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Failed to disassemble SPIR-V binary."; } assert(text); encoder.SetDisassembly(std::string(text->str, text->length)); spvTextDestroy(text); } if (spvBinaryParse( &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader, EncodeInstruction, diagnostic) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Unable to encode to MARK-V."; } if (comments) *comments = CreateSpvText(encoder.GetComments()); *markv_binary = encoder.GetMarkvBinary(); return SPV_SUCCESS; } spv_result_t spvMarkvToSpirv(spv_const_context context, const uint8_t* markv_data, size_t markv_size_bytes, spv_const_markv_decoder_options options, spv_binary* spirv_binary, spv_text* /* comments */, spv_diagnostic* diagnostic) { spv_position_t position = {}; spv_context_t hijack_context = *context; if (diagnostic) { *diagnostic = nullptr; libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); } MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options); std::vector words; if (decoder.DecodeModule(&words) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Unable to decode MARK-V."; } assert(!words.empty()); *spirv_binary = new spv_binary_t(); (*spirv_binary)->code = new uint32_t[words.size()]; (*spirv_binary)->wordCount = words.size(); std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size()); return SPV_SUCCESS; } void spvMarkvBinaryDestroy(spv_markv_binary binary) { if (!binary) return; delete[] binary->data; delete binary; } spv_markv_encoder_options spvMarkvEncoderOptionsCreate() { return new spv_markv_encoder_options_t; } void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) { delete options; } spv_markv_decoder_options spvMarkvDecoderOptionsCreate() { return new spv_markv_decoder_options_t; } void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) { delete options; }