// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/comp/markv_decoder.h" #include #include #include #include "source/ext_inst.h" #include "source/opcode.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { namespace comp { spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); if (codec) { uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to decode non-id word with Huffman"; if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { // The word decoded successfully. *word = uint32_t(decoded_value); assert(*word == decoded_value); return SPV_SUCCESS; } // Received kMarkvNoneOfTheAbove signal, use fallback decoding. } const size_t chunk_length = model_->GetOperandVariableWidthChunkLength(operand_.type); if (chunk_length) { if (!reader_.ReadVariableWidthU32(word, chunk_length)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to decode non-id word with varint"; } else { if (!reader_.ReadUnencoded(word)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read unencoded non-id word"; } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( uint32_t* opcode, uint32_t* num_operands) { // First try to use the Markov chain codec. auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); if (codec) { uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode opcode_and_num_operands, previous opcode is " << spvOpcodeString(GetPrevOpcode()); if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { // The word was successfully decoded. *opcode = uint32_t(decoded_value & 0xFFFF); *num_operands = uint32_t(decoded_value >> 16); return SPV_SUCCESS; } // Received kMarkvNoneOfTheAbove signal, use fallback decoding. } // Fallback to base-rate codec. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); assert(codec); uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode opcode_and_num_operands with global codec"; if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) { // Received kMarkvNoneOfTheAbove signal, fallback further. return SPV_UNSUPPORTED; } *opcode = uint32_t(decoded_value & 0xFFFF); *num_operands = uint32_t(decoded_value >> 16); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, uint32_t* rank) { const auto* codec = GetMtfHuffmanCodec(mtf); if (!codec) { assert(fallback_method != kMtfNone); codec = GetMtfHuffmanCodec(fallback_method); } if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; uint32_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; if (decoded_value == kMtfRankEncodedByValueSignal) { // Decode by value. if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with varint"; *rank += MarkvCodec::kMtfSmallestRankEncodedByValue; } else { // Decode using Huffman coding. assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue); *rank = decoded_value; } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); uint64_t mtf = kMtfNone; if (codec) { uint64_t decoded_value = 0; if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode descriptor with Huffman"; if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { const uint32_t long_descriptor = uint32_t(decoded_value); mtf = GetMtfLongIdDescriptor(long_descriptor); } } if (mtf == kMtfNone) { if (model_->id_fallback_strategy() != MarkvModel::IdFallbackStrategy::kShortDescriptor) { return SPV_UNSUPPORTED; } uint64_t decoded_value = 0; if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits)) return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor"; const uint32_t short_descriptor = uint32_t(decoded_value); if (short_descriptor == 0) { // Forward declared id. return SPV_UNSUPPORTED; } mtf = GetMtfShortIdDescriptor(short_descriptor); } return DecodeExistingId(mtf, id); } spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { assert(multi_mtf_.GetSize(mtf) > 0); *id = 0; uint32_t rank = 0; if (multi_mtf_.GetSize(mtf) == 1) { rank = 1; } else { const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); if (result != SPV_SUCCESS) return result; } assert(rank); if (!multi_mtf_.ValueFromRank(mtf, rank, id)) return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { { const spv_result_t result = DecodeIdWithDescriptor(id); if (result != SPV_UNSUPPORTED) return result; } const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_); uint32_t rank = 0; *id = 0; if (model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased) { uint64_t mtf = GetRuleBasedMtf(); if (mtf != kMtfNone && !can_forward_declare) { return DecodeExistingId(mtf, id); } if (mtf == kMtfNone) mtf = kMtfAll; { const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); if (result != SPV_SUCCESS) return result; } if (rank == 0) { // This is the first occurrence of a forward declared id. *id = GetIdBound(); SetIdBound(*id + 1); multi_mtf_.Insert(kMtfAll, *id); multi_mtf_.Insert(kMtfForwardDeclared, *id); if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); } else { if (!multi_mtf_.ValueFromRank(mtf, rank, id)) return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; } } else { assert(can_forward_declare); if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with varint"; if (rank == 0) { // This is the first occurrence of a forward declared id. *id = GetIdBound(); SetIdBound(*id + 1); multi_mtf_.Insert(kMtfForwardDeclared, *id); } else { if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id)) return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; } } assert(*id); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeTypeId() { if (inst_.opcode == SpvOpFunctionParameter) { assert(!remaining_function_parameter_types_.empty()); inst_.type_id = remaining_function_parameter_types_.front(); remaining_function_parameter_types_.pop_front(); return SPV_SUCCESS; } { const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); if (result != SPV_UNSUPPORTED) return result; } assert(model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased); uint64_t mtf = GetRuleBasedMtf(); assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( operand_index_)); if (mtf == kMtfNone) { mtf = kMtfTypeNonFunction; // Function types should have been handled by GetRuleBasedMtf. assert(inst_.opcode != SpvOpFunction); } return DecodeExistingId(mtf, &inst_.type_id); } spv_result_t MarkvDecoder::DecodeResultId() { uint32_t rank = 0; const uint64_t num_still_forward_declared = multi_mtf_.GetSize(kMtfForwardDeclared); if (num_still_forward_declared) { // Some ids were forward declared. Check if this id is one of them. uint64_t id_was_forward_declared; if (!reader_.ReadBits(&id_was_forward_declared, 1)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id_was_forward_declared flag"; if (id_was_forward_declared) { if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read MTF rank of forward declared id"; if (rank) { // The id was forward declared, recover it from kMtfForwardDeclared. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, &inst_.result_id)) return Diag(SPV_ERROR_INTERNAL) << "Forward declared MTF rank is out of bounds"; // We can now remove the id from kMtfForwardDeclared. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) return Diag(SPV_ERROR_INTERNAL) << "Failed to remove id from kMtfForwardDeclared"; } } } if (inst_.result_id == 0) { // The id was not forward declared, issue a new id. inst_.result_id = GetIdBound(); SetIdBound(inst_.result_id + 1); } if (model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased) { if (!rank) { multi_mtf_.Insert(kMtfAll, inst_.result_id); } } return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeLiteralNumber( const spv_parsed_operand_t& operand) { if (operand.number_bit_width <= 32) { uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); } else { assert(operand.number_bit_width <= 64); uint64_t word = 0; if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), model_->s64_block_exponent())) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; std::memcpy(&word, &val, 8); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { if (!reader_.ReadUnencoded(&word)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; } else { return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; } inst_words_.push_back(static_cast(word)); inst_words_.push_back(static_cast(word >> 32)); } return SPV_SUCCESS; } bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { const size_t num_bits_to_next_byte = GetNumBitsToNextByte(reader_.GetNumReadBits()); if (num_bits_to_next_byte == 0 || num_bits_to_next_byte > byte_break_if_less_than) return true; uint64_t bits = 0; if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; assert(bits == 0); if (bits != 0) return false; return true; } spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { const bool header_read_success = reader_.ReadUnencoded(&header_.magic_number) && reader_.ReadUnencoded(&header_.markv_version) && reader_.ReadUnencoded(&header_.markv_model) && reader_.ReadUnencoded(&header_.markv_length_in_bits) && reader_.ReadUnencoded(&header_.spirv_version) && reader_.ReadUnencoded(&header_.spirv_generator); if (!header_read_success) return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; if (header_.markv_length_in_bits == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Header markv_length_in_bits field is zero"; if (header_.magic_number != MarkvCodec::kMarkvMagicNumber) return Diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has incorrect magic number"; // TODO(atgoo@github.com): Print version strings. if (header_.markv_version != MarkvCodec::GetMarkvVersion()) return Diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary and the codec have different versions"; const uint32_t model_type = header_.markv_model >> 16; const uint32_t model_version = header_.markv_model & 0xFFFF; if (model_type != model_->model_type()) return Diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary and the codec use different MARK-V models"; if (model_version != model_->model_version()) return Diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary and the codec use different versions if the same " << "MARK-V model"; spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. spirv_.resize(5, 0); spirv_[0] = SpvMagicNumber; spirv_[1] = header_.spirv_version; spirv_[2] = header_.spirv_generator; if (logger_) { reader_.SetCallback( [this](const std::string& str) { logger_->AppendBitSequence(str); }); } while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { inst_ = {}; const spv_result_t decode_result = DecodeInstruction(); if (decode_result != SPV_SUCCESS) return decode_result; } if (validator_options_) { spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()}; const spv_result_t result = spvValidateWithOptions( context_, validator_options_, &validation_binary, nullptr); if (result != SPV_SUCCESS) return result; } // Validate the decode binary if (reader_.GetNumReadBits() != header_.markv_length_in_bits || !reader_.OnlyZeroesLeft()) { return Diag(SPV_ERROR_INVALID_BINARY) << "MARK-V binary has wrong stated bit length " << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; } // Decoding of the module is finished, validation state should have correct // id bound. spirv_[3] = GetIdBound(); *spirv_binary = std::move(spirv_); return SPV_SUCCESS; } // TODO(atgoo@github.com): The implementation borrows heavily from // Parser::parseOperand. // Consider coupling them together in some way once MARK-V codec is more mature. // For now it's better to keep the code independent for experimentation // purposes. spv_result_t MarkvDecoder::DecodeOperand( size_t operand_offset, const spv_operand_type_t type, spv_operand_pattern_t* expected_operands) { const SpvOp opcode = static_cast(inst_.opcode); memset(&operand_, 0, sizeof(operand_)); assert((operand_offset >> 16) == 0); operand_.offset = static_cast(operand_offset); operand_.type = type; // Set default values, may be updated later. operand_.number_kind = SPV_NUMBER_NONE; operand_.number_bit_width = 0; const size_t first_word_index = inst_words_.size(); switch (type) { case SPV_OPERAND_TYPE_RESULT_ID: { const spv_result_t result = DecodeResultId(); if (result != SPV_SUCCESS) return result; inst_words_.push_back(inst_.result_id); SetIdBound(std::max(GetIdBound(), inst_.result_id + 1)); PromoteIfNeeded(inst_.result_id); break; } case SPV_OPERAND_TYPE_TYPE_ID: { const spv_result_t result = DecodeTypeId(); if (result != SPV_SUCCESS) return result; inst_words_.push_back(inst_.type_id); SetIdBound(std::max(GetIdBound(), inst_.type_id + 1)); PromoteIfNeeded(inst_.type_id); break; } case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_OPTIONAL_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { uint32_t id = 0; const spv_result_t result = DecodeRefId(&id); if (result != SPV_SUCCESS) return result; if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { operand_.type = SPV_OPERAND_TYPE_ID; if (opcode == SpvOpExtInst && operand_.offset == 3) { // The current word is the extended instruction set id. // Set the extended instruction set type for the current // instruction. auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { return Diag(SPV_ERROR_INVALID_ID) << "OpExtInst set id " << id << " does not reference an OpExtInstImport result Id"; } inst_.ext_inst_type = ext_inst_type_iter->second; } } inst_words_.push_back(id); SetIdBound(std::max(GetIdBound(), id + 1)); PromoteIfNeeded(id); break; } case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); assert(SpvOpExtInst == opcode); assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); spv_ext_inst_desc ext_inst; if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction number: " << word; spvPushOperandTypes(ext_inst->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { // These are regular single-word literal integer operands. // Post-parsing validation should check the range of the parsed value. operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; // It turns out they are always unsigned integers! operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; operand_.number_bit_width = 32; uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); break; } case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; if (opcode == SpvOpSwitch) { // The literal operands have the same type as the value // referenced by the selector Id. const uint32_t selector_id = inst_words_.at(1); const auto type_id_iter = id_to_type_id_.find(selector_id); if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " has no type"; } uint32_t type_id = type_id_iter->second; if (selector_id == type_id) { // Recall that by convention, a result ID that is a type definition // maps to itself. return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is a type, not a value"; } if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) return error; if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && operand_.number_kind != SPV_NUMBER_SIGNED_INT) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid OpSwitch: selector id " << selector_id << " is not a scalar integer"; } } else { assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); // The literal number type is determined by the type Id for the // constant. assert(inst_.type_id); if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) return error; } if (auto error = DecodeLiteralNumber(operand_)) return error; break; } case SPV_OPERAND_TYPE_LITERAL_STRING: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; std::vector str; auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); if (codec) { std::string decoded_string; const bool huffman_result = codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); assert(huffman_result); if (!huffman_result) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal string"; if (decoded_string != "kMarkvNoneOfTheAbove") { std::copy(decoded_string.begin(), decoded_string.end(), std::back_inserter(str)); str.push_back('\0'); } } // The loop is expected to terminate once we encounter '\0' or exhaust // the bit stream. if (str.empty()) { while (true) { char ch = 0; if (!reader_.ReadUnencoded(&ch)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal string"; str.push_back(ch); if (ch == '\0') break; } } while (str.size() % 4 != 0) str.push_back('\0'); inst_words_.resize(inst_words_.size() + str.size() / 4); std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); if (SpvOpExtInstImport == opcode) { // Record the extended instruction type for the ID for this import. // There is only one string literal argument to OpExtInstImport, // so it's sufficient to guard this just on the opcode. const spv_ext_inst_type_t ext_inst_type = spvExtInstImportTypeGet(str.data()); if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid extended instruction import '" << str.data() << "'"; } // We must have parsed a valid result ID. It's a condition // of the grammar, and we only accept non-zero result Ids. assert(inst_.result_id); const bool inserted = import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type) .second; (void)inserted; assert(inserted); } break; } case SPV_OPERAND_TYPE_CAPABILITY: case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: case SPV_OPERAND_TYPE_ADDRESSING_MODEL: case SPV_OPERAND_TYPE_MEMORY_MODEL: case SPV_OPERAND_TYPE_EXECUTION_MODE: case SPV_OPERAND_TYPE_STORAGE_CLASS: case SPV_OPERAND_TYPE_DIMENSIONALITY: case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: case SPV_OPERAND_TYPE_LINKAGE_TYPE: case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: case SPV_OPERAND_TYPE_DECORATION: case SPV_OPERAND_TYPE_BUILT_IN: case SPV_OPERAND_TYPE_GROUP_OPERATION: case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { // A single word that is a plain enum value. uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; spv_operand_desc entry; if (grammar_.lookupOperand(type, word, &entry)) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(operand_.type) << " operand: " << word; } // Prepare to accept operands to this operand, if needed. spvPushOperandTypes(entry->operandTypes, expected_operands); break; } case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: case SPV_OPERAND_TYPE_FUNCTION_CONTROL: case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: { // This operand is a mask. uint32_t word = 0; const spv_result_t result = DecodeNonIdWord(&word); if (result != SPV_SUCCESS) return result; inst_words_.push_back(word); // Map an optional operand type to its corresponding concrete type. if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) operand_.type = SPV_OPERAND_TYPE_IMAGE; else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; // Check validity of set mask bits. Also prepare for operands for those // masks if they have any. To get operand order correct, scan from // MSB to LSB since we can only prepend operands to a pattern. // The only case in the grammar where you have more than one mask bit // having an operand is for image operands. See SPIR-V 3.14 Image // Operands. uint32_t remaining_word = word; for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { if (remaining_word & mask) { spv_operand_desc entry; if (grammar_.lookupOperand(type, mask, &entry)) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid " << spvOperandTypeStr(operand_.type) << " operand: " << word << " has invalid mask component " << mask; } remaining_word ^= mask; spvPushOperandTypes(entry->operandTypes, expected_operands); } } if (word == 0) { // An all-zeroes mask *might* also be valid. spv_operand_desc entry; if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { // Prepare for its operands, if any. spvPushOperandTypes(entry->operandTypes, expected_operands); } } break; } default: return Diag(SPV_ERROR_INVALID_BINARY) << "Internal error: Unhandled operand type: " << type; } operand_.num_words = uint16_t(inst_words_.size() - first_word_index); assert(spvOperandIsConcrete(operand_.type)); parsed_operands_.push_back(operand_); return SPV_SUCCESS; } spv_result_t MarkvDecoder::DecodeInstruction() { parsed_operands_.clear(); inst_words_.clear(); // Opcode/num_words placeholder, the word will be filled in later. inst_words_.push_back(0); bool num_operands_still_unknown = true; { uint32_t opcode = 0; uint32_t num_operands = 0; const spv_result_t opcode_decoding_result = DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); if (opcode_decoding_result < 0) return opcode_decoding_result; if (opcode_decoding_result == SPV_SUCCESS) { inst_.num_operands = static_cast(num_operands); num_operands_still_unknown = false; } else { if (!reader_.ReadVariableWidthU32(&opcode, model_->opcode_chunk_length())) { return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read opcode of instruction"; } } inst_.opcode = static_cast(opcode); } const SpvOp opcode = static_cast(inst_.opcode); spv_opcode_desc opcode_desc; if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; } spv_operand_pattern_t expected_operands; expected_operands.reserve(opcode_desc->numTypes); for (auto i = 0; i < opcode_desc->numTypes; i++) { expected_operands.push_back( opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); } if (num_operands_still_unknown) { if (!OpcodeHasFixedNumberOfOperands(opcode)) { if (!reader_.ReadVariableWidthU16(&inst_.num_operands, model_->num_operands_chunk_length())) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read num_operands of instruction"; } else { inst_.num_operands = static_cast(expected_operands.size()); } } for (operand_index_ = 0; operand_index_ < static_cast(inst_.num_operands); ++operand_index_) { assert(!expected_operands.empty()); const spv_operand_type_t type = spvTakeFirstMatchableOperand(&expected_operands); const size_t operand_offset = inst_words_.size(); const spv_result_t decode_result = DecodeOperand(operand_offset, type, &expected_operands); if (decode_result != SPV_SUCCESS) return decode_result; } assert(inst_.num_operands == parsed_operands_.size()); // Only valid while inst_words_ and parsed_operands_ remain unchanged (until // next DecodeInstruction call). inst_.words = inst_words_.data(); inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); inst_.num_words = static_cast(inst_words_.size()); inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); assert(inst_.num_words == std::accumulate( parsed_operands_.begin(), parsed_operands_.end(), 1, [](int num_words, const spv_parsed_operand_t& operand) { return num_words += operand.num_words; }) && "num_words in instruction doesn't correspond to the sum of num_words" "in the operands"); RecordNumberType(); ProcessCurInstruction(); if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte)) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; if (logger_) { logger_->NewLine(); std::stringstream ss; ss << spvOpcodeString(opcode) << " "; for (size_t index = 1; index < inst_words_.size(); ++index) ss << inst_words_[index] << " "; logger_->AppendText(ss.str()); logger_->NewLine(); logger_->NewLine(); if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; } return SPV_SUCCESS; } spv_result_t MarkvDecoder::SetNumericTypeInfoForType( spv_parsed_operand_t* parsed_operand, uint32_t type_id) { assert(type_id != 0); auto type_info_iter = type_id_to_number_type_info_.find(type_id); if (type_info_iter == type_id_to_number_type_info_.end()) { return Diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a type"; } const NumberType& info = type_info_iter->second; if (info.type == SPV_NUMBER_NONE) { // This is a valid type, but for something other than a scalar number. return Diag(SPV_ERROR_INVALID_BINARY) << "Type Id " << type_id << " is not a scalar numeric type"; } parsed_operand->number_kind = info.type; parsed_operand->number_bit_width = info.bit_width; // Round up the word count. parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); return SPV_SUCCESS; } void MarkvDecoder::RecordNumberType() { const SpvOp opcode = static_cast(inst_.opcode); if (spvOpcodeGeneratesType(opcode)) { NumberType info = {SPV_NUMBER_NONE, 0}; if (SpvOpTypeInt == opcode) { info.bit_width = inst_.words[inst_.operands[1].offset]; info.type = inst_.words[inst_.operands[2].offset] ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; } else if (SpvOpTypeFloat == opcode) { info.bit_width = inst_.words[inst_.operands[1].offset]; info.type = SPV_NUMBER_FLOATING; } // The *result* Id of a type generating instruction is the type Id. type_id_to_number_type_info_[inst_.result_id] = info; } } } // namespace comp } // namespace spvtools