// Copyright (c) 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Contains // - SPIR-V to MARK-V encoder // - MARK-V to SPIR-V decoder // // MARK-V is a compression format for SPIR-V binaries. It strips away // non-essential information (such as result ids which can be regenerated) and // uses various bit reduction techiniques to reduce the size of the binary. // // MarkvModel is a flatbuffers object containing a set of rules defining how // compression/decompression is done (coding schemes, dictionaries). #include #include #include #include #include #include #include #include #include #include #include #include "spirv/1.2/GLSL.std.450.h" #include "spirv/1.2/OpenCL.std.h" #include "spirv/1.2/spirv.h" #include "binary.h" #include "diagnostic.h" #include "enum_string_mapping.h" #include "extensions.h" #include "ext_inst.h" #include "id_descriptor.h" #include "instruction.h" #include "markv_autogen.h" #include "opcode.h" #include "operand.h" #include "spirv-tools/libspirv.h" #include "spirv-tools/markv.h" #include "spirv_endian.h" #include "spirv_validator_options.h" #include "util/bit_stream.h" #include "util/huffman_codec.h" #include "util/move_to_front.h" #include "util/parse_number.h" #include "validate.h" #include "val/instruction.h" #include "val/validation_state.h" using libspirv::IdDescriptorCollection; using libspirv::Instruction; using libspirv::ValidationState_t; using spvtools::ValidateInstructionAndUpdateValidationState; using spvutils::BitReaderWord64; using spvutils::BitWriterWord64; using spvutils::HuffmanCodec; using MoveToFront = spvutils::MoveToFront; using MultiMoveToFront = spvutils::MultiMoveToFront; struct spv_markv_encoder_options_t { }; struct spv_markv_decoder_options_t { }; namespace { const uint32_t kSpirvMagicNumber = SpvMagicNumber; const uint32_t kMarkvMagicNumber = 0x07230303; // Handles for move-to-front sequences. Enums which end with "Begin" define // handle spaces which start at that value and span 16 or 32 bit wide. enum : uint64_t { kMtfNone = 0, // All ids. kMtfAll, // All forward declared ids. kMtfForwardDeclared, // All type ids except for generated by OpTypeFunction. kMtfTypeNonFunction, // All labels. kMtfLabel, // All ids created by instructions which had type_id. kMtfObject, // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool. kMtfTypeScalar, // All composite types. kMtfTypeComposite, // Boolean type or any vector type of it. kMtfTypeBoolScalarOrVector, // All float types or any vector floats type. kMtfTypeFloatScalarOrVector, // All int types or any vector int type. kMtfTypeIntScalarOrVector, // All types declared as return types in OpTypeFunction. kMtfTypeReturnedByFunction, // All object ids which are integer constants. kMtfConstInteger, // All composite objects. kMtfComposite, // All bool objects or vectors of bools. kMtfBoolScalarOrVector, // All float objects or vectors of float. kMtfFloatScalarOrVector, // All int objects or vectors of int. kMtfIntScalarOrVector, // All pointer types which point to composited. kMtfTypePointerToComposite, // Used by EncodeMtfRankHuffman. kMtfGenericNonZeroRank, // Handle space for ids of specific type. kMtfIdOfTypeBegin = 0x10000, // Handle space for ids generated by specific opcode. kMtfIdGeneratedByOpcode = 0x20000, // Handle space for ids of objects with type generated by specific opcode. kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000, // All vectors of specific component type. kMtfVectorOfComponentTypeBegin = 0x40000, // All vector types of specific size. kMtfTypeVectorOfSizeBegin = 0x50000, // All pointer types to specific type. kMtfPointerToTypeBegin = 0x60000, // All function types which return specific type. kMtfFunctionTypeWithReturnTypeBegin = 0x70000, // All function objects which return specific type. kMtfFunctionWithReturnTypeBegin = 0x80000, // All float vectors of specific size. kMtfFloatVectorOfSizeBegin = 0x90000, // Id descriptor space (32-bit). kMtfIdDescriptorSpaceBegin = 0x100000000, }; // Used by "presumed index" technique which does special treatment of integer // constants no greater than this value. const uint32_t kMarkvMaxPresumedAccessIndex = 31; // Signals that the value is not in the coding scheme and a fallback method // needs to be used. const uint64_t kMarkvNoneOfTheAbove = GetMarkvNonOfTheAbove(); // Mtf ranks smaller than this are encoded with Huffman coding. const uint32_t kMtfSmallestRankEncodedByValue = 10; // Signals that the mtf rank is too large to be encoded with Huffman. const uint32_t kMtfRankEncodedByValueSignal = std::numeric_limits::max(); const size_t kCommentNumWhitespaces = 2; const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8; // Returns a set of mtf rank codecs based on a plausible hand-coded // distribution. std::map>> GetMtfHuffmanCodecs() { std::map>> codecs; std::unique_ptr> codec; codec.reset(new HuffmanCodec(std::map({ { 0, 5 }, { 1, 40 }, { 2, 10 }, { 3, 5 }, { 4, 5 }, { 5, 5 }, { 6, 3 }, { 7, 3 }, { 8, 3 }, { 9, 3 }, { kMtfRankEncodedByValueSignal, 10 }, }))); codecs.emplace(kMtfAll, std::move(codec)); codec.reset(new HuffmanCodec(std::map({ { 1, 50 }, { 2, 20 }, { 3, 5 }, { 4, 5 }, { 5, 2 }, { 6, 1 }, { 7, 1 }, { 8, 1 }, { 9, 1 }, { kMtfRankEncodedByValueSignal, 10 }, }))); codecs.emplace(kMtfGenericNonZeroRank, std::move(codec)); return codecs; } // Encoding/decoding model containing various constants and codecs. class MarkvModel { public: MarkvModel() : mtf_huffman_codecs_(GetMtfHuffmanCodecs()), opcode_and_num_operands_huffman_codec_(GetOpcodeAndNumOperandsHist()), opcode_and_num_operands_markov_huffman_codecs_( GetOpcodeAndNumOperandsMarkovHuffmanCodecs()), non_id_word_huffman_codecs_(GetNonIdWordHuffmanCodecs()), id_descriptor_huffman_codecs_(GetIdDescriptorHuffmanCodecs()), literal_string_huffman_codecs_(GetLiteralStringHuffmanCodecs()) {} size_t opcode_chunk_length() const { return 7; } size_t num_operands_chunk_length() const { return 3; } size_t mtf_rank_chunk_length() const { return 5; } size_t u16_chunk_length() const { return 4; } size_t s16_chunk_length() const { return 4; } size_t s16_block_exponent() const { return 6; } size_t u32_chunk_length() const { return 8; } size_t s32_chunk_length() const { return 8; } size_t s32_block_exponent() const { return 10; } size_t u64_chunk_length() const { return 8; } size_t s64_chunk_length() const { return 8; } size_t s64_block_exponent() const { return 10; } // Returns Huffman codec for ranks of the mtf with given |handle|. // Different mtfs can use different rank distributions. // May return nullptr if the codec doesn't exist. const HuffmanCodec* GetMtfHuffmanCodec(uint64_t handle) const { const auto it = mtf_huffman_codecs_.find(handle); if (it == mtf_huffman_codecs_.end()) return nullptr; return it->second.get(); } // Returns a codec for common opcode_and_num_operands words for the given // previous opcode. May return nullptr if the codec doesn't exist. const HuffmanCodec* GetOpcodeAndNumOperandsMarkovHuffmanCodec( uint32_t prev_opcode) const { if (prev_opcode == SpvOpNop) return &opcode_and_num_operands_huffman_codec_; const auto it = opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode); if (it == opcode_and_num_operands_markov_huffman_codecs_.end()) return nullptr; return it->second.get(); } // Returns a codec for common non-id words used for given operand slot. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. const HuffmanCodec* GetNonIdWordHuffmanCodec( uint32_t opcode, uint32_t operand_index) const { const auto it = non_id_word_huffman_codecs_.find( std::pair(opcode, operand_index)); if (it == non_id_word_huffman_codecs_.end()) return nullptr; return it->second.get(); } // Returns a codec for common id descriptos used for given operand slot. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. const HuffmanCodec* GetIdDescriptorHuffmanCodec( uint32_t opcode, uint32_t operand_index) const { const auto it = id_descriptor_huffman_codecs_.find( std::pair(opcode, operand_index)); if (it == id_descriptor_huffman_codecs_.end()) return nullptr; return it->second.get(); } // Returns a codec for common strings used by the given opcode. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. const HuffmanCodec* GetLiteralStringHuffmanCodec( uint32_t opcode) const { const auto it = literal_string_huffman_codecs_.find(opcode); if (it == literal_string_huffman_codecs_.end()) return nullptr; return it->second.get(); } private: // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't // need to contain a different codec for every handle as most use one and the // same. std::map>> mtf_huffman_codecs_; // Huffman codec for base-rate of opcode_and_num_operands. HuffmanCodec opcode_and_num_operands_huffman_codec_; // Huffman codecs for opcode_and_num_operands. The map key is previous opcode. std::map>> opcode_and_num_operands_markov_huffman_codecs_; // Huffman codecs for non-id single-word operand values. // The map key is pair . std::map, std::unique_ptr>> non_id_word_huffman_codecs_; // Huffman codecs for id descriptors. The map key is pair // . std::map, std::unique_ptr>> id_descriptor_huffman_codecs_; // Huffman codecs for literal strings. The map key is the opcode of the // current instruction. This assumes, that there is no more than one literal // string operand per instruction, but would still work even if this is not // the case. Names and debug information strings are not collected. std::map>> literal_string_huffman_codecs_; }; const MarkvModel* GetDefaultModel() { static MarkvModel model; return &model; } // Returns chunk length used for variable length encoding of spirv operand // words. Returns zero if operand type corresponds to potentially multiple // words or a word which is not expected to profit from variable width encoding. // Chunk length is selected based on the size of expected value. // Most of these values will later be encoded with probability-based coding, // but variable width integer coding is a good quick solution. // TODO(atgoo@github.com): Put this in MarkvModel flatbuffer. size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) { switch (type) { case SPV_OPERAND_TYPE_TYPE_ID: return 4; case SPV_OPERAND_TYPE_RESULT_ID: case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: return 8; case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: return 6; case SPV_OPERAND_TYPE_CAPABILITY: return 6; case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: return 3; case SPV_OPERAND_TYPE_ADDRESSING_MODEL: case SPV_OPERAND_TYPE_MEMORY_MODEL: return 2; case SPV_OPERAND_TYPE_EXECUTION_MODE: return 6; case SPV_OPERAND_TYPE_STORAGE_CLASS: return 4; case SPV_OPERAND_TYPE_DIMENSIONALITY: case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: return 3; case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: return 2; case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: return 6; case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: case SPV_OPERAND_TYPE_LINKAGE_TYPE: case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: return 2; case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: return 3; case SPV_OPERAND_TYPE_DECORATION: case SPV_OPERAND_TYPE_BUILT_IN: return 6; case SPV_OPERAND_TYPE_GROUP_OPERATION: case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: return 2; case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: case SPV_OPERAND_TYPE_FUNCTION_CONTROL: case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: return 4; case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: return 6; default: return 0; } return 0; } // Returns true if the opcode has a fixed number of operands. May return a // false negative. bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { switch (opcode) { // TODO(atgoo@github.com) This is not a complete list. case SpvOpNop: case SpvOpName: case SpvOpUndef: case SpvOpSizeOf: case SpvOpLine: case SpvOpNoLine: case SpvOpDecorationGroup: case SpvOpExtension: case SpvOpExtInstImport: case SpvOpMemoryModel: case SpvOpCapability: case SpvOpTypeVoid: case SpvOpTypeBool: case SpvOpTypeInt: case SpvOpTypeFloat: case SpvOpTypeVector: case SpvOpTypeMatrix: case SpvOpTypeSampler: case SpvOpTypeSampledImage: case SpvOpTypeArray: case SpvOpTypePointer: case SpvOpConstantTrue: case SpvOpConstantFalse: case SpvOpLabel: case SpvOpBranch: case SpvOpFunction: case SpvOpFunctionParameter: case SpvOpFunctionEnd: case SpvOpBitcast: case SpvOpCopyObject: case SpvOpTranspose: case SpvOpSNegate: case SpvOpFNegate: case SpvOpIAdd: case SpvOpFAdd: case SpvOpISub: case SpvOpFSub: case SpvOpIMul: case SpvOpFMul: case SpvOpUDiv: case SpvOpSDiv: case SpvOpFDiv: case SpvOpUMod: case SpvOpSRem: case SpvOpSMod: case SpvOpFRem: case SpvOpFMod: case SpvOpVectorTimesScalar: case SpvOpMatrixTimesScalar: case SpvOpVectorTimesMatrix: case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: case SpvOpOuterProduct: case SpvOpDot: return true; default: break; } return false; } size_t GetNumBitsToNextByte(size_t bit_pos) { return (8 - (bit_pos % 8)) % 8; } // Defines and returns current MARK-V version. uint32_t GetMarkvVersion() { const uint32_t kVersionMajor = 1; const uint32_t kVersionMinor = 1; return kVersionMinor | (kVersionMajor << 16); } class CommentLogger { public: void AppendText(const std::string& str) { Append(str); use_delimiter_ = false; } void AppendTextNewLine(const std::string& str) { Append(str); Append("\n"); use_delimiter_ = false; } void AppendBitSequence(const std::string& str) { if (use_delimiter_) Append("-"); Append(str); use_delimiter_ = true; } void AppendWhitespaces(size_t num) { Append(std::string(num, ' ')); use_delimiter_ = false; } void NewLine() { Append("\n"); use_delimiter_ = false; } std::string GetText() const { return ss_.str(); } private: void Append(const std::string& str) { ss_ << str; // std::cerr << str; } std::stringstream ss_; // If true a delimiter will be appended before the next bit sequence. // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. bool use_delimiter_ = false; }; // Creates spv_text object containing text from |str|. // The returned value is owned by the caller and needs to be destroyed with // spvTextDestroy. spv_text CreateSpvText(const std::string& str) { spv_text out = new spv_text_t(); assert(out); char* cstr = new char[str.length() + 1]; assert(cstr); std::strncpy(cstr, str.c_str(), str.length()); cstr[str.length()] = '\0'; out->str = cstr; out->length = str.length(); return out; } // Base class for MARK-V encoder and decoder. Contains common functionality // such as: // - Validator connection and validation state. // - SPIR-V grammar and helper functions. class MarkvCodecBase { public: virtual ~MarkvCodecBase() { spvValidatorOptionsDestroy(validator_options_); } MarkvCodecBase() = delete; void SetModel(const MarkvModel* model) { model_ = model; } protected: struct MarkvHeader { MarkvHeader() { magic_number = kMarkvMagicNumber; markv_version = GetMarkvVersion(); markv_model = 0; markv_length_in_bits = 0; spirv_version = 0; spirv_generator = 0; } uint32_t magic_number; uint32_t markv_version; // Magic number to identify or verify MarkvModel used for encoding. uint32_t markv_model; uint32_t markv_length_in_bits; uint32_t spirv_version; uint32_t spirv_generator; }; explicit MarkvCodecBase(spv_const_context context, spv_validator_options validator_options) : validator_options_(validator_options), vstate_(context, validator_options_), grammar_(context), model_(GetDefaultModel()) {} // Validates a single instruction and updates validation state of the module. spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) { return ValidateInstructionAndUpdateValidationState(&vstate_, &inst); } // Returns instruction which created |id| or nullptr if such instruction was // not registered. const Instruction* GetDefInst(uint32_t id) const { const auto it = vstate_.all_definitions().find(id); if (it == vstate_.all_definitions().end()) return nullptr; return it->second; } // Returns type id of vector type component. uint32_t GetVectorComponentType(uint32_t vector_type_id) const { const auto it = vstate_.all_definitions().find(vector_type_id); assert(it != vstate_.all_definitions().end()); const Instruction* type_inst = it->second; assert(type_inst->opcode() == SpvOpTypeVector); const uint32_t component_type = type_inst->word(type_inst->operands()[1].offset); return component_type; } // Returns mtf handle for ids of given type. uint64_t GetMtfIdOfType(uint32_t type_id) const { return kMtfIdOfTypeBegin + type_id; } // Returns mtf handle for ids generated by given opcode. uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const { return kMtfIdGeneratedByOpcode + opcode; } // Returns mtf handle for ids of type generated by given opcode. uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const { return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode; } // Returns mtf handle for vectors of specific component type. uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const { return kMtfVectorOfComponentTypeBegin + type_id; } // Returns mtf handle for float vectors of specific size. uint64_t GetMtfFloatVectorOfSize(uint32_t size) const { return kMtfFloatVectorOfSizeBegin + size; } // Returns mtf handle for vector type of specific size. uint64_t GetMtfTypeVectorOfSize(uint32_t size) const { return kMtfTypeVectorOfSizeBegin + size; } // Returns mtf handle for pointers to specific size. uint64_t GetMtfPointerToType(uint32_t type_id) const { return kMtfPointerToTypeBegin + type_id; } // Returns mtf handle for function types with given return type. uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const { return kMtfFunctionTypeWithReturnTypeBegin + type_id; } // Returns mtf handle for functions with given return type. uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const { return kMtfFunctionWithReturnTypeBegin + type_id; } // Returns mtf handle for the given id descriptor. uint64_t GetMtfIdDescriptor(uint32_t descriptor) const { return kMtfIdDescriptorSpaceBegin + descriptor; } // Process data from the current instruction. This would update MTFs and // other data containers. void ProcessCurInstruction(); // Returns move-to-front handle to be used for the current operand slot. // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar. uint64_t GetRuleBasedMtf(); // Returns words of the current instruction. Decoder has a different // implementation and the array is valid only until the previously decoded // word. virtual const uint32_t* GetInstWords() const { return inst_.words; } // Returns the opcode of the previous instruction. SpvOp GetPrevOpcode() const { if (instructions_.empty()) return SpvOpNop; return instructions_.back()->opcode(); } spv_validator_options validator_options_ = nullptr; ValidationState_t vstate_; const libspirv::AssemblyGrammar grammar_; MarkvHeader header_; const MarkvModel* model_ = nullptr; // Current instruction, current operand and current operand index. spv_parsed_instruction_t inst_; spv_parsed_operand_t operand_; uint32_t operand_index_; // Maps a result ID to its type ID. By convention: // - a result ID that is a type definition maps to itself. // - a result ID without a type maps to 0. (E.g. for OpLabel) std::unordered_map id_to_type_id_; // Container for all move-to-front sequences. MultiMoveToFront multi_mtf_; // Id of the current function or zero if outside of function. uint32_t cur_function_id_ = 0; // Return type of the current function. uint32_t cur_function_return_type_ = 0; // Remaining function parameter types. This container is filled on OpFunction, // and drained on OpFunctionParameter. std::list remaining_function_parameter_types_; // List of ids local to the current function. std::vector ids_local_to_cur_function_; // List of instructions in the order they are given in the module. std::vector instructions_; // Maps used for the 'presumed id' techniques. Maps small constant integer // value to its id and back. std::map presumed_index_to_id_; std::map id_to_presumed_index_; // Container/computer for id descriptors. IdDescriptorCollection id_descriptors_; }; // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and // EncodeInstruction which can be used as callback by spvBinaryParse. // Encoded binary is written to an internally maintained bitstream. // After the last instruction is encoded, the resulting MARK-V binary can be // acquired by calling GetMarkvBinary(). // The encoder uses SPIR-V validator to keep internal state, therefore // SPIR-V binary needs to be able to pass validator checks. // CreateCommentsLogger() can be used to enable the encoder to write comments // on how encoding was done, which can later be accessed with GetComments(). class MarkvEncoder : public MarkvCodecBase { public: MarkvEncoder(spv_const_context context, spv_const_markv_encoder_options options) : MarkvCodecBase(context, GetValidatorOptions(options)), options_(options) { (void) options_; } // Writes data from SPIR-V header to MARK-V header. spv_result_t EncodeHeader( spv_endianness_t /* endian */, uint32_t /* magic */, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t /* schema */) { vstate_.setIdBound(id_bound); header_.spirv_version = version; header_.spirv_generator = generator; return SPV_SUCCESS; } // Encodes SPIR-V instruction to MARK-V and writes to bit stream. // Operation can fail if the instruction fails to pass the validator or if // the encoder stubmles on something unexpected. spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); // Concatenates MARK-V header and the bit stream with encoded instructions // into a single buffer and returns it as spv_markv_binary. The returned // value is owned by the caller and needs to be destroyed with // spvMarkvBinaryDestroy(). spv_markv_binary GetMarkvBinary() { header_.markv_length_in_bits = static_cast(sizeof(header_) * 8 + writer_.GetNumBits()); const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); spv_markv_binary markv_binary = new spv_markv_binary_t(); markv_binary->data = new uint8_t[num_bytes]; markv_binary->length = num_bytes; assert(writer_.GetData()); std::memcpy(markv_binary->data, &header_, sizeof(header_)); std::memcpy(markv_binary->data + sizeof(header_), writer_.GetData(), writer_.GetDataSizeBytes()); return markv_binary; } // Creates an internal logger which writes comments on the encoding process. // Output can later be accessed with GetComments(). void CreateCommentsLogger() { logger_.reset(new CommentLogger()); writer_.SetCallback([this](const std::string& str){ logger_->AppendBitSequence(str); }); } // Optionally adds disassembly to the comments. // Disassembly should contain all instructions in the module separated by // \n, and no header. void SetDisassembly(std::string&& disassembly) { disassembly_.reset(new std::stringstream(std::move(disassembly))); } // Extracts the next instruction line from the disassembly and logs it. void LogDisassemblyInstruction() { if (logger_ && disassembly_) { std::string line; std::getline(*disassembly_, line, '\n'); logger_->AppendTextNewLine(line); } } // Extracts the text from the comment logger. std::string GetComments() const { if (!logger_) return ""; return logger_->GetText(); } private: // Creates and returns validator options. Return value owned by the caller. static spv_validator_options GetValidatorOptions( spv_const_markv_encoder_options) { return spvValidatorOptionsCreate(); } // Writes a single word to bit stream. operand_.type determines if the word is // encoded and how. spv_result_t EncodeNonIdWord(uint32_t word); // Writes both opcode and num_operands as a single code. // Returns SPV_UNSUPPORTED iff no suitable codec was found. spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, uint32_t num_operands); // Writes mtf rank to bit stream. |mtf| is used to determine the codec // scheme. |fallback_method| is used if no codec defined for |mtf|. spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, uint64_t fallback_method); // Writes id using coding based on mtf associated with the id descriptor. // Returns SPV_UNSUPPORTED iff fallback method needs to be used. spv_result_t EncodeIdWithDescriptor(uint32_t id); // Writes id using coding based on the given |mtf|, which is expected to // contain the given |id|. spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id); // Writes type id of the current instruction if can't be inferred. spv_result_t EncodeTypeId(); // Writes result id of the current instruction if can't be inferred. spv_result_t EncodeResultId(); // Writes ids which are neither type nor result ids. spv_result_t EncodeRefId(uint32_t id); // Writes bits to the stream until the beginning of the next byte if the // number of bits until the next byte is less than |byte_break_if_less_than|. void AddByteBreak(size_t byte_break_if_less_than); // Encodes a literal number operand and writes it to the bit stream. spv_result_t EncodeLiteralNumber(const Instruction& instruction, const spv_parsed_operand_t& operand); spv_const_markv_encoder_options options_; // Bit stream where encoded instructions are written. BitWriterWord64 writer_; // If not nullptr, encoder will write comments. std::unique_ptr logger_; // If not nullptr, disassembled instruction lines will be written to comments. // Format: \n separated instruction lines, no header. std::unique_ptr disassembly_; }; // Decodes MARK-V buffers written by MarkvEncoder. class MarkvDecoder : public MarkvCodecBase { public: MarkvDecoder(spv_const_context context, const uint8_t* markv_data, size_t markv_size_bytes, spv_const_markv_decoder_options options) : MarkvCodecBase(context, GetValidatorOptions(options)), options_(options), reader_(markv_data, markv_size_bytes) { (void) options_; vstate_.setIdBound(1); parsed_operands_.reserve(25); inst_words_.reserve(25); } // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. // Can be called only once. Fails if data of wrong format or ends prematurely, // of if validation fails. spv_result_t DecodeModule(std::vector* spirv_binary); private: // Describes the format of a typed literal number. struct NumberType { spv_number_kind_t type; uint32_t bit_width; }; // Creates and returns validator options. Return value owned by the caller. static spv_validator_options GetValidatorOptions( spv_const_markv_decoder_options) { return spvValidatorOptionsCreate(); } // Reads a single bit from reader_. The read bit is stored in |bit|. // Returns false iff reader_ fails. bool ReadBit(bool* bit) { uint64_t bits = 0; const bool result = reader_.ReadBits(&bits, 1); if (result) *bit = bits ? true : false; return result; }; // Returns ReadBit bound to the class object. std::function GetReadBitCallback() { return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1); } // Reads a single non-id word from bit stream. operand_.type determines if // the word needs to be decoded and how. spv_result_t DecodeNonIdWord(uint32_t* word); // Reads and decodes both opcode and num_operands as a single code. // Returns SPV_UNSUPPORTED iff no suitable codec was found. spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode, uint32_t* num_operands); // Reads mtf rank from bit stream. |mtf| is used to determine the codec // scheme. |fallback_method| is used if no codec defined for |mtf|. spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, uint32_t* rank); // Reads id using coding based on mtf associated with the id descriptor. // Returns SPV_UNSUPPORTED iff fallback method needs to be used. spv_result_t DecodeIdWithDescriptor(uint32_t* id); // Reads id using coding based on the given |mtf|, which is expected to // contain the needed |id|. spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id); // Reads type id of the current instruction if can't be inferred. spv_result_t DecodeTypeId(); // Reads result id of the current instruction if can't be inferred. spv_result_t DecodeResultId(); // Reads id which is neither type nor result id. spv_result_t DecodeRefId(uint32_t* id); // Reads and discards bits until the beginning of the next byte if the // number of bits until the next byte is less than |byte_break_if_less_than|. bool ReadToByteBreak(size_t byte_break_if_less_than); // Returns instruction words decoded up to this point. const uint32_t* GetInstWords() const override { return inst_words_.data(); } // Reads a literal number as it is described in |operand| from the bit stream, // decodes and writes it to spirv_. spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); // Reads instruction from bit stream, decodes and validates it. // Decoded instruction is valid until the next call of DecodeInstruction(). spv_result_t DecodeInstruction(); // Read operand from the stream decodes and validates it. spv_result_t DecodeOperand(size_t operand_offset, const spv_operand_type_t type, spv_operand_pattern_t* expected_operands); // Records the numeric type for an operand according to the type information // associated with the given non-zero type Id. This can fail if the type Id // is not a type Id, or if the type Id does not reference a scalar numeric // type. On success, return SPV_SUCCESS and populates the num_words, // number_kind, and number_bit_width fields of parsed_operand. spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, uint32_t type_id); // Records the number type for the current instruction, if it generates a // type. For types that aren't scalar numbers, record something with number // kind SPV_NUMBER_NONE. void RecordNumberType(); spv_const_markv_decoder_options options_; // Temporary sink where decoded SPIR-V words are written. Once it contains the // entire module, the container is moved and returned. std::vector spirv_; // Bit stream containing encoded data. BitReaderWord64 reader_; // Temporary storage for operands of the currently parsed instruction. // Valid until next DecodeInstruction call. std::vector parsed_operands_; // Temporary storage for current instruction words. // Valid until next DecodeInstruction call. std::vector inst_words_; // Maps a type ID to its number type description. std::unordered_map type_id_to_number_type_info_; // Maps an ExtInstImport id to the extended instruction type. std::unordered_map import_id_to_ext_inst_type_; }; void MarkvCodecBase::ProcessCurInstruction() { const SpvOp opcode = SpvOp(inst_.opcode); if (inst_.result_id) { // Collect ids local to the current function. if (cur_function_id_){ ids_local_to_cur_function_.push_back(inst_.result_id); } // Starting new function. if (opcode == SpvOpFunction) { cur_function_id_ = inst_.result_id; cur_function_return_type_ = inst_.type_id; multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id), inst_.result_id); // Store function parameter types in a queue, so that we know which types // to expect in the following OpFunctionParameter instructions. const Instruction* def_inst = GetDefInst(inst_.words[4]); assert(def_inst); assert(def_inst->opcode() == SpvOpTypeFunction); for (uint32_t i = 3; i < def_inst->words().size(); ++i) { remaining_function_parameter_types_.push_back(def_inst->word(i)); } } } // Remove local ids from MTFs if function end. if (opcode == SpvOpFunctionEnd) { cur_function_id_ = 0; for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id); ids_local_to_cur_function_.clear(); assert(remaining_function_parameter_types_.empty()); } if (!inst_.result_id) return; { // Save the result ID to type ID mapping. // In the grammar, type ID always appears before result ID. // A regular value maps to its type. Some instructions (e.g. OpLabel) // have no type Id, and will map to 0. The result Id for a // type-generating instruction (e.g. OpTypeInt) maps to itself. auto insertion_result = id_to_type_id_.emplace( inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode)) ? inst_.result_id : inst_.type_id); (void)insertion_result; assert(insertion_result.second); } // Add result_id to MTFs. switch (opcode) { case SpvOpTypeFloat: case SpvOpTypeInt: case SpvOpTypeBool: case SpvOpTypeVector: case SpvOpTypePointer: case SpvOpExtInstImport: case SpvOpTypeSampledImage: case SpvOpTypeImage: case SpvOpTypeSampler: multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id); break; default: break; } if (spvOpcodeIsComposite(opcode)) { multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id); } if (opcode == SpvOpLabel) { multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id); } if (opcode == SpvOpTypeInt) { multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id); } if (opcode == SpvOpTypeFloat) { multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id); } if (opcode == SpvOpTypeBool) { multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id); } if (opcode == SpvOpTypeVector) { const uint32_t component_type_id = inst_.words[2]; const uint32_t size = inst_.words[3]; if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat), component_type_id)) { multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id); } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt), component_type_id)) { multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id); } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool), component_type_id)) { multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id); } multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id); } if (inst_.opcode == SpvOpTypeFunction) { const uint32_t return_type = inst_.words[2]; multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type); multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type), inst_.result_id); } if (inst_.type_id) { const Instruction* type_inst = GetDefInst(inst_.type_id); assert(type_inst); multi_mtf_.Insert(kMtfObject, inst_.result_id); multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id); if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) { multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id); } if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id)) multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id); if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id)) multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id); if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id)) multi_mtf_.Insert(kMtfComposite, inst_.result_id); if (inst_.opcode == SpvOpConstant) { if (multi_mtf_.HasValue( GetMtfIdGeneratedByOpcode(SpvOpTypeInt), inst_.type_id)) { multi_mtf_.Insert(kMtfConstInteger, inst_.result_id); const uint32_t presumed_index = inst_.words[3]; if (presumed_index <= kMarkvMaxPresumedAccessIndex) { const auto result = presumed_index_to_id_.emplace(presumed_index, inst_.result_id); if (result.second) { id_to_presumed_index_.emplace(inst_.result_id, presumed_index); } } } } switch (type_inst->opcode()) { case SpvOpTypeInt: case SpvOpTypeBool: case SpvOpTypePointer: case SpvOpTypeVector: case SpvOpTypeImage: case SpvOpTypeSampledImage: case SpvOpTypeSampler: multi_mtf_.Insert(GetMtfIdWithTypeGeneratedByOpcode( type_inst->opcode()), inst_.result_id); break; default: break; } if (type_inst->opcode() == SpvOpTypeVector) { const uint32_t component_type = type_inst->word(2); multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type), inst_.result_id); } if (type_inst->opcode() == SpvOpTypePointer) { assert(type_inst->operands().size() > 2); assert(type_inst->words().size() > type_inst->operands()[2].offset); const uint32_t data_type = type_inst->word(type_inst->operands()[2].offset); multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id); if (multi_mtf_.HasValue(kMtfTypeComposite, data_type)) multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id); } } if (spvOpcodeGeneratesType(opcode)) { if (opcode != SpvOpTypeFunction) { multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id); } } const uint32_t descriptor = id_descriptors_.ProcessInstruction(inst_); multi_mtf_.Insert(GetMtfIdDescriptor(descriptor), inst_.result_id); } uint64_t MarkvCodecBase::GetRuleBasedMtf() { // This function is only called for id operands (but not result ids). assert(spvIsIdType(operand_.type) || operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID); assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID); const SpvOp opcode = static_cast(inst_.opcode); // All operand slots which expect label id. if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) || (inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) || (inst_.opcode == SpvOpBranch && operand_index_ == 0) || (inst_.opcode == SpvOpBranchConditional && (operand_index_ == 1 || operand_index_ == 2 )) || (inst_.opcode == SpvOpPhi && operand_index_ >= 3 && operand_index_ % 2 == 1) || (inst_.opcode == SpvOpSwitch && operand_index_ > 0)) { return kMtfLabel; } switch (opcode) { case SpvOpFAdd: case SpvOpFSub: case SpvOpFMul: case SpvOpFDiv: case SpvOpFRem: case SpvOpFMod: case SpvOpFNegate: { if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector; return GetMtfIdOfType(inst_.type_id); } case SpvOpISub: case SpvOpIAdd: case SpvOpIMul: case SpvOpSDiv: case SpvOpUDiv: case SpvOpSMod: case SpvOpUMod: case SpvOpSRem: case SpvOpSNegate: { if (operand_index_ == 0) return kMtfTypeIntScalarOrVector; return kMtfIntScalarOrVector; } // TODO(atgoo@github.com) Add OpConvertFToU and other opcodes. case SpvOpFOrdEqual: case SpvOpFUnordEqual: case SpvOpFOrdNotEqual: case SpvOpFUnordNotEqual: case SpvOpFOrdLessThan: case SpvOpFUnordLessThan: case SpvOpFOrdGreaterThan: case SpvOpFUnordGreaterThan: case SpvOpFOrdLessThanEqual: case SpvOpFUnordLessThanEqual: case SpvOpFOrdGreaterThanEqual: case SpvOpFUnordGreaterThanEqual: { if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector; if (operand_index_ == 2) return kMtfFloatScalarOrVector; if (operand_index_ == 3) { const uint32_t first_operand_id = GetInstWords()[3]; const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id); return GetMtfIdOfType(first_operand_type); } break; } case SpvOpVectorShuffle: { if (operand_index_ == 0) { assert(inst_.num_operands > 4); return GetMtfTypeVectorOfSize(inst_.num_operands - 4); } assert(inst_.type_id); if (operand_index_ == 2 || operand_index_ == 3) return GetMtfVectorOfComponentType( GetVectorComponentType(inst_.type_id)); break; } case SpvOpVectorTimesScalar: { if (operand_index_ == 0) { // TODO(atgoo@github.com) Could be narrowed to vector of floats. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); } assert(inst_.type_id); if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id); if (operand_index_ == 3) return GetMtfIdOfType(GetVectorComponentType(inst_.type_id)); break; } case SpvOpDot: { if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat); assert(inst_.type_id); if (operand_index_ == 2) return GetMtfVectorOfComponentType(inst_.type_id); if (operand_index_ == 3) { const uint32_t vector_id = GetInstWords()[3]; const uint32_t vector_type = id_to_type_id_.at(vector_id); return GetMtfIdOfType(vector_type); } break; } case SpvOpTypeVector: { if (operand_index_ == 1) { return kMtfTypeScalar; } break; } case SpvOpTypeMatrix: { if (operand_index_ == 1) { return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); } break; } case SpvOpTypePointer: { if (operand_index_ == 2) { return kMtfTypeNonFunction; } break; } case SpvOpTypeStruct: { if (operand_index_ >= 1) { return kMtfTypeNonFunction; } break; } case SpvOpTypeFunction: { if (operand_index_ == 1) { return kMtfTypeNonFunction; } if (operand_index_ >= 2) { return kMtfTypeNonFunction; } break; } case SpvOpLoad: { if (operand_index_ == 0) return kMtfTypeNonFunction; if (operand_index_ == 2) { assert(inst_.type_id); return GetMtfPointerToType(inst_.type_id); } break; } case SpvOpStore: { if (operand_index_ == 0) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer); if (operand_index_ == 1) { const uint32_t pointer_id = GetInstWords()[1]; const uint32_t pointer_type = id_to_type_id_.at(pointer_id); const auto it = vstate_.all_definitions().find(pointer_type); assert(it != vstate_.all_definitions().end()); const Instruction* pointer_inst = it->second; assert(pointer_inst->opcode() == SpvOpTypePointer); const uint32_t data_type = pointer_inst->word(pointer_inst->operands()[2].offset); return GetMtfIdOfType(data_type); } break; } case SpvOpVariable: { if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypePointer); break; } case SpvOpAccessChain: { if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypePointer); if (operand_index_ == 2) return kMtfTypePointerToComposite; if (operand_index_ >= 3) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt); break; } case SpvOpCompositeConstruct: { if (operand_index_ == 0) return kMtfTypeComposite; if (operand_index_ >= 2) { const uint32_t composite_type = GetInstWords()[1]; if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type)) return kMtfFloatScalarOrVector; if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type)) return kMtfIntScalarOrVector; if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type)) return kMtfBoolScalarOrVector; } break; } case SpvOpCompositeExtract: { if (operand_index_ == 2) return kMtfComposite; break; } case SpvOpConstantComposite: { if (operand_index_ == 0) return kMtfTypeComposite; if (operand_index_ >= 2) { const Instruction* composite_type_inst = GetDefInst(inst_.type_id); assert(composite_type_inst); if (composite_type_inst->opcode() == SpvOpTypeVector) { return GetMtfIdOfType(composite_type_inst->word(2)); } } break; } case SpvOpExtInst: { if (operand_index_ == 2) return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport); if (operand_index_ >= 4) { const uint32_t return_type = GetInstWords()[1]; const uint32_t ext_inst_type = inst_.ext_inst_type; const uint32_t ext_inst_index = GetInstWords()[4]; // TODO(atgoo@github.com) The list of extended instructions is // incomplete. Only common instructions and low-hanging fruits listed. if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) { switch (ext_inst_index) { case GLSLstd450FAbs: case GLSLstd450FClamp: case GLSLstd450FMax: case GLSLstd450FMin: case GLSLstd450FMix: case GLSLstd450Step: case GLSLstd450SmoothStep: case GLSLstd450Fma: case GLSLstd450Pow: case GLSLstd450Exp: case GLSLstd450Exp2: case GLSLstd450Log: case GLSLstd450Log2: case GLSLstd450Sqrt: case GLSLstd450InverseSqrt: case GLSLstd450Fract: case GLSLstd450Floor: case GLSLstd450Ceil: case GLSLstd450Radians: case GLSLstd450Degrees: case GLSLstd450Sin: case GLSLstd450Cos: case GLSLstd450Tan: case GLSLstd450Sinh: case GLSLstd450Cosh: case GLSLstd450Tanh: case GLSLstd450Asin: case GLSLstd450Acos: case GLSLstd450Atan: case GLSLstd450Atan2: case GLSLstd450Asinh: case GLSLstd450Acosh: case GLSLstd450Atanh: case GLSLstd450MatrixInverse: case GLSLstd450Cross: case GLSLstd450Normalize: case GLSLstd450Reflect: case GLSLstd450FaceForward: return GetMtfIdOfType(return_type); case GLSLstd450Length: case GLSLstd450Distance: case GLSLstd450Refract: return kMtfFloatScalarOrVector; default: break; } } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) { switch (ext_inst_index) { case OpenCLLIB::Fabs: case OpenCLLIB::FClamp: case OpenCLLIB::Fmax: case OpenCLLIB::Fmin: case OpenCLLIB::Step: case OpenCLLIB::Smoothstep: case OpenCLLIB::Fma: case OpenCLLIB::Pow: case OpenCLLIB::Exp: case OpenCLLIB::Exp2: case OpenCLLIB::Log: case OpenCLLIB::Log2: case OpenCLLIB::Sqrt: case OpenCLLIB::Rsqrt: case OpenCLLIB::Fract: case OpenCLLIB::Floor: case OpenCLLIB::Ceil: case OpenCLLIB::Radians: case OpenCLLIB::Degrees: case OpenCLLIB::Sin: case OpenCLLIB::Cos: case OpenCLLIB::Tan: case OpenCLLIB::Sinh: case OpenCLLIB::Cosh: case OpenCLLIB::Tanh: case OpenCLLIB::Asin: case OpenCLLIB::Acos: case OpenCLLIB::Atan: case OpenCLLIB::Atan2: case OpenCLLIB::Asinh: case OpenCLLIB::Acosh: case OpenCLLIB::Atanh: case OpenCLLIB::Cross: case OpenCLLIB::Normalize: return GetMtfIdOfType(return_type); case OpenCLLIB::Length: case OpenCLLIB::Distance: return kMtfFloatScalarOrVector; default: break; } } } break; } case SpvOpFunction: { if (operand_index_ == 0) return kMtfTypeReturnedByFunction; if (operand_index_ == 3) { const uint32_t return_type = GetInstWords()[1]; return GetMtfFunctionTypeWithReturnType(return_type); } break; } case SpvOpFunctionCall: { if (operand_index_ == 0) return kMtfTypeReturnedByFunction; if (operand_index_ == 2) { const uint32_t return_type = GetInstWords()[1]; return GetMtfFunctionWithReturnType(return_type); } if (operand_index_ >= 3) { const uint32_t function_id = GetInstWords()[3]; const auto function_it = vstate_.all_definitions().find(function_id); if (function_it == vstate_.all_definitions().end()) return kMtfObject; const Instruction* function_inst = function_it->second; assert(function_inst->opcode() == SpvOpFunction); const uint32_t function_type_id = function_inst->word(4); const auto function_type_it = vstate_.all_definitions().find(function_type_id); assert(function_type_it != vstate_.all_definitions().end()); const Instruction* function_type_inst = function_type_it->second; assert(function_type_inst->opcode() == SpvOpTypeFunction); const uint32_t argument_type = function_type_inst->word(operand_index_); return GetMtfIdOfType(argument_type); } break; } case SpvOpReturnValue: { if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_); break; } case SpvOpBranchConditional: { if (operand_index_ == 0) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool); break; } case SpvOpSampledImage: { if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage); if (operand_index_ == 2) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage); if (operand_index_ == 3) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler); break; } case SpvOpImageSampleImplicitLod: { if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); if (operand_index_ == 2) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage); if (operand_index_ == 3) return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector); break; } default: break; } return kMtfNone; } spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); if (codec) { uint64_t bits = 0; size_t num_bits = 0; if (codec->Encode(word, &bits, &num_bits)) { // Encoding successful. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // Encoding failed, write kMarkvNoneOfTheAbove flag. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Non-id word Huffman table for " << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " << operand_index_ << " is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); } } // Fallback encoding. const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type); if (chunk_length) { writer_.WriteVariableWidthU32(word, chunk_length); } else { writer_.WriteUnencoded(word); } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); if (codec) { uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to decode non-id word with Huffman"; if (decoded_value != kMarkvNoneOfTheAbove) { // The word decoded successfully. *word = uint32_t(decoded_value); assert(*word == decoded_value); return SPV_SUCCESS; } // Received kMarkvNoneOfTheAbove signal, use fallback decoding. } const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type); if (chunk_length) { if (!reader_.ReadVariableWidthU32(word, chunk_length)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to decode non-id word with varint"; } else { if (!reader_.ReadUnencoded(word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read unencoded non-id word"; } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands( uint32_t opcode, uint32_t num_operands) { uint64_t bits = 0; size_t num_bits = 0; const uint32_t word = opcode | (num_operands << 16); // First try to use the Markov chain codec. auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); if (codec) { if (codec->Encode(word, &bits, &num_bits)) { // The word was successfully encoded into bits/num_bits. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove // and use fallback encoding. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "opcode_and_num_operands Huffman table for " << spvOpcodeString(GetPrevOpcode()) << "is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); } } // Fallback to base-rate codec. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); assert(codec); if (codec->Encode(word, &bits, &num_bits)) { // The word was successfully encoded into bits/num_bits. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove // and return false. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Global opcode_and_num_operands Huffman table is missing " << "kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); return SPV_UNSUPPORTED; } } spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( uint32_t* opcode, uint32_t* num_operands) { // First try to use the Markov chain codec. auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); if (codec) { uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to decode opcode_and_num_operands, previous opcode is " << spvOpcodeString(GetPrevOpcode()); if (decoded_value != kMarkvNoneOfTheAbove) { // The word was successfully decoded. *opcode = uint32_t(decoded_value & 0xFFFF); *num_operands = uint32_t(decoded_value >> 16); return SPV_SUCCESS; } // Received kMarkvNoneOfTheAbove signal, use fallback decoding. } // Fallback to base-rate codec. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); assert(codec); uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to decode opcode_and_num_operands with global codec"; if (decoded_value == kMarkvNoneOfTheAbove) { // Received kMarkvNoneOfTheAbove signal, fallback further. return SPV_UNSUPPORTED; } *opcode = uint32_t(decoded_value & 0xFFFF); *num_operands = uint32_t(decoded_value >> 16); return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, uint64_t fallback_method) { const auto* codec = model_->GetMtfHuffmanCodec(mtf); if (!codec) { assert(fallback_method != kMtfNone); codec = model_->GetMtfHuffmanCodec(fallback_method); } if (!codec) return vstate_.diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; uint64_t bits = 0; size_t num_bits = 0; if (rank < kMtfSmallestRankEncodedByValue) { // Encode using Huffman coding. if (!codec->Encode(rank, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to encode MTF rank with Huffman"; writer_.WriteBits(bits, num_bits); } else { // Encode by value. if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to encode kMtfRankEncodedByValueSignal"; writer_.WriteBits(bits, num_bits); writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue, model_->mtf_rank_chunk_length()); } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeMtfRankHuffman( uint64_t mtf, uint32_t fallback_method, uint32_t* rank) { const auto* codec = model_->GetMtfHuffmanCodec(mtf); if (!codec) { assert(fallback_method != kMtfNone); codec = model_->GetMtfHuffmanCodec(fallback_method); } if (!codec) return vstate_.diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; uint32_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; if (decoded_value == kMtfRankEncodedByValueSignal) { // Decode by value. if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with varint"; *rank += kMtfSmallestRankEncodedByValue; } else { // Decode using Huffman coding. assert(decoded_value < kMtfSmallestRankEncodedByValue); *rank = decoded_value; } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); if (!codec) return SPV_UNSUPPORTED; uint64_t bits = 0; size_t num_bits = 0; // Get the descriptor for id. const uint32_t descriptor = id_descriptors_.GetDescriptor(id); if (descriptor && codec->Encode(descriptor, &bits, &num_bits)) { // If the descriptor exists and is in the table, write the descriptor and // proceed to encoding the rank. writer_.WriteBits(bits, num_bits); } else { // The descriptor doesn't exist or we have no coding for it. Write // kMarkvNoneOfTheAbove and go to fallback method. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Descriptor Huffman table for " << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " << operand_index_ << " is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); return SPV_UNSUPPORTED; } // Descriptor has been encoded. Now encode the rank of the id in the // associated mtf sequence. const uint64_t mtf = GetMtfIdDescriptor(descriptor); return EncodeExistingId(mtf, id); } spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); if (!codec) return SPV_UNSUPPORTED; uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to decode descriptor with Huffman"; if (decoded_value == kMarkvNoneOfTheAbove) return SPV_UNSUPPORTED; // If descriptor exists then the id was encoded through descriptor mtf. const uint32_t descriptor = uint32_t(decoded_value); assert(descriptor == decoded_value); assert(descriptor); const uint64_t mtf = GetMtfIdDescriptor(descriptor); return DecodeExistingId(mtf, id); } spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { assert(multi_mtf_.GetSize(mtf) > 0); if (multi_mtf_.GetSize(mtf) == 1) { // If the sequence has only one element no need to write rank, the decoder // would make the same decision. return SPV_SUCCESS; } uint32_t rank = 0; if (!multi_mtf_.RankFromValue(mtf, id, &rank)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); } spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { assert(multi_mtf_.GetSize(mtf) > 0); *id = 0; uint32_t rank = 0; if (multi_mtf_.GetSize(mtf) == 1) { rank = 1; } else { const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); if (result != SPV_SUCCESS) return result; } assert(rank); if (!multi_mtf_.ValueFromRank(mtf, rank, id)) return vstate_.diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { // TODO(atgoo@github.com) This might not be needed as EncodeIdWithDescriptor // can handle SpvOpAccessChain indices if enough statistics is collected. if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) { const auto it = id_to_presumed_index_.find(id); if (it != id_to_presumed_index_.end()) { writer_.WriteBits(1, 1); writer_.WriteFixedWidth(it->second, kMarkvMaxPresumedAccessIndex); return SPV_SUCCESS; } writer_.WriteBits(0, 1); } { // Try to encode using id descriptor mtfs. const spv_result_t result = EncodeIdWithDescriptor(id); if (result != SPV_UNSUPPORTED) return result; // If can't be done continue with other methods. } // Encode using rule-based mtf. uint64_t mtf = GetRuleBasedMtf(); const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_); if (mtf != kMtfNone && !can_forward_declare) { assert(multi_mtf_.HasValue(kMtfAll, id)); return EncodeExistingId(mtf, id); } if (mtf == kMtfNone) mtf = kMtfAll; uint32_t rank = 0; if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { // This is the first occurrence of a forward declared id. multi_mtf_.Insert(kMtfAll, id); multi_mtf_.Insert(kMtfForwardDeclared, id); if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); rank = 0; } return EncodeMtfRankHuffman(rank, mtf, kMtfAll); } spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) { uint64_t use_presumed_index_technique = 0; if (!reader_.ReadBits(&use_presumed_index_technique, 1)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read use_presumed_index_technique flag"; if (use_presumed_index_technique) { uint64_t value = 0; if (!reader_.ReadFixedWidth(&value, kMarkvMaxPresumedAccessIndex)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read presumed_index"; const uint32_t presumed_index = static_cast(value); const auto it = presumed_index_to_id_.find(presumed_index); if (it == presumed_index_to_id_.end()) { assert(0); return vstate_.diag(SPV_ERROR_INTERNAL) << "Presumed index id not found"; } *id = it->second; return SPV_SUCCESS; } } { const spv_result_t result = DecodeIdWithDescriptor(id); if (result != SPV_UNSUPPORTED) return result; } uint64_t mtf = GetRuleBasedMtf(); const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_); if (mtf != kMtfNone && !can_forward_declare) { return DecodeExistingId(mtf, id); } if (mtf == kMtfNone) mtf = kMtfAll; *id = 0; uint32_t rank = 0; { const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); if (result != SPV_SUCCESS) return result; } if (rank == 0) { // This is the first occurrence of a forward declared id. *id = vstate_.getIdBound(); vstate_.setIdBound(*id + 1); multi_mtf_.Insert(kMtfAll, *id); multi_mtf_.Insert(kMtfForwardDeclared, *id); if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); } else { if (!multi_mtf_.ValueFromRank(mtf, rank, id)) return vstate_.diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; } assert(*id); return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeTypeId() { if (inst_.opcode == SpvOpFunctionParameter) { assert(!remaining_function_parameter_types_.empty()); assert(inst_.type_id == remaining_function_parameter_types_.front()); remaining_function_parameter_types_.pop_front(); return SPV_SUCCESS; } { // Try to encode using id descriptor mtfs. const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); if (result != SPV_UNSUPPORTED) return result; // If can't be done continue with other methods. } uint64_t mtf = GetRuleBasedMtf(); assert(!spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_)); if (mtf == kMtfNone) { mtf = kMtfTypeNonFunction; // Function types should have been handled by GetRuleBasedMtf. assert(inst_.opcode != SpvOpFunction); } return EncodeExistingId(mtf, inst_.type_id); } spv_result_t MarkvDecoder::DecodeTypeId() { if (inst_.opcode == SpvOpFunctionParameter) { assert(!remaining_function_parameter_types_.empty()); inst_.type_id = remaining_function_parameter_types_.front(); remaining_function_parameter_types_.pop_front(); return SPV_SUCCESS; } { const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); if (result != SPV_UNSUPPORTED) return result; } uint64_t mtf = GetRuleBasedMtf(); assert(!spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_)); if (mtf == kMtfNone) { mtf = kMtfTypeNonFunction; // Function types should have been handled by GetRuleBasedMtf. assert(inst_.opcode != SpvOpFunction); } return DecodeExistingId(mtf, &inst_.type_id); } spv_result_t MarkvEncoder::EncodeResultId() { uint32_t rank = 0; const uint64_t num_still_forward_declared = multi_mtf_.GetSize(kMtfForwardDeclared); if (num_still_forward_declared) { // We write the rank only if kMtfForwardDeclared is not empty. If it is // empty the decoder knows that there are no forward declared ids to expect. if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { // This is a definition of a forward declared id. We can remove the id // from kMtfForwardDeclared. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to remove id from kMtfForwardDeclared"; writer_.WriteBits(1, 1); writer_.WriteVariableWidthU32( rank, model_->mtf_rank_chunk_length()); } else { rank = 0; writer_.WriteBits(0, 1); } } if (!rank) { multi_mtf_.Insert(kMtfAll, inst_.result_id); } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeResultId() { uint32_t rank = 0; const uint64_t num_still_forward_declared = multi_mtf_.GetSize(kMtfForwardDeclared); if (num_still_forward_declared) { // Some ids were forward declared. Check if this id is one of them. uint64_t id_was_forward_declared; if (!reader_.ReadBits(&id_was_forward_declared, 1)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id_was_forward_declared flag"; if (id_was_forward_declared) { if (!reader_.ReadVariableWidthU32( &rank, model_->mtf_rank_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read MTF rank of forward declared id"; if (rank) { // The id was forward declared, recover it from kMtfForwardDeclared. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, &inst_.result_id)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Forward declared MTF rank is out of bounds"; // We can now remove the id from kMtfForwardDeclared. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) return vstate_.diag(SPV_ERROR_INTERNAL) << "Failed to remove id from kMtfForwardDeclared"; } } } if (inst_.result_id == 0) { // The id was not forward declared, issue a new id. inst_.result_id = vstate_.getIdBound(); vstate_.setIdBound(inst_.result_id + 1); } if (!rank) { multi_mtf_.Insert(kMtfAll, inst_.result_id); } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeLiteralNumber( const Instruction& instruction, const spv_parsed_operand_t& operand) { if (operand.number_bit_width <= 32) { const uint32_t word = instruction.word(operand.offset); return EncodeNonIdWord(word); } else { assert(operand.number_bit_width <= 64); const uint64_t word = uint64_t(instruction.word(operand.offset)) | (uint64_t(instruction.word(operand.offset + 1)) << 32); if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; std::memcpy(&val, &word, 8); writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), model_->s64_block_exponent()); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { writer_.WriteUnencoded(word); } else { return vstate_.diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; } } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeLiteralNumber( const spv_parsed_operand_t& operand) { if (operand.number_bit_width <= 32) { uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); } else { assert(operand.number_bit_width <= 64); uint64_t word = 0; if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), model_->s64_block_exponent())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; std::memcpy(&word, &val, 8); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { if (!reader_.ReadUnencoded(&word)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; } else { return vstate_.diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; } inst_words_.push_back(static_cast(word)); inst_words_.push_back(static_cast(word >> 32)); } return SPV_SUCCESS; } void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { const size_t num_bits_to_next_byte = GetNumBitsToNextByte(writer_.GetNumBits()); if (num_bits_to_next_byte == 0 || num_bits_to_next_byte > byte_break_if_less_than) return; if (logger_) { logger_->AppendWhitespaces(kCommentNumWhitespaces); logger_->AppendText(""); } writer_.WriteBits(0, num_bits_to_next_byte); } bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { const size_t num_bits_to_next_byte = GetNumBitsToNextByte(reader_.GetNumReadBits()); if (num_bits_to_next_byte == 0 || num_bits_to_next_byte > byte_break_if_less_than) return true; uint64_t bits = 0; if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; assert(bits == 0); if (bits != 0) return false; return true; } spv_result_t MarkvEncoder::EncodeInstruction( const spv_parsed_instruction_t& inst) { SpvOp opcode = SpvOp(inst.opcode); inst_ = inst; const spv_result_t validation_result = UpdateValidationState(inst); if (validation_result != SPV_SUCCESS) return validation_result; const Instruction& instruction = vstate_.ordered_instructions().back(); const auto& operands = instruction.operands(); LogDisassemblyInstruction(); const spv_result_t opcode_encodig_result = EncodeOpcodeAndNumOperands(opcode, inst.num_operands); if (opcode_encodig_result < 0) return opcode_encodig_result; if (opcode_encodig_result != SPV_SUCCESS) { // Fallback encoding for opcode and num_operands. writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); if (!OpcodeHasFixedNumberOfOperands(opcode)) { // If the opcode has a variable number of operands, encode the number of // operands with the instruction. if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); writer_.WriteVariableWidthU16(inst.num_operands, model_->num_operands_chunk_length()); } } // Write operands. for (operand_index_ = 0; operand_index_ < operands.size(); ++operand_index_) { operand_ = operands[operand_index_]; if (logger_) { logger_->AppendWhitespaces(kCommentNumWhitespaces); logger_->AppendText("<"); logger_->AppendText(spvOperandTypeStr(operand_.type)); logger_->AppendText(">"); } switch (operand_.type) { case SPV_OPERAND_TYPE_RESULT_ID: case SPV_OPERAND_TYPE_TYPE_ID: case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_OPTIONAL_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { const uint32_t id = instruction.word(operand_.offset); if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { const spv_result_t result = EncodeTypeId(); if (result != SPV_SUCCESS) return result; } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { const spv_result_t result = EncodeResultId(); if (result != SPV_SUCCESS) return result; } else { const spv_result_t result = EncodeRefId(id); if (result != SPV_SUCCESS) return result; } multi_mtf_.Promote(id); break; } case SPV_OPERAND_TYPE_LITERAL_INTEGER: { const spv_result_t result = EncodeNonIdWord(instruction.word(operand_.offset)); if (result != SPV_SUCCESS) return result; break; } case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { const spv_result_t result = EncodeLiteralNumber(instruction, operand_); if (result != SPV_SUCCESS) return result; break; } case SPV_OPERAND_TYPE_LITERAL_STRING: { const char* src = reinterpret_cast( &instruction.words()[operand_.offset]); auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); if (codec) { uint64_t bits = 0; size_t num_bits = 0; const std::string str = reinterpret_cast( &instruction.words()[operand_.offset]); if (codec->Encode(str, &bits, &num_bits)) { writer_.WriteBits(bits, num_bits); break; } else { bool result = codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); (void)result; assert(result); writer_.WriteBits(bits, num_bits); } } const size_t length = spv_strnlen_s(src, operand_.num_words * 4); if (length == operand_.num_words * 4) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to find terminal character of literal string"; for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); break; } default: { for (int i = 0; i < operand_.num_words; ++i) { const uint32_t word = instruction.word(operand_.offset + i); const spv_result_t result = EncodeNonIdWord(word); if (result != SPV_SUCCESS) return result; } break; } } } AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte); if (logger_) { logger_->NewLine(); logger_->NewLine(); } ProcessCurInstruction(); instructions_.push_back(&instruction); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { const bool header_read_success = reader_.ReadUnencoded(&header_.magic_number) && reader_.ReadUnencoded(&header_.markv_version) && reader_.ReadUnencoded(&header_.markv_model) && reader_.ReadUnencoded(&header_.markv_length_in_bits) && reader_.ReadUnencoded(&header_.spirv_version) && reader_.ReadUnencoded(&header_.spirv_generator); if (!header_read_success) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; if (header_.markv_length_in_bits == 0) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Header markv_length_in_bits field is zero"; if (header_.magic_number != kMarkvMagicNumber) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has incorrect magic number"; // TODO(atgoo@github.com): Print version strings. if (header_.markv_version != GetMarkvVersion()) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary and the codec have different versions"; spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. spirv_.resize(5, 0); spirv_[0] = kSpirvMagicNumber; spirv_[1] = header_.spirv_version; spirv_[2] = header_.spirv_generator; while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { inst_ = {}; const spv_result_t decode_result = DecodeInstruction(); if (decode_result != SPV_SUCCESS) return decode_result; const spv_result_t validation_result = UpdateValidationState(inst_); if (validation_result != SPV_SUCCESS) return validation_result; instructions_.push_back(&vstate_.ordered_instructions().back()); } if (reader_.GetNumReadBits() != header_.markv_length_in_bits || !reader_.OnlyZeroesLeft()) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has wrong stated bit length " << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; } // Decoding of the module is finished, validation state should have correct // id bound. spirv_[3] = vstate_.getIdBound(); *spirv_binary = std::move(spirv_); return SPV_SUCCESS; } // TODO(atgoo@github.com): The implementation borrows heavily from // Parser::parseOperand. // Consider coupling them together in some way once MARK-V codec is more mature. // For now it's better to keep the code independent for experimentation // purposes. spv_result_t MarkvDecoder::DecodeOperand( size_t operand_offset, const spv_operand_type_t type, spv_operand_pattern_t* expected_operands) { const SpvOp opcode = static_cast(inst_.opcode); memset(&operand_, 0, sizeof(operand_)); assert((operand_offset >> 16) == 0); operand_.offset = static_cast(operand_offset); operand_.type = type; // Set default values, may be updated later. operand_.number_kind = SPV_NUMBER_NONE; operand_.number_bit_width = 0; const size_t first_word_index = inst_words_.size(); switch (type) { case SPV_OPERAND_TYPE_RESULT_ID: { const spv_result_t result = DecodeResultId(); if (result != SPV_SUCCESS) return result; inst_words_.push_back(inst_.result_id); vstate_.setIdBound(std::max(vstate_.getIdBound(), inst_.result_id + 1)); multi_mtf_.Promote(inst_.result_id); break; } case SPV_OPERAND_TYPE_TYPE_ID: { const spv_result_t result = DecodeTypeId(); if (result != SPV_SUCCESS) return result; inst_words_.push_back(inst_.type_id); vstate_.setIdBound(std::max(vstate_.getIdBound(), inst_.type_id + 1)); multi_mtf_.Promote(inst_.type_id); break; } case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_OPTIONAL_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { uint32_t id = 0; const spv_result_t result = DecodeRefId(&id); if (result != SPV_SUCCESS) return result; if (id == 0) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { operand_.type = SPV_OPERAND_TYPE_ID; if (opcode == SpvOpExtInst && operand_.offset == 3) { // The current word is the extended instruction set id. // Set the extended instruction set type for the current // instruction. auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { return vstate_.diag(SPV_ERROR_INVALID_ID) << "OpExtInst set id " << id << " does not reference an OpExtInstImport result Id"; } inst_.ext_inst_type = ext_inst_type_iter->second; } } inst_words_.push_back(id); vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1)); multi_mtf_.Promote(id); break; } case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); assert(SpvOpExtInst == opcode); assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); spv_ext_inst_desc ext_inst; if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction number: " << word; spvPushOperandTypes(ext_inst->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { // These are regular single-word literal integer operands. // Post-parsing validation should check the range of the parsed value. operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; // It turns out they are always unsigned integers! operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; operand_.number_bit_width = 32; uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); break; } case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; if (opcode == SpvOpSwitch) { // The literal operands have the same type as the value // referenced by the selector Id. const uint32_t selector_id = inst_words_.at(1); const auto type_id_iter = id_to_type_id_.find(selector_id); if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " has no type"; } uint32_t type_id = type_id_iter->second; if (selector_id == type_id) { // Recall that by convention, a result ID that is a type definition // maps to itself. return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is a type, not a value"; } if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) return error; if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && operand_.number_kind != SPV_NUMBER_SIGNED_INT) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is not a scalar integer"; } } else { assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); // The literal number type is determined by the type Id for the // constant. assert(inst_.type_id); if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) return error; } if (auto error = DecodeLiteralNumber(operand_)) return error; break; } case SPV_OPERAND_TYPE_LITERAL_STRING: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; std::vector str; auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); if (codec) { std::string decoded_string; const bool huffman_result = codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); assert(huffman_result); if (!huffman_result) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal string"; if (decoded_string != "kMarkvNoneOfTheAbove") { std::copy(decoded_string.begin(), decoded_string.end(), std::back_inserter(str)); str.push_back('\0'); } } // The loop is expected to terminate once we encounter '\0' or exhaust // the bit stream. if (str.empty()) { while (true) { char ch = 0; if (!reader_.ReadUnencoded(&ch)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal string"; str.push_back(ch); if (ch == '\0') break; } } while (str.size() % 4 != 0) str.push_back('\0'); inst_words_.resize(inst_words_.size() + str.size() / 4); std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); if (SpvOpExtInstImport == opcode) { // Record the extended instruction type for the ID for this import. // There is only one string literal argument to OpExtInstImport, // so it's sufficient to guard this just on the opcode. const spv_ext_inst_type_t ext_inst_type = spvExtInstImportTypeGet(str.data()); if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction import '" << str.data() << "'"; } // We must have parsed a valid result ID. It's a condition // of the grammar, and we only accept non-zero result Ids. assert(inst_.result_id); const bool inserted = import_id_to_ext_inst_type_.emplace( inst_.result_id, ext_inst_type).second; (void)inserted; assert(inserted); } break; } case SPV_OPERAND_TYPE_CAPABILITY: case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: case SPV_OPERAND_TYPE_ADDRESSING_MODEL: case SPV_OPERAND_TYPE_MEMORY_MODEL: case SPV_OPERAND_TYPE_EXECUTION_MODE: case SPV_OPERAND_TYPE_STORAGE_CLASS: case SPV_OPERAND_TYPE_DIMENSIONALITY: case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: case SPV_OPERAND_TYPE_LINKAGE_TYPE: case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: case SPV_OPERAND_TYPE_DECORATION: case SPV_OPERAND_TYPE_BUILT_IN: case SPV_OPERAND_TYPE_GROUP_OPERATION: case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { // A single word that is a plain enum value. uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; spv_operand_desc entry; if (grammar_.lookupOperand(type, word, &entry)) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(operand_.type) << " operand: " << word; } // Prepare to accept operands to this operand, if needed. spvPushOperandTypes(entry->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: case SPV_OPERAND_TYPE_FUNCTION_CONTROL: case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: { // This operand is a mask. uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) operand_.type = SPV_OPERAND_TYPE_IMAGE; else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; // Check validity of set mask bits. Also prepare for operands for those // masks if they have any. To get operand order correct, scan from // MSB to LSB since we can only prepend operands to a pattern. // The only case in the grammar where you have more than one mask bit // having an operand is for image operands. See SPIR-V 3.14 Image // Operands. uint32_t remaining_word = word; for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { if (remaining_word & mask) { spv_operand_desc entry; if (grammar_.lookupOperand(type, mask, &entry)) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(operand_.type) << " operand: " << word << " has invalid mask component " << mask; } remaining_word ^= mask; spvPushOperandTypes(entry->operandTypes, expected_operands); } } if (word == 0) { // An all-zeroes mask *might* also be valid. spv_operand_desc entry; if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { // Prepare for its operands, if any. spvPushOperandTypes(entry->operandTypes, expected_operands); } } break; } default: return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Internal error: Unhandled operand type: " << type; } operand_.num_words = uint16_t(inst_words_.size() - first_word_index); assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(operand_.type)); assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(operand_.type)); parsed_operands_.push_back(operand_); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeInstruction() { parsed_operands_.clear(); inst_words_.clear(); // Opcode/num_words placeholder, the word will be filled in later. inst_words_.push_back(0); bool num_operands_still_unknown = true; { uint32_t opcode = 0; uint32_t num_operands = 0; const spv_result_t opcode_decoding_result = DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); if (opcode_decoding_result < 0) return opcode_decoding_result; if (opcode_decoding_result == SPV_SUCCESS) { inst_.num_operands = static_cast(num_operands); num_operands_still_unknown = false; } else { if (!reader_.ReadVariableWidthU32( &opcode, model_->opcode_chunk_length())) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read opcode of instruction"; } } inst_.opcode = static_cast(opcode); } const SpvOp opcode = static_cast(inst_.opcode); spv_opcode_desc opcode_desc; if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; } spv_operand_pattern_t expected_operands; expected_operands.reserve(opcode_desc->numTypes); for (auto i = 0; i < opcode_desc->numTypes; i++) { expected_operands.push_back( opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); } if (num_operands_still_unknown) { if (!OpcodeHasFixedNumberOfOperands(opcode)) { if (!reader_.ReadVariableWidthU16(&inst_.num_operands, model_->num_operands_chunk_length())) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read num_operands of instruction"; } else { inst_.num_operands = static_cast(expected_operands.size()); } } for (operand_index_ = 0; operand_index_ < static_cast(inst_.num_operands); ++operand_index_) { assert(!expected_operands.empty()); const spv_operand_type_t type = spvTakeFirstMatchableOperand(&expected_operands); const size_t operand_offset = inst_words_.size(); const spv_result_t decode_result = DecodeOperand( operand_offset, type, &expected_operands); if (decode_result != SPV_SUCCESS) return decode_result; } assert(inst_.num_operands == parsed_operands_.size()); // Only valid while inst_words_ and parsed_operands_ remain unchanged (until // next DecodeInstruction call). inst_.words = inst_words_.data(); inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); inst_.num_words = static_cast(inst_words_.size()); inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); assert(inst_.num_words == std::accumulate( parsed_operands_.begin(), parsed_operands_.end(), 1, [](int num_words, const spv_parsed_operand_t& operand) { return num_words += operand.num_words; }) && "num_words in instruction doesn't correspond to the sum of num_words" "in the operands"); RecordNumberType(); ProcessCurInstruction(); if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte)) return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; return SPV_SUCCESS; } spv_result_t MarkvDecoder::SetNumericTypeInfoForType( spv_parsed_operand_t* parsed_operand, uint32_t type_id) { assert(type_id != 0); auto type_info_iter = type_id_to_number_type_info_.find(type_id); if (type_info_iter == type_id_to_number_type_info_.end()) { return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a type"; } const NumberType& info = type_info_iter->second; if (info.type == SPV_NUMBER_NONE) { // This is a valid type, but for something other than a scalar number. return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a scalar numeric type"; } parsed_operand->number_kind = info.type; parsed_operand->number_bit_width = info.bit_width; // Round up the word count. parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); return SPV_SUCCESS; } void MarkvDecoder::RecordNumberType() { const SpvOp opcode = static_cast(inst_.opcode); if (spvOpcodeGeneratesType(opcode)) { NumberType info = {SPV_NUMBER_NONE, 0}; if (SpvOpTypeInt == opcode) { info.bit_width = inst_.words[inst_.operands[1].offset]; info.type = inst_.words[inst_.operands[2].offset] ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; } else if (SpvOpTypeFloat == opcode) { info.bit_width = inst_.words[inst_.operands[1].offset]; info.type = SPV_NUMBER_FLOATING; } // The *result* Id of a type generating instruction is the type Id. type_id_to_number_type_info_[inst_.result_id] = info; } } spv_result_t EncodeHeader( void* user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t schema) { MarkvEncoder* encoder = reinterpret_cast(user_data); return encoder->EncodeHeader( endian, magic, version, generator, id_bound, schema); } spv_result_t EncodeInstruction( void* user_data, const spv_parsed_instruction_t* inst) { MarkvEncoder* encoder = reinterpret_cast(user_data); return encoder->EncodeInstruction(*inst); } } // namespace spv_result_t spvSpirvToMarkv(spv_const_context context, const uint32_t* spirv_words, const size_t spirv_num_words, spv_const_markv_encoder_options options, spv_markv_binary* markv_binary, spv_text* comments, spv_diagnostic* diagnostic) { spv_context_t hijack_context = *context; if (diagnostic) { *diagnostic = nullptr; libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); } spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words}; spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(&spirv_binary, &endian)) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } MarkvEncoder encoder(&hijack_context, options); if (comments) { encoder.CreateCommentsLogger(); spv_text text = nullptr; if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Failed to disassemble SPIR-V binary."; } assert(text); encoder.SetDisassembly(std::string(text->str, text->length)); spvTextDestroy(text); } if (spvBinaryParse( &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader, EncodeInstruction, diagnostic) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Unable to encode to MARK-V."; } if (comments) *comments = CreateSpvText(encoder.GetComments()); *markv_binary = encoder.GetMarkvBinary(); return SPV_SUCCESS; } spv_result_t spvMarkvToSpirv(spv_const_context context, const uint8_t* markv_data, size_t markv_size_bytes, spv_const_markv_decoder_options options, spv_binary* spirv_binary, spv_text* /* comments */, spv_diagnostic* diagnostic) { spv_position_t position = {}; spv_context_t hijack_context = *context; if (diagnostic) { *diagnostic = nullptr; libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); } MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options); std::vector words; if (decoder.DecodeModule(&words) != SPV_SUCCESS) { return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Unable to decode MARK-V."; } assert(!words.empty()); *spirv_binary = new spv_binary_t(); (*spirv_binary)->code = new uint32_t[words.size()]; (*spirv_binary)->wordCount = words.size(); std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size()); return SPV_SUCCESS; } void spvMarkvBinaryDestroy(spv_markv_binary binary) { if (!binary) return; delete[] binary->data; delete binary; } spv_markv_encoder_options spvMarkvEncoderOptionsCreate() { return new spv_markv_encoder_options_t; } void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) { delete options; } spv_markv_decoder_options spvMarkvDecoderOptionsCreate() { return new spv_markv_decoder_options_t; } void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) { delete options; }