mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-05 22:41:07 +00:00
b6319c3a43
This CL breaks the monolithic markv_codec file into files for the base class, encoder, decoder and logger.
487 lines
16 KiB
C++
487 lines
16 KiB
C++
// Copyright (c) 2018 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/comp/markv_encoder.h"
|
|
|
|
#include "source/binary.h"
|
|
#include "source/opcode.h"
|
|
#include "spirv-tools/libspirv.hpp"
|
|
|
|
namespace spvtools {
|
|
namespace comp {
|
|
namespace {
|
|
|
|
const size_t kCommentNumWhitespaces = 2;
|
|
|
|
} // namespace
|
|
|
|
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
|
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
|
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// Encoding successful.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// Encoding failed, write kMarkvNoneOfTheAbove flag.
|
|
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
|
&num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Non-id word Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
|
|
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback encoding.
|
|
const size_t chunk_length =
|
|
model_->GetOperandVariableWidthChunkLength(operand_.type);
|
|
if (chunk_length) {
|
|
writer_.WriteVariableWidthU32(word, chunk_length);
|
|
} else {
|
|
writer_.WriteUnencoded(word);
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
|
|
uint32_t num_operands) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
|
|
const uint32_t word = opcode | (num_operands << 16);
|
|
|
|
// First try to use the Markov chain codec.
|
|
auto* codec =
|
|
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
|
if (codec) {
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and use fallback encoding.
|
|
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
|
&num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "opcode_and_num_operands Huffman table for "
|
|
<< spvOpcodeString(GetPrevOpcode())
|
|
<< "is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback to base-rate codec.
|
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
|
assert(codec);
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and return false.
|
|
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Global opcode_and_num_operands Huffman table is missing "
|
|
<< "kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
|
uint64_t fallback_method) {
|
|
const auto* codec = GetMtfHuffmanCodec(mtf);
|
|
if (!codec) {
|
|
assert(fallback_method != kMtfNone);
|
|
codec = GetMtfHuffmanCodec(fallback_method);
|
|
}
|
|
|
|
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
|
|
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) {
|
|
// Encode using Huffman coding.
|
|
if (!codec->Encode(rank, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode MTF rank with Huffman";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
} else {
|
|
// Encode by value.
|
|
if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits,
|
|
&num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode kMtfRankEncodedByValueSignal";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
writer_.WriteVariableWidthU32(
|
|
rank - MarkvCodec::kMtfSmallestRankEncodedByValue,
|
|
model_->mtf_rank_chunk_length());
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
|
|
// Get the descriptor for id.
|
|
const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
|
|
auto* codec =
|
|
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
uint64_t mtf = kMtfNone;
|
|
if (long_descriptor && codec &&
|
|
codec->Encode(long_descriptor, &bits, &num_bits)) {
|
|
// If the descriptor exists and is in the table, write the descriptor and
|
|
// proceed to encoding the rank.
|
|
writer_.WriteBits(bits, num_bits);
|
|
mtf = GetMtfLongIdDescriptor(long_descriptor);
|
|
} else {
|
|
if (codec) {
|
|
// The descriptor doesn't exist or we have no coding for it. Write
|
|
// kMarkvNoneOfTheAbove and go to fallback method.
|
|
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
|
|
&num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Descriptor Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
|
|
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
|
|
if (model_->id_fallback_strategy() !=
|
|
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
|
|
writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits);
|
|
|
|
if (short_descriptor == 0) {
|
|
// Forward declared id.
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
mtf = GetMtfShortIdDescriptor(short_descriptor);
|
|
}
|
|
|
|
// Descriptor has been encoded. Now encode the rank of the id in the
|
|
// associated mtf sequence.
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
|
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
|
// If the sequence has only one element no need to write rank, the decoder
|
|
// would make the same decision.
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
uint32_t rank = 0;
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
|
|
return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(id);
|
|
if (result != SPV_UNSUPPORTED) return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_);
|
|
uint32_t rank = 0;
|
|
|
|
if (model_->id_fallback_strategy() ==
|
|
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
|
// Encode using rule-based mtf.
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
|
|
if (mtf != kMtfNone && !can_forward_declare) {
|
|
assert(multi_mtf_.HasValue(kMtfAll, id));
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
if (mtf == kMtfNone) mtf = kMtfAll;
|
|
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
|
|
// This is the first occurrence of a forward declared id.
|
|
multi_mtf_.Insert(kMtfAll, id);
|
|
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
|
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
|
|
rank = 0;
|
|
}
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
|
|
} else {
|
|
assert(can_forward_declare);
|
|
|
|
if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
|
|
// This is the first occurrence of a forward declared id.
|
|
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
|
rank = 0;
|
|
}
|
|
|
|
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
|
|
return SPV_SUCCESS;
|
|
}
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeTypeId() {
|
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
|
assert(!remaining_function_parameter_types_.empty());
|
|
assert(inst_.type_id == remaining_function_parameter_types_.front());
|
|
remaining_function_parameter_types_.pop_front();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
|
|
if (result != SPV_UNSUPPORTED) return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
assert(model_->id_fallback_strategy() ==
|
|
MarkvModel::IdFallbackStrategy::kRuleBased);
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
|
|
operand_index_));
|
|
|
|
if (mtf == kMtfNone) {
|
|
mtf = kMtfTypeNonFunction;
|
|
// Function types should have been handled by GetRuleBasedMtf.
|
|
assert(inst_.opcode != SpvOpFunction);
|
|
}
|
|
|
|
return EncodeExistingId(mtf, inst_.type_id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeResultId() {
|
|
uint32_t rank = 0;
|
|
|
|
const uint64_t num_still_forward_declared =
|
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
|
|
|
if (num_still_forward_declared) {
|
|
// We write the rank only if kMtfForwardDeclared is not empty. If it is
|
|
// empty the decoder knows that there are no forward declared ids to expect.
|
|
if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
|
|
// This is a definition of a forward declared id. We can remove the id
|
|
// from kMtfForwardDeclared.
|
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to remove id from kMtfForwardDeclared";
|
|
writer_.WriteBits(1, 1);
|
|
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
|
|
} else {
|
|
rank = 0;
|
|
writer_.WriteBits(0, 1);
|
|
}
|
|
}
|
|
|
|
if (model_->id_fallback_strategy() ==
|
|
MarkvModel::IdFallbackStrategy::kRuleBased) {
|
|
if (!rank) {
|
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
|
}
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeLiteralNumber(
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width <= 32) {
|
|
const uint32_t word = inst_.words[operand.offset];
|
|
return EncodeNonIdWord(word);
|
|
} else {
|
|
assert(operand.number_bit_width <= 64);
|
|
const uint64_t word = uint64_t(inst_.words[operand.offset]) |
|
|
(uint64_t(inst_.words[operand.offset + 1]) << 32);
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
std::memcpy(&val, &word, 8);
|
|
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
|
}
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
|
|
const size_t num_bits_to_next_byte =
|
|
GetNumBitsToNextByte(writer_.GetNumBits());
|
|
if (num_bits_to_next_byte == 0 ||
|
|
num_bits_to_next_byte > byte_break_if_less_than)
|
|
return;
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<byte break>");
|
|
}
|
|
|
|
writer_.WriteBits(0, num_bits_to_next_byte);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeInstruction(
|
|
const spv_parsed_instruction_t& inst) {
|
|
SpvOp opcode = SpvOp(inst.opcode);
|
|
inst_ = inst;
|
|
|
|
LogDisassemblyInstruction();
|
|
|
|
const spv_result_t opcode_encodig_result =
|
|
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
|
|
if (opcode_encodig_result < 0) return opcode_encodig_result;
|
|
|
|
if (opcode_encodig_result != SPV_SUCCESS) {
|
|
// Fallback encoding for opcode and num_operands.
|
|
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
|
|
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
// If the opcode has a variable number of operands, encode the number of
|
|
// operands with the instruction.
|
|
|
|
if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
|
|
writer_.WriteVariableWidthU16(inst.num_operands,
|
|
model_->num_operands_chunk_length());
|
|
}
|
|
}
|
|
|
|
// Write operands.
|
|
const uint32_t num_operands = inst_.num_operands;
|
|
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
|
|
operand_ = inst_.operands[operand_index_];
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<");
|
|
logger_->AppendText(spvOperandTypeStr(operand_.type));
|
|
logger_->AppendText(">");
|
|
}
|
|
|
|
switch (operand_.type) {
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
const uint32_t id = inst_.words[operand_.offset];
|
|
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
|
|
const spv_result_t result = EncodeTypeId();
|
|
if (result != SPV_SUCCESS) return result;
|
|
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
|
|
const spv_result_t result = EncodeResultId();
|
|
if (result != SPV_SUCCESS) return result;
|
|
} else {
|
|
const spv_result_t result = EncodeRefId(id);
|
|
if (result != SPV_SUCCESS) return result;
|
|
}
|
|
|
|
PromoteIfNeeded(id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
|
|
const spv_result_t result =
|
|
EncodeNonIdWord(inst_.words[operand_.offset]);
|
|
if (result != SPV_SUCCESS) return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
|
|
const spv_result_t result = EncodeLiteralNumber(operand_);
|
|
if (result != SPV_SUCCESS) return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING: {
|
|
const char* src =
|
|
reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
|
|
|
|
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
const std::string str = src;
|
|
if (codec->Encode(str, &bits, &num_bits)) {
|
|
writer_.WriteBits(bits, num_bits);
|
|
break;
|
|
} else {
|
|
bool result =
|
|
codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
|
|
(void)result;
|
|
assert(result);
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
|
|
if (length == operand_.num_words * 4)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to find terminal character of literal string";
|
|
for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
|
|
break;
|
|
}
|
|
|
|
default: {
|
|
for (int i = 0; i < operand_.num_words; ++i) {
|
|
const uint32_t word = inst_.words[operand_.offset + i];
|
|
const spv_result_t result = EncodeNonIdWord(word);
|
|
if (result != SPV_SUCCESS) return result;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte);
|
|
|
|
if (logger_) {
|
|
logger_->NewLine();
|
|
logger_->NewLine();
|
|
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
|
|
}
|
|
|
|
ProcessCurInstruction();
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
} // namespace comp
|
|
} // namespace spvtools
|