mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-12-12 03:30:09 +00:00
926 lines
32 KiB
C++
926 lines
32 KiB
C++
|
// Copyright (c) 2018 Google LLC
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#include "source/comp/markv_decoder.h"
|
||
|
|
||
|
#include <cstring>
|
||
|
#include <iterator>
|
||
|
#include <numeric>
|
||
|
|
||
|
#include "source/ext_inst.h"
|
||
|
#include "source/opcode.h"
|
||
|
#include "spirv-tools/libspirv.hpp"
|
||
|
|
||
|
namespace spvtools {
|
||
|
namespace comp {
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
|
||
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
||
|
|
||
|
if (codec) {
|
||
|
uint64_t decoded_value = 0;
|
||
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to decode non-id word with Huffman";
|
||
|
|
||
|
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||
|
// The word decoded successfully.
|
||
|
*word = uint32_t(decoded_value);
|
||
|
assert(*word == decoded_value);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
||
|
}
|
||
|
|
||
|
const size_t chunk_length =
|
||
|
model_->GetOperandVariableWidthChunkLength(operand_.type);
|
||
|
if (chunk_length) {
|
||
|
if (!reader_.ReadVariableWidthU32(word, chunk_length))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to decode non-id word with varint";
|
||
|
} else {
|
||
|
if (!reader_.ReadUnencoded(word))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read unencoded non-id word";
|
||
|
}
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
|
||
|
uint32_t* opcode, uint32_t* num_operands) {
|
||
|
// First try to use the Markov chain codec.
|
||
|
auto* codec =
|
||
|
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
||
|
if (codec) {
|
||
|
uint64_t decoded_value = 0;
|
||
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to decode opcode_and_num_operands, previous opcode is "
|
||
|
<< spvOpcodeString(GetPrevOpcode());
|
||
|
|
||
|
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||
|
// The word was successfully decoded.
|
||
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
||
|
*num_operands = uint32_t(decoded_value >> 16);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
||
|
}
|
||
|
|
||
|
// Fallback to base-rate codec.
|
||
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
||
|
assert(codec);
|
||
|
uint64_t decoded_value = 0;
|
||
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to decode opcode_and_num_operands with global codec";
|
||
|
|
||
|
if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||
|
// Received kMarkvNoneOfTheAbove signal, fallback further.
|
||
|
return SPV_UNSUPPORTED;
|
||
|
}
|
||
|
|
||
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
||
|
*num_operands = uint32_t(decoded_value >> 16);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
|
||
|
uint32_t fallback_method,
|
||
|
uint32_t* rank) {
|
||
|
const auto* codec = GetMtfHuffmanCodec(mtf);
|
||
|
if (!codec) {
|
||
|
assert(fallback_method != kMtfNone);
|
||
|
codec = GetMtfHuffmanCodec(fallback_method);
|
||
|
}
|
||
|
|
||
|
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
|
||
|
|
||
|
uint32_t decoded_value = 0;
|
||
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
|
||
|
|
||
|
if (decoded_value == kMtfRankEncodedByValueSignal) {
|
||
|
// Decode by value.
|
||
|
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to decode MTF rank with varint";
|
||
|
*rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
|
||
|
} else {
|
||
|
// Decode using Huffman coding.
|
||
|
assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
|
||
|
*rank = decoded_value;
|
||
|
}
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
|
||
|
auto* codec =
|
||
|
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
|
||
|
|
||
|
uint64_t mtf = kMtfNone;
|
||
|
if (codec) {
|
||
|
uint64_t decoded_value = 0;
|
||
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to decode descriptor with Huffman";
|
||
|
|
||
|
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
|
||
|
const uint32_t long_descriptor = uint32_t(decoded_value);
|
||
|
mtf = GetMtfLongIdDescriptor(long_descriptor);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (mtf == kMtfNone) {
|
||
|
if (model_->id_fallback_strategy() !=
|
||
|
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
||
|
return SPV_UNSUPPORTED;
|
||
|
}
|
||
|
|
||
|
uint64_t decoded_value = 0;
|
||
|
if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
|
||
|
const uint32_t short_descriptor = uint32_t(decoded_value);
|
||
|
if (short_descriptor == 0) {
|
||
|
// Forward declared id.
|
||
|
return SPV_UNSUPPORTED;
|
||
|
}
|
||
|
mtf = GetMtfShortIdDescriptor(short_descriptor);
|
||
|
}
|
||
|
|
||
|
return DecodeExistingId(mtf, id);
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
|
||
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
||
|
*id = 0;
|
||
|
|
||
|
uint32_t rank = 0;
|
||
|
|
||
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
||
|
rank = 1;
|
||
|
} else {
|
||
|
const spv_result_t result =
|
||
|
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
}
|
||
|
|
||
|
assert(rank);
|
||
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
|
||
|
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
|
||
|
{
|
||
|
const spv_result_t result = DecodeIdWithDescriptor(id);
|
||
|
if (result != SPV_UNSUPPORTED) return result;
|
||
|
}
|
||
|
|
||
|
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
|
||
|
SpvOp(inst_.opcode))(operand_index_);
|
||
|
uint32_t rank = 0;
|
||
|
*id = 0;
|
||
|
|
||
|
if (model_->id_fallback_strategy() ==
|
||
|
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||
|
uint64_t mtf = GetRuleBasedMtf();
|
||
|
if (mtf != kMtfNone && !can_forward_declare) {
|
||
|
return DecodeExistingId(mtf, id);
|
||
|
}
|
||
|
|
||
|
if (mtf == kMtfNone) mtf = kMtfAll;
|
||
|
{
|
||
|
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
}
|
||
|
|
||
|
if (rank == 0) {
|
||
|
// This is the first occurrence of a forward declared id.
|
||
|
*id = GetIdBound();
|
||
|
SetIdBound(*id + 1);
|
||
|
multi_mtf_.Insert(kMtfAll, *id);
|
||
|
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
||
|
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
|
||
|
} else {
|
||
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
||
|
}
|
||
|
} else {
|
||
|
assert(can_forward_declare);
|
||
|
|
||
|
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to decode MTF rank with varint";
|
||
|
|
||
|
if (rank == 0) {
|
||
|
// This is the first occurrence of a forward declared id.
|
||
|
*id = GetIdBound();
|
||
|
SetIdBound(*id + 1);
|
||
|
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
||
|
} else {
|
||
|
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
||
|
}
|
||
|
}
|
||
|
assert(*id);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeTypeId() {
|
||
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
||
|
assert(!remaining_function_parameter_types_.empty());
|
||
|
inst_.type_id = remaining_function_parameter_types_.front();
|
||
|
remaining_function_parameter_types_.pop_front();
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
{
|
||
|
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
|
||
|
if (result != SPV_UNSUPPORTED) return result;
|
||
|
}
|
||
|
|
||
|
assert(model_->id_fallback_strategy() ==
|
||
|
MarkvModel::IdFallbackStrategy::kRuleBased);
|
||
|
|
||
|
uint64_t mtf = GetRuleBasedMtf();
|
||
|
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
|
||
|
operand_index_));
|
||
|
|
||
|
if (mtf == kMtfNone) {
|
||
|
mtf = kMtfTypeNonFunction;
|
||
|
// Function types should have been handled by GetRuleBasedMtf.
|
||
|
assert(inst_.opcode != SpvOpFunction);
|
||
|
}
|
||
|
|
||
|
return DecodeExistingId(mtf, &inst_.type_id);
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeResultId() {
|
||
|
uint32_t rank = 0;
|
||
|
|
||
|
const uint64_t num_still_forward_declared =
|
||
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
||
|
|
||
|
if (num_still_forward_declared) {
|
||
|
// Some ids were forward declared. Check if this id is one of them.
|
||
|
uint64_t id_was_forward_declared;
|
||
|
if (!reader_.ReadBits(&id_was_forward_declared, 1))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read id_was_forward_declared flag";
|
||
|
|
||
|
if (id_was_forward_declared) {
|
||
|
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read MTF rank of forward declared id";
|
||
|
|
||
|
if (rank) {
|
||
|
// The id was forward declared, recover it from kMtfForwardDeclared.
|
||
|
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
|
||
|
&inst_.result_id))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Forward declared MTF rank is out of bounds";
|
||
|
|
||
|
// We can now remove the id from kMtfForwardDeclared.
|
||
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
||
|
return Diag(SPV_ERROR_INTERNAL)
|
||
|
<< "Failed to remove id from kMtfForwardDeclared";
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (inst_.result_id == 0) {
|
||
|
// The id was not forward declared, issue a new id.
|
||
|
inst_.result_id = GetIdBound();
|
||
|
SetIdBound(inst_.result_id + 1);
|
||
|
}
|
||
|
|
||
|
if (model_->id_fallback_strategy() ==
|
||
|
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
||
|
if (!rank) {
|
||
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeLiteralNumber(
|
||
|
const spv_parsed_operand_t& operand) {
|
||
|
if (operand.number_bit_width <= 32) {
|
||
|
uint32_t word = 0;
|
||
|
const spv_result_t result = DecodeNonIdWord(&word);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
inst_words_.push_back(word);
|
||
|
} else {
|
||
|
assert(operand.number_bit_width <= 64);
|
||
|
uint64_t word = 0;
|
||
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
||
|
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
|
||
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
||
|
int64_t val = 0;
|
||
|
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
|
||
|
model_->s64_block_exponent()))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
|
||
|
std::memcpy(&word, &val, 8);
|
||
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
||
|
if (!reader_.ReadUnencoded(&word))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
|
||
|
} else {
|
||
|
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
||
|
}
|
||
|
inst_words_.push_back(static_cast<uint32_t>(word));
|
||
|
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
|
||
|
}
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
|
||
|
const size_t num_bits_to_next_byte =
|
||
|
GetNumBitsToNextByte(reader_.GetNumReadBits());
|
||
|
if (num_bits_to_next_byte == 0 ||
|
||
|
num_bits_to_next_byte > byte_break_if_less_than)
|
||
|
return true;
|
||
|
|
||
|
uint64_t bits = 0;
|
||
|
if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
|
||
|
|
||
|
assert(bits == 0);
|
||
|
if (bits != 0) return false;
|
||
|
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
|
||
|
const bool header_read_success =
|
||
|
reader_.ReadUnencoded(&header_.magic_number) &&
|
||
|
reader_.ReadUnencoded(&header_.markv_version) &&
|
||
|
reader_.ReadUnencoded(&header_.markv_model) &&
|
||
|
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
|
||
|
reader_.ReadUnencoded(&header_.spirv_version) &&
|
||
|
reader_.ReadUnencoded(&header_.spirv_generator);
|
||
|
|
||
|
if (!header_read_success)
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
|
||
|
|
||
|
if (header_.markv_length_in_bits == 0)
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Header markv_length_in_bits field is zero";
|
||
|
|
||
|
if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "MARK-V binary has incorrect magic number";
|
||
|
|
||
|
// TODO(atgoo@github.com): Print version strings.
|
||
|
if (header_.markv_version != MarkvCodec::GetMarkvVersion())
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "MARK-V binary and the codec have different versions";
|
||
|
|
||
|
const uint32_t model_type = header_.markv_model >> 16;
|
||
|
const uint32_t model_version = header_.markv_model & 0xFFFF;
|
||
|
if (model_type != model_->model_type())
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "MARK-V binary and the codec use different MARK-V models";
|
||
|
|
||
|
if (model_version != model_->model_version())
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "MARK-V binary and the codec use different versions if the same "
|
||
|
<< "MARK-V model";
|
||
|
|
||
|
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
|
||
|
spirv_.resize(5, 0);
|
||
|
spirv_[0] = SpvMagicNumber;
|
||
|
spirv_[1] = header_.spirv_version;
|
||
|
spirv_[2] = header_.spirv_generator;
|
||
|
|
||
|
if (logger_) {
|
||
|
reader_.SetCallback(
|
||
|
[this](const std::string& str) { logger_->AppendBitSequence(str); });
|
||
|
}
|
||
|
|
||
|
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
|
||
|
inst_ = {};
|
||
|
const spv_result_t decode_result = DecodeInstruction();
|
||
|
if (decode_result != SPV_SUCCESS) return decode_result;
|
||
|
}
|
||
|
|
||
|
if (validator_options_) {
|
||
|
spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
|
||
|
const spv_result_t result = spvValidateWithOptions(
|
||
|
context_, validator_options_, &validation_binary, nullptr);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
}
|
||
|
|
||
|
// Validate the decode binary
|
||
|
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
|
||
|
!reader_.OnlyZeroesLeft()) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "MARK-V binary has wrong stated bit length "
|
||
|
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
|
||
|
}
|
||
|
|
||
|
// Decoding of the module is finished, validation state should have correct
|
||
|
// id bound.
|
||
|
spirv_[3] = GetIdBound();
|
||
|
|
||
|
*spirv_binary = std::move(spirv_);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
// TODO(atgoo@github.com): The implementation borrows heavily from
|
||
|
// Parser::parseOperand.
|
||
|
// Consider coupling them together in some way once MARK-V codec is more mature.
|
||
|
// For now it's better to keep the code independent for experimentation
|
||
|
// purposes.
|
||
|
spv_result_t MarkvDecoder::DecodeOperand(
|
||
|
size_t operand_offset, const spv_operand_type_t type,
|
||
|
spv_operand_pattern_t* expected_operands) {
|
||
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||
|
|
||
|
memset(&operand_, 0, sizeof(operand_));
|
||
|
|
||
|
assert((operand_offset >> 16) == 0);
|
||
|
operand_.offset = static_cast<uint16_t>(operand_offset);
|
||
|
operand_.type = type;
|
||
|
|
||
|
// Set default values, may be updated later.
|
||
|
operand_.number_kind = SPV_NUMBER_NONE;
|
||
|
operand_.number_bit_width = 0;
|
||
|
|
||
|
const size_t first_word_index = inst_words_.size();
|
||
|
|
||
|
switch (type) {
|
||
|
case SPV_OPERAND_TYPE_RESULT_ID: {
|
||
|
const spv_result_t result = DecodeResultId();
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(inst_.result_id);
|
||
|
SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
|
||
|
PromoteIfNeeded(inst_.result_id);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_TYPE_ID: {
|
||
|
const spv_result_t result = DecodeTypeId();
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(inst_.type_id);
|
||
|
SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
|
||
|
PromoteIfNeeded(inst_.type_id);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_ID:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
||
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
||
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
||
|
uint32_t id = 0;
|
||
|
const spv_result_t result = DecodeRefId(&id);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
|
||
|
|
||
|
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
|
||
|
operand_.type = SPV_OPERAND_TYPE_ID;
|
||
|
|
||
|
if (opcode == SpvOpExtInst && operand_.offset == 3) {
|
||
|
// The current word is the extended instruction set id.
|
||
|
// Set the extended instruction set type for the current
|
||
|
// instruction.
|
||
|
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
|
||
|
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
|
||
|
return Diag(SPV_ERROR_INVALID_ID)
|
||
|
<< "OpExtInst set id " << id
|
||
|
<< " does not reference an OpExtInstImport result Id";
|
||
|
}
|
||
|
inst_.ext_inst_type = ext_inst_type_iter->second;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inst_words_.push_back(id);
|
||
|
SetIdBound(std::max(GetIdBound(), id + 1));
|
||
|
PromoteIfNeeded(id);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
|
||
|
uint32_t word = 0;
|
||
|
const spv_result_t result = DecodeNonIdWord(&word);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(word);
|
||
|
|
||
|
assert(SpvOpExtInst == opcode);
|
||
|
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
|
||
|
spv_ext_inst_desc ext_inst;
|
||
|
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid extended instruction number: " << word;
|
||
|
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
|
||
|
// These are regular single-word literal integer operands.
|
||
|
// Post-parsing validation should check the range of the parsed value.
|
||
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
|
||
|
// It turns out they are always unsigned integers!
|
||
|
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
|
||
|
operand_.number_bit_width = 32;
|
||
|
|
||
|
uint32_t word = 0;
|
||
|
const spv_result_t result = DecodeNonIdWord(&word);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(word);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
|
||
|
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
|
||
|
if (opcode == SpvOpSwitch) {
|
||
|
// The literal operands have the same type as the value
|
||
|
// referenced by the selector Id.
|
||
|
const uint32_t selector_id = inst_words_.at(1);
|
||
|
const auto type_id_iter = id_to_type_id_.find(selector_id);
|
||
|
if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid OpSwitch: selector id " << selector_id
|
||
|
<< " has no type";
|
||
|
}
|
||
|
uint32_t type_id = type_id_iter->second;
|
||
|
|
||
|
if (selector_id == type_id) {
|
||
|
// Recall that by convention, a result ID that is a type definition
|
||
|
// maps to itself.
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid OpSwitch: selector id " << selector_id
|
||
|
<< " is a type, not a value";
|
||
|
}
|
||
|
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
|
||
|
return error;
|
||
|
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
|
||
|
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid OpSwitch: selector id " << selector_id
|
||
|
<< " is not a scalar integer";
|
||
|
}
|
||
|
} else {
|
||
|
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
|
||
|
// The literal number type is determined by the type Id for the
|
||
|
// constant.
|
||
|
assert(inst_.type_id);
|
||
|
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
|
||
|
return error;
|
||
|
}
|
||
|
|
||
|
if (auto error = DecodeLiteralNumber(operand_)) return error;
|
||
|
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_LITERAL_STRING:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
|
||
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
|
||
|
std::vector<char> str;
|
||
|
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
|
||
|
|
||
|
if (codec) {
|
||
|
std::string decoded_string;
|
||
|
const bool huffman_result =
|
||
|
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
|
||
|
assert(huffman_result);
|
||
|
if (!huffman_result)
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read literal string";
|
||
|
|
||
|
if (decoded_string != "kMarkvNoneOfTheAbove") {
|
||
|
std::copy(decoded_string.begin(), decoded_string.end(),
|
||
|
std::back_inserter(str));
|
||
|
str.push_back('\0');
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// The loop is expected to terminate once we encounter '\0' or exhaust
|
||
|
// the bit stream.
|
||
|
if (str.empty()) {
|
||
|
while (true) {
|
||
|
char ch = 0;
|
||
|
if (!reader_.ReadUnencoded(&ch))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read literal string";
|
||
|
|
||
|
str.push_back(ch);
|
||
|
|
||
|
if (ch == '\0') break;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
while (str.size() % 4 != 0) str.push_back('\0');
|
||
|
|
||
|
inst_words_.resize(inst_words_.size() + str.size() / 4);
|
||
|
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
|
||
|
|
||
|
if (SpvOpExtInstImport == opcode) {
|
||
|
// Record the extended instruction type for the ID for this import.
|
||
|
// There is only one string literal argument to OpExtInstImport,
|
||
|
// so it's sufficient to guard this just on the opcode.
|
||
|
const spv_ext_inst_type_t ext_inst_type =
|
||
|
spvExtInstImportTypeGet(str.data());
|
||
|
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid extended instruction import '" << str.data()
|
||
|
<< "'";
|
||
|
}
|
||
|
// We must have parsed a valid result ID. It's a condition
|
||
|
// of the grammar, and we only accept non-zero result Ids.
|
||
|
assert(inst_.result_id);
|
||
|
const bool inserted =
|
||
|
import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
|
||
|
.second;
|
||
|
(void)inserted;
|
||
|
assert(inserted);
|
||
|
}
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
||
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
||
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
||
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
||
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
||
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
||
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
||
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
||
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
||
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
||
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
||
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
||
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
||
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
||
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
||
|
case SPV_OPERAND_TYPE_DECORATION:
|
||
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
||
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
||
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
||
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
|
||
|
// A single word that is a plain enum value.
|
||
|
uint32_t word = 0;
|
||
|
const spv_result_t result = DecodeNonIdWord(&word);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(word);
|
||
|
|
||
|
// Map an optional operand type to its corresponding concrete type.
|
||
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
|
||
|
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
|
||
|
|
||
|
spv_operand_desc entry;
|
||
|
if (grammar_.lookupOperand(type, word, &entry)) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
||
|
<< " operand: " << word;
|
||
|
}
|
||
|
|
||
|
// Prepare to accept operands to this operand, if needed.
|
||
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
||
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
||
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
||
|
case SPV_OPERAND_TYPE_IMAGE:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
||
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
||
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
|
||
|
// This operand is a mask.
|
||
|
uint32_t word = 0;
|
||
|
const spv_result_t result = DecodeNonIdWord(&word);
|
||
|
if (result != SPV_SUCCESS) return result;
|
||
|
|
||
|
inst_words_.push_back(word);
|
||
|
|
||
|
// Map an optional operand type to its corresponding concrete type.
|
||
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
|
||
|
operand_.type = SPV_OPERAND_TYPE_IMAGE;
|
||
|
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
||
|
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
||
|
|
||
|
// Check validity of set mask bits. Also prepare for operands for those
|
||
|
// masks if they have any. To get operand order correct, scan from
|
||
|
// MSB to LSB since we can only prepend operands to a pattern.
|
||
|
// The only case in the grammar where you have more than one mask bit
|
||
|
// having an operand is for image operands. See SPIR-V 3.14 Image
|
||
|
// Operands.
|
||
|
uint32_t remaining_word = word;
|
||
|
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
|
||
|
if (remaining_word & mask) {
|
||
|
spv_operand_desc entry;
|
||
|
if (grammar_.lookupOperand(type, mask, &entry)) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
||
|
<< " operand: " << word << " has invalid mask component "
|
||
|
<< mask;
|
||
|
}
|
||
|
remaining_word ^= mask;
|
||
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||
|
}
|
||
|
}
|
||
|
if (word == 0) {
|
||
|
// An all-zeroes mask *might* also be valid.
|
||
|
spv_operand_desc entry;
|
||
|
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
|
||
|
// Prepare for its operands, if any.
|
||
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
||
|
}
|
||
|
}
|
||
|
break;
|
||
|
}
|
||
|
default:
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Internal error: Unhandled operand type: " << type;
|
||
|
}
|
||
|
|
||
|
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
|
||
|
|
||
|
assert(spvOperandIsConcrete(operand_.type));
|
||
|
|
||
|
parsed_operands_.push_back(operand_);
|
||
|
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::DecodeInstruction() {
|
||
|
parsed_operands_.clear();
|
||
|
inst_words_.clear();
|
||
|
|
||
|
// Opcode/num_words placeholder, the word will be filled in later.
|
||
|
inst_words_.push_back(0);
|
||
|
|
||
|
bool num_operands_still_unknown = true;
|
||
|
{
|
||
|
uint32_t opcode = 0;
|
||
|
uint32_t num_operands = 0;
|
||
|
|
||
|
const spv_result_t opcode_decoding_result =
|
||
|
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
|
||
|
if (opcode_decoding_result < 0) return opcode_decoding_result;
|
||
|
|
||
|
if (opcode_decoding_result == SPV_SUCCESS) {
|
||
|
inst_.num_operands = static_cast<uint16_t>(num_operands);
|
||
|
num_operands_still_unknown = false;
|
||
|
} else {
|
||
|
if (!reader_.ReadVariableWidthU32(&opcode,
|
||
|
model_->opcode_chunk_length())) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read opcode of instruction";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inst_.opcode = static_cast<uint16_t>(opcode);
|
||
|
}
|
||
|
|
||
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||
|
|
||
|
spv_opcode_desc opcode_desc;
|
||
|
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
|
||
|
}
|
||
|
|
||
|
spv_operand_pattern_t expected_operands;
|
||
|
expected_operands.reserve(opcode_desc->numTypes);
|
||
|
for (auto i = 0; i < opcode_desc->numTypes; i++) {
|
||
|
expected_operands.push_back(
|
||
|
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
|
||
|
}
|
||
|
|
||
|
if (num_operands_still_unknown) {
|
||
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
||
|
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
|
||
|
model_->num_operands_chunk_length()))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Failed to read num_operands of instruction";
|
||
|
} else {
|
||
|
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (operand_index_ = 0;
|
||
|
operand_index_ < static_cast<size_t>(inst_.num_operands);
|
||
|
++operand_index_) {
|
||
|
assert(!expected_operands.empty());
|
||
|
const spv_operand_type_t type =
|
||
|
spvTakeFirstMatchableOperand(&expected_operands);
|
||
|
|
||
|
const size_t operand_offset = inst_words_.size();
|
||
|
|
||
|
const spv_result_t decode_result =
|
||
|
DecodeOperand(operand_offset, type, &expected_operands);
|
||
|
|
||
|
if (decode_result != SPV_SUCCESS) return decode_result;
|
||
|
}
|
||
|
|
||
|
assert(inst_.num_operands == parsed_operands_.size());
|
||
|
|
||
|
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
|
||
|
// next DecodeInstruction call).
|
||
|
inst_.words = inst_words_.data();
|
||
|
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
|
||
|
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
|
||
|
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
|
||
|
|
||
|
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
|
||
|
|
||
|
assert(inst_.num_words ==
|
||
|
std::accumulate(
|
||
|
parsed_operands_.begin(), parsed_operands_.end(), 1,
|
||
|
[](int num_words, const spv_parsed_operand_t& operand) {
|
||
|
return num_words += operand.num_words;
|
||
|
}) &&
|
||
|
"num_words in instruction doesn't correspond to the sum of num_words"
|
||
|
"in the operands");
|
||
|
|
||
|
RecordNumberType();
|
||
|
ProcessCurInstruction();
|
||
|
|
||
|
if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
|
||
|
|
||
|
if (logger_) {
|
||
|
logger_->NewLine();
|
||
|
std::stringstream ss;
|
||
|
ss << spvOpcodeString(opcode) << " ";
|
||
|
for (size_t index = 1; index < inst_words_.size(); ++index)
|
||
|
ss << inst_words_[index] << " ";
|
||
|
logger_->AppendText(ss.str());
|
||
|
logger_->NewLine();
|
||
|
logger_->NewLine();
|
||
|
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
|
||
|
}
|
||
|
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
|
||
|
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
|
||
|
assert(type_id != 0);
|
||
|
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
|
||
|
if (type_info_iter == type_id_to_number_type_info_.end()) {
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Type Id " << type_id << " is not a type";
|
||
|
}
|
||
|
|
||
|
const NumberType& info = type_info_iter->second;
|
||
|
if (info.type == SPV_NUMBER_NONE) {
|
||
|
// This is a valid type, but for something other than a scalar number.
|
||
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
||
|
<< "Type Id " << type_id << " is not a scalar numeric type";
|
||
|
}
|
||
|
|
||
|
parsed_operand->number_kind = info.type;
|
||
|
parsed_operand->number_bit_width = info.bit_width;
|
||
|
// Round up the word count.
|
||
|
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
|
||
|
return SPV_SUCCESS;
|
||
|
}
|
||
|
|
||
|
void MarkvDecoder::RecordNumberType() {
|
||
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
||
|
if (spvOpcodeGeneratesType(opcode)) {
|
||
|
NumberType info = {SPV_NUMBER_NONE, 0};
|
||
|
if (SpvOpTypeInt == opcode) {
|
||
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
||
|
info.type = inst_.words[inst_.operands[2].offset]
|
||
|
? SPV_NUMBER_SIGNED_INT
|
||
|
: SPV_NUMBER_UNSIGNED_INT;
|
||
|
} else if (SpvOpTypeFloat == opcode) {
|
||
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
||
|
info.type = SPV_NUMBER_FLOATING;
|
||
|
}
|
||
|
// The *result* Id of a type generating instruction is the type Id.
|
||
|
type_id_to_number_type_info_[inst_.result_id] = info;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
} // namespace comp
|
||
|
} // namespace spvtools
|