mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-19 11:30:15 +00:00
9b14dd0cb4
- now includes a table of all descriptors with coding scheme (improves performance by 5% by allowing to avoid creation of move-to-front sequences which will never be used) - increased the size of markv_autogen.inc, clang doesn't seem to have the long compilation time problem now (probably was inadvertently fixed by using Huffman codec serialization)
3117 lines
103 KiB
C++
3117 lines
103 KiB
C++
// Copyright (c) 2017 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Contains
|
|
// - SPIR-V to MARK-V encoder
|
|
// - MARK-V to SPIR-V decoder
|
|
//
|
|
// MARK-V is a compression format for SPIR-V binaries. It strips away
|
|
// non-essential information (such as result ids which can be regenerated) and
|
|
// uses various bit reduction techiniques to reduce the size of the binary.
|
|
//
|
|
// MarkvModel is a flatbuffers object containing a set of rules defining how
|
|
// compression/decompression is done (coding schemes, dictionaries).
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <iterator>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
#include "spirv/1.2/GLSL.std.450.h"
|
|
#include "spirv/1.2/OpenCL.std.h"
|
|
#include "spirv/1.2/spirv.h"
|
|
|
|
#include "binary.h"
|
|
#include "diagnostic.h"
|
|
#include "enum_string_mapping.h"
|
|
#include "extensions.h"
|
|
#include "ext_inst.h"
|
|
#include "id_descriptor.h"
|
|
#include "instruction.h"
|
|
#include "markv_autogen.h"
|
|
#include "opcode.h"
|
|
#include "operand.h"
|
|
#include "spirv-tools/libspirv.h"
|
|
#include "spirv-tools/markv.h"
|
|
#include "spirv_endian.h"
|
|
#include "spirv_validator_options.h"
|
|
#include "util/bit_stream.h"
|
|
#include "util/huffman_codec.h"
|
|
#include "util/move_to_front.h"
|
|
#include "util/parse_number.h"
|
|
#include "validate.h"
|
|
#include "val/instruction.h"
|
|
#include "val/validation_state.h"
|
|
|
|
using libspirv::IdDescriptorCollection;
|
|
using libspirv::Instruction;
|
|
using libspirv::ValidationState_t;
|
|
using spvtools::ValidateInstructionAndUpdateValidationState;
|
|
using spvutils::BitReaderWord64;
|
|
using spvutils::BitWriterWord64;
|
|
using spvutils::HuffmanCodec;
|
|
using MoveToFront = spvutils::MoveToFront<uint32_t>;
|
|
using MultiMoveToFront = spvutils::MultiMoveToFront<uint32_t>;
|
|
|
|
struct spv_markv_encoder_options_t {
|
|
};
|
|
|
|
struct spv_markv_decoder_options_t {
|
|
};
|
|
|
|
namespace {
|
|
|
|
const uint32_t kSpirvMagicNumber = SpvMagicNumber;
|
|
const uint32_t kMarkvMagicNumber = 0x07230303;
|
|
|
|
// Handles for move-to-front sequences. Enums which end with "Begin" define
|
|
// handle spaces which start at that value and span 16 or 32 bit wide.
|
|
enum : uint64_t {
|
|
kMtfNone = 0,
|
|
// All ids.
|
|
kMtfAll,
|
|
// All forward declared ids.
|
|
kMtfForwardDeclared,
|
|
// All type ids except for generated by OpTypeFunction.
|
|
kMtfTypeNonFunction,
|
|
// All labels.
|
|
kMtfLabel,
|
|
// All ids created by instructions which had type_id.
|
|
kMtfObject,
|
|
// All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
|
|
kMtfTypeScalar,
|
|
// All composite types.
|
|
kMtfTypeComposite,
|
|
// Boolean type or any vector type of it.
|
|
kMtfTypeBoolScalarOrVector,
|
|
// All float types or any vector floats type.
|
|
kMtfTypeFloatScalarOrVector,
|
|
// All int types or any vector int type.
|
|
kMtfTypeIntScalarOrVector,
|
|
// All types declared as return types in OpTypeFunction.
|
|
kMtfTypeReturnedByFunction,
|
|
// All object ids which are integer constants.
|
|
kMtfConstInteger,
|
|
// All composite objects.
|
|
kMtfComposite,
|
|
// All bool objects or vectors of bools.
|
|
kMtfBoolScalarOrVector,
|
|
// All float objects or vectors of float.
|
|
kMtfFloatScalarOrVector,
|
|
// All int objects or vectors of int.
|
|
kMtfIntScalarOrVector,
|
|
// All pointer types which point to composited.
|
|
kMtfTypePointerToComposite,
|
|
// Used by EncodeMtfRankHuffman.
|
|
kMtfGenericNonZeroRank,
|
|
// Handle space for ids of specific type.
|
|
kMtfIdOfTypeBegin = 0x10000,
|
|
// Handle space for ids generated by specific opcode.
|
|
kMtfIdGeneratedByOpcode = 0x20000,
|
|
// Handle space for ids of objects with type generated by specific opcode.
|
|
kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
|
|
// All vectors of specific component type.
|
|
kMtfVectorOfComponentTypeBegin = 0x40000,
|
|
// All vector types of specific size.
|
|
kMtfTypeVectorOfSizeBegin = 0x50000,
|
|
// All pointer types to specific type.
|
|
kMtfPointerToTypeBegin = 0x60000,
|
|
// All function types which return specific type.
|
|
kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
|
|
// All function objects which return specific type.
|
|
kMtfFunctionWithReturnTypeBegin = 0x80000,
|
|
// All float vectors of specific size.
|
|
kMtfFloatVectorOfSizeBegin = 0x90000,
|
|
// Id descriptor space (32-bit).
|
|
kMtfIdDescriptorSpaceBegin = 0x100000000,
|
|
};
|
|
|
|
// Used by "presumed index" technique which does special treatment of integer
|
|
// constants no greater than this value.
|
|
const uint32_t kMarkvMaxPresumedAccessIndex = 31;
|
|
|
|
// Signals that the value is not in the coding scheme and a fallback method
|
|
// needs to be used.
|
|
const uint64_t kMarkvNoneOfTheAbove = GetMarkvNonOfTheAbove();
|
|
|
|
// Mtf ranks smaller than this are encoded with Huffman coding.
|
|
const uint32_t kMtfSmallestRankEncodedByValue = 10;
|
|
|
|
// Signals that the mtf rank is too large to be encoded with Huffman.
|
|
const uint32_t kMtfRankEncodedByValueSignal =
|
|
std::numeric_limits<uint32_t>::max();
|
|
|
|
const size_t kCommentNumWhitespaces = 2;
|
|
|
|
const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8;
|
|
|
|
// Returns a set of mtf rank codecs based on a plausible hand-coded
|
|
// distribution.
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
|
GetMtfHuffmanCodecs() {
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
|
|
|
|
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
|
|
|
|
codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
|
|
{ 0, 5 },
|
|
{ 1, 40 },
|
|
{ 2, 10 },
|
|
{ 3, 5 },
|
|
{ 4, 5 },
|
|
{ 5, 5 },
|
|
{ 6, 3 },
|
|
{ 7, 3 },
|
|
{ 8, 3 },
|
|
{ 9, 3 },
|
|
{ kMtfRankEncodedByValueSignal, 10 },
|
|
})));
|
|
codecs.emplace(kMtfAll, std::move(codec));
|
|
|
|
codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
|
|
{ 1, 50 },
|
|
{ 2, 20 },
|
|
{ 3, 5 },
|
|
{ 4, 5 },
|
|
{ 5, 2 },
|
|
{ 6, 1 },
|
|
{ 7, 1 },
|
|
{ 8, 1 },
|
|
{ 9, 1 },
|
|
{ kMtfRankEncodedByValueSignal, 10 },
|
|
})));
|
|
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
|
|
|
|
return codecs;
|
|
}
|
|
|
|
// Encoding/decoding model containing various constants and codecs.
|
|
class MarkvModel {
|
|
public:
|
|
MarkvModel()
|
|
: mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
|
|
opcode_and_num_operands_huffman_codec_(GetOpcodeAndNumOperandsHist()),
|
|
opcode_and_num_operands_markov_huffman_codecs_(
|
|
GetOpcodeAndNumOperandsMarkovHuffmanCodecs()),
|
|
non_id_word_huffman_codecs_(GetNonIdWordHuffmanCodecs()),
|
|
id_descriptor_huffman_codecs_(GetIdDescriptorHuffmanCodecs()),
|
|
descriptors_with_coding_scheme_(GetDescriptorsWithCodingScheme()),
|
|
literal_string_huffman_codecs_(GetLiteralStringHuffmanCodecs()) {}
|
|
|
|
size_t opcode_chunk_length() const { return 7; }
|
|
size_t num_operands_chunk_length() const { return 3; }
|
|
size_t mtf_rank_chunk_length() const { return 5; }
|
|
|
|
size_t u16_chunk_length() const { return 4; }
|
|
size_t s16_chunk_length() const { return 4; }
|
|
size_t s16_block_exponent() const { return 6; }
|
|
|
|
size_t u32_chunk_length() const { return 8; }
|
|
size_t s32_chunk_length() const { return 8; }
|
|
size_t s32_block_exponent() const { return 10; }
|
|
|
|
size_t u64_chunk_length() const { return 8; }
|
|
size_t s64_chunk_length() const { return 8; }
|
|
size_t s64_block_exponent() const { return 10; }
|
|
|
|
// Returns Huffman codec for ranks of the mtf with given |handle|.
|
|
// Different mtfs can use different rank distributions.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
|
|
const auto it = mtf_huffman_codecs_.find(handle);
|
|
if (it == mtf_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common opcode_and_num_operands words for the given
|
|
// previous opcode. May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetOpcodeAndNumOperandsMarkovHuffmanCodec(
|
|
uint32_t prev_opcode) const {
|
|
if (prev_opcode == SpvOpNop)
|
|
return &opcode_and_num_operands_huffman_codec_;
|
|
|
|
const auto it =
|
|
opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode);
|
|
if (it == opcode_and_num_operands_markov_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common non-id words used for given operand slot.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetNonIdWordHuffmanCodec(
|
|
uint32_t opcode, uint32_t operand_index) const {
|
|
const auto it = non_id_word_huffman_codecs_.find(
|
|
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
|
if (it == non_id_word_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common id descriptos used for given operand slot.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetIdDescriptorHuffmanCodec(
|
|
uint32_t opcode, uint32_t operand_index) const {
|
|
const auto it = id_descriptor_huffman_codecs_.find(
|
|
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
|
if (it == id_descriptor_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common strings used by the given opcode.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<std::string>* GetLiteralStringHuffmanCodec(
|
|
uint32_t opcode) const {
|
|
const auto it = literal_string_huffman_codecs_.find(opcode);
|
|
if (it == literal_string_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
bool DescriptorHasCodingScheme(uint32_t descriptor) const {
|
|
return descriptors_with_coding_scheme_.count(descriptor);
|
|
}
|
|
|
|
private:
|
|
// Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
|
|
// need to contain a different codec for every handle as most use one and the
|
|
// same.
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
|
mtf_huffman_codecs_;
|
|
|
|
// Huffman codec for base-rate of opcode_and_num_operands.
|
|
HuffmanCodec<uint64_t> opcode_and_num_operands_huffman_codec_;
|
|
|
|
// Huffman codecs for opcode_and_num_operands. The map key is previous opcode.
|
|
std::map<uint32_t, std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
opcode_and_num_operands_markov_huffman_codecs_;
|
|
|
|
// Huffman codecs for non-id single-word operand values.
|
|
// The map key is pair <opcode, operand_index>.
|
|
std::map<std::pair<uint32_t, uint32_t>,
|
|
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
non_id_word_huffman_codecs_;
|
|
|
|
// Huffman codecs for id descriptors. The map key is pair
|
|
// <opcode, operand_index>.
|
|
std::map<std::pair<uint32_t, uint32_t>,
|
|
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
id_descriptor_huffman_codecs_;
|
|
|
|
std::unordered_set<uint32_t> descriptors_with_coding_scheme_;
|
|
|
|
// Huffman codecs for literal strings. The map key is the opcode of the
|
|
// current instruction. This assumes, that there is no more than one literal
|
|
// string operand per instruction, but would still work even if this is not
|
|
// the case. Names and debug information strings are not collected.
|
|
std::map<uint32_t, std::unique_ptr<HuffmanCodec<std::string>>>
|
|
literal_string_huffman_codecs_;
|
|
};
|
|
|
|
const MarkvModel* GetDefaultModel() {
|
|
static MarkvModel model;
|
|
return &model;
|
|
}
|
|
|
|
// Returns chunk length used for variable length encoding of spirv operand
|
|
// words. Returns zero if operand type corresponds to potentially multiple
|
|
// words or a word which is not expected to profit from variable width encoding.
|
|
// Chunk length is selected based on the size of expected value.
|
|
// Most of these values will later be encoded with probability-based coding,
|
|
// but variable width integer coding is a good quick solution.
|
|
// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
|
|
size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
|
|
return 8;
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
|
return 6;
|
|
default:
|
|
return 0;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Returns true if the opcode has a fixed number of operands. May return a
|
|
// false negative.
|
|
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
|
|
switch (opcode) {
|
|
// TODO(atgoo@github.com) This is not a complete list.
|
|
case SpvOpNop:
|
|
case SpvOpName:
|
|
case SpvOpUndef:
|
|
case SpvOpSizeOf:
|
|
case SpvOpLine:
|
|
case SpvOpNoLine:
|
|
case SpvOpDecorationGroup:
|
|
case SpvOpExtension:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpMemoryModel:
|
|
case SpvOpCapability:
|
|
case SpvOpTypeVoid:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeMatrix:
|
|
case SpvOpTypeSampler:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeArray:
|
|
case SpvOpTypePointer:
|
|
case SpvOpConstantTrue:
|
|
case SpvOpConstantFalse:
|
|
case SpvOpLabel:
|
|
case SpvOpBranch:
|
|
case SpvOpFunction:
|
|
case SpvOpFunctionParameter:
|
|
case SpvOpFunctionEnd:
|
|
case SpvOpBitcast:
|
|
case SpvOpCopyObject:
|
|
case SpvOpTranspose:
|
|
case SpvOpSNegate:
|
|
case SpvOpFNegate:
|
|
case SpvOpIAdd:
|
|
case SpvOpFAdd:
|
|
case SpvOpISub:
|
|
case SpvOpFSub:
|
|
case SpvOpIMul:
|
|
case SpvOpFMul:
|
|
case SpvOpUDiv:
|
|
case SpvOpSDiv:
|
|
case SpvOpFDiv:
|
|
case SpvOpUMod:
|
|
case SpvOpSRem:
|
|
case SpvOpSMod:
|
|
case SpvOpFRem:
|
|
case SpvOpFMod:
|
|
case SpvOpVectorTimesScalar:
|
|
case SpvOpMatrixTimesScalar:
|
|
case SpvOpVectorTimesMatrix:
|
|
case SpvOpMatrixTimesVector:
|
|
case SpvOpMatrixTimesMatrix:
|
|
case SpvOpOuterProduct:
|
|
case SpvOpDot:
|
|
return true;
|
|
default:
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
size_t GetNumBitsToNextByte(size_t bit_pos) {
|
|
return (8 - (bit_pos % 8)) % 8;
|
|
}
|
|
|
|
// Defines and returns current MARK-V version.
|
|
uint32_t GetMarkvVersion() {
|
|
const uint32_t kVersionMajor = 1;
|
|
const uint32_t kVersionMinor = 2;
|
|
return kVersionMinor | (kVersionMajor << 16);
|
|
}
|
|
|
|
class CommentLogger {
|
|
public:
|
|
void AppendText(const std::string& str) {
|
|
Append(str);
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendTextNewLine(const std::string& str) {
|
|
Append(str);
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendBitSequence(const std::string& str) {
|
|
if (use_delimiter_)
|
|
Append("-");
|
|
Append(str);
|
|
use_delimiter_ = true;
|
|
}
|
|
|
|
void AppendWhitespaces(size_t num) {
|
|
Append(std::string(num, ' '));
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void NewLine() {
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
std::string GetText() const {
|
|
return ss_.str();
|
|
}
|
|
|
|
private:
|
|
void Append(const std::string& str) {
|
|
ss_ << str;
|
|
// std::cerr << str;
|
|
}
|
|
|
|
std::stringstream ss_;
|
|
|
|
// If true a delimiter will be appended before the next bit sequence.
|
|
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
|
|
bool use_delimiter_ = false;
|
|
};
|
|
|
|
// Creates spv_text object containing text from |str|.
|
|
// The returned value is owned by the caller and needs to be destroyed with
|
|
// spvTextDestroy.
|
|
spv_text CreateSpvText(const std::string& str) {
|
|
spv_text out = new spv_text_t();
|
|
assert(out);
|
|
char* cstr = new char[str.length() + 1];
|
|
assert(cstr);
|
|
std::strncpy(cstr, str.c_str(), str.length());
|
|
cstr[str.length()] = '\0';
|
|
out->str = cstr;
|
|
out->length = str.length();
|
|
return out;
|
|
}
|
|
|
|
// Base class for MARK-V encoder and decoder. Contains common functionality
|
|
// such as:
|
|
// - Validator connection and validation state.
|
|
// - SPIR-V grammar and helper functions.
|
|
class MarkvCodecBase {
|
|
public:
|
|
virtual ~MarkvCodecBase() {
|
|
spvValidatorOptionsDestroy(validator_options_);
|
|
}
|
|
|
|
MarkvCodecBase() = delete;
|
|
|
|
void SetModel(const MarkvModel* model) {
|
|
model_ = model;
|
|
}
|
|
|
|
protected:
|
|
struct MarkvHeader {
|
|
MarkvHeader() {
|
|
magic_number = kMarkvMagicNumber;
|
|
markv_version = GetMarkvVersion();
|
|
markv_model = 0;
|
|
markv_length_in_bits = 0;
|
|
spirv_version = 0;
|
|
spirv_generator = 0;
|
|
}
|
|
|
|
uint32_t magic_number;
|
|
uint32_t markv_version;
|
|
// Magic number to identify or verify MarkvModel used for encoding.
|
|
uint32_t markv_model;
|
|
uint32_t markv_length_in_bits;
|
|
uint32_t spirv_version;
|
|
uint32_t spirv_generator;
|
|
};
|
|
|
|
explicit MarkvCodecBase(spv_const_context context,
|
|
spv_validator_options validator_options)
|
|
: validator_options_(validator_options),
|
|
vstate_(context, validator_options_), grammar_(context),
|
|
model_(GetDefaultModel()) {}
|
|
|
|
// Validates a single instruction and updates validation state of the module.
|
|
spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
|
|
return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
|
|
}
|
|
|
|
// Returns instruction which created |id| or nullptr if such instruction was
|
|
// not registered.
|
|
const Instruction* GetDefInst(uint32_t id) const {
|
|
const auto it = vstate_.all_definitions().find(id);
|
|
if (it == vstate_.all_definitions().end())
|
|
return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
// Returns type id of vector type component.
|
|
uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
|
|
const auto it = vstate_.all_definitions().find(vector_type_id);
|
|
assert(it != vstate_.all_definitions().end());
|
|
const Instruction* type_inst = it->second;
|
|
assert(type_inst->opcode() == SpvOpTypeVector);
|
|
|
|
const uint32_t component_type =
|
|
type_inst->word(type_inst->operands()[1].offset);
|
|
return component_type;
|
|
}
|
|
|
|
// Returns mtf handle for ids of given type.
|
|
uint64_t GetMtfIdOfType(uint32_t type_id) const {
|
|
return kMtfIdOfTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for ids generated by given opcode.
|
|
uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
|
|
return kMtfIdGeneratedByOpcode + opcode;
|
|
}
|
|
|
|
// Returns mtf handle for ids of type generated by given opcode.
|
|
uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
|
|
return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
|
|
}
|
|
|
|
// Returns mtf handle for vectors of specific component type.
|
|
uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
|
|
return kMtfVectorOfComponentTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for float vectors of specific size.
|
|
uint64_t GetMtfFloatVectorOfSize(uint32_t size) const {
|
|
return kMtfFloatVectorOfSizeBegin + size;
|
|
}
|
|
|
|
// Returns mtf handle for vector type of specific size.
|
|
uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
|
|
return kMtfTypeVectorOfSizeBegin + size;
|
|
}
|
|
|
|
// Returns mtf handle for pointers to specific size.
|
|
uint64_t GetMtfPointerToType(uint32_t type_id) const {
|
|
return kMtfPointerToTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for function types with given return type.
|
|
uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
|
|
return kMtfFunctionTypeWithReturnTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for functions with given return type.
|
|
uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
|
|
return kMtfFunctionWithReturnTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for the given id descriptor.
|
|
uint64_t GetMtfIdDescriptor(uint32_t descriptor) const {
|
|
return kMtfIdDescriptorSpaceBegin + descriptor;
|
|
}
|
|
|
|
// Process data from the current instruction. This would update MTFs and
|
|
// other data containers.
|
|
void ProcessCurInstruction();
|
|
|
|
// Returns move-to-front handle to be used for the current operand slot.
|
|
// Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
|
|
uint64_t GetRuleBasedMtf();
|
|
|
|
// Returns words of the current instruction. Decoder has a different
|
|
// implementation and the array is valid only until the previously decoded
|
|
// word.
|
|
virtual const uint32_t* GetInstWords() const {
|
|
return inst_.words;
|
|
}
|
|
|
|
// Returns the opcode of the previous instruction.
|
|
SpvOp GetPrevOpcode() const {
|
|
if (instructions_.empty())
|
|
return SpvOpNop;
|
|
|
|
return instructions_.back()->opcode();
|
|
}
|
|
|
|
spv_validator_options validator_options_ = nullptr;
|
|
ValidationState_t vstate_;
|
|
const libspirv::AssemblyGrammar grammar_;
|
|
MarkvHeader header_;
|
|
const MarkvModel* model_ = nullptr;
|
|
|
|
// Current instruction, current operand and current operand index.
|
|
spv_parsed_instruction_t inst_;
|
|
spv_parsed_operand_t operand_;
|
|
uint32_t operand_index_;
|
|
|
|
// Maps a result ID to its type ID. By convention:
|
|
// - a result ID that is a type definition maps to itself.
|
|
// - a result ID without a type maps to 0. (E.g. for OpLabel)
|
|
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
|
|
|
|
// Container for all move-to-front sequences.
|
|
MultiMoveToFront multi_mtf_;
|
|
|
|
// Id of the current function or zero if outside of function.
|
|
uint32_t cur_function_id_ = 0;
|
|
|
|
// Return type of the current function.
|
|
uint32_t cur_function_return_type_ = 0;
|
|
|
|
// Remaining function parameter types. This container is filled on OpFunction,
|
|
// and drained on OpFunctionParameter.
|
|
std::list<uint32_t> remaining_function_parameter_types_;
|
|
|
|
// List of ids local to the current function.
|
|
std::vector<uint32_t> ids_local_to_cur_function_;
|
|
|
|
// List of instructions in the order they are given in the module.
|
|
std::vector<const Instruction*> instructions_;
|
|
|
|
// Maps used for the 'presumed id' techniques. Maps small constant integer
|
|
// value to its id and back.
|
|
std::map<uint32_t, uint32_t> presumed_index_to_id_;
|
|
std::map<uint32_t, uint32_t> id_to_presumed_index_;
|
|
|
|
// Container/computer for id descriptors.
|
|
IdDescriptorCollection id_descriptors_;
|
|
};
|
|
|
|
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
|
|
// EncodeInstruction which can be used as callback by spvBinaryParse.
|
|
// Encoded binary is written to an internally maintained bitstream.
|
|
// After the last instruction is encoded, the resulting MARK-V binary can be
|
|
// acquired by calling GetMarkvBinary().
|
|
// The encoder uses SPIR-V validator to keep internal state, therefore
|
|
// SPIR-V binary needs to be able to pass validator checks.
|
|
// CreateCommentsLogger() can be used to enable the encoder to write comments
|
|
// on how encoding was done, which can later be accessed with GetComments().
|
|
class MarkvEncoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvEncoder(spv_const_context context,
|
|
spv_const_markv_encoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options) {
|
|
(void) options_;
|
|
}
|
|
|
|
// Writes data from SPIR-V header to MARK-V header.
|
|
spv_result_t EncodeHeader(
|
|
spv_endianness_t /* endian */, uint32_t /* magic */,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t /* schema */) {
|
|
vstate_.setIdBound(id_bound);
|
|
header_.spirv_version = version;
|
|
header_.spirv_generator = generator;
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
|
|
// Operation can fail if the instruction fails to pass the validator or if
|
|
// the encoder stubmles on something unexpected.
|
|
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
|
|
|
|
// Concatenates MARK-V header and the bit stream with encoded instructions
|
|
// into a single buffer and returns it as spv_markv_binary. The returned
|
|
// value is owned by the caller and needs to be destroyed with
|
|
// spvMarkvBinaryDestroy().
|
|
spv_markv_binary GetMarkvBinary() {
|
|
header_.markv_length_in_bits =
|
|
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
|
|
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
|
|
|
|
spv_markv_binary markv_binary = new spv_markv_binary_t();
|
|
markv_binary->data = new uint8_t[num_bytes];
|
|
markv_binary->length = num_bytes;
|
|
assert(writer_.GetData());
|
|
std::memcpy(markv_binary->data, &header_, sizeof(header_));
|
|
std::memcpy(markv_binary->data + sizeof(header_),
|
|
writer_.GetData(), writer_.GetDataSizeBytes());
|
|
return markv_binary;
|
|
}
|
|
|
|
// Creates an internal logger which writes comments on the encoding process.
|
|
// Output can later be accessed with GetComments().
|
|
void CreateCommentsLogger() {
|
|
logger_.reset(new CommentLogger());
|
|
writer_.SetCallback([this](const std::string& str){
|
|
logger_->AppendBitSequence(str);
|
|
});
|
|
}
|
|
|
|
// Optionally adds disassembly to the comments.
|
|
// Disassembly should contain all instructions in the module separated by
|
|
// \n, and no header.
|
|
void SetDisassembly(std::string&& disassembly) {
|
|
disassembly_.reset(new std::stringstream(std::move(disassembly)));
|
|
}
|
|
|
|
// Extracts the next instruction line from the disassembly and logs it.
|
|
void LogDisassemblyInstruction() {
|
|
if (logger_ && disassembly_) {
|
|
std::string line;
|
|
std::getline(*disassembly_, line, '\n');
|
|
logger_->AppendTextNewLine(line);
|
|
}
|
|
}
|
|
|
|
// Extracts the text from the comment logger.
|
|
std::string GetComments() const {
|
|
if (!logger_)
|
|
return "";
|
|
return logger_->GetText();
|
|
}
|
|
|
|
private:
|
|
// Creates and returns validator options. Return value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_encoder_options) {
|
|
return spvValidatorOptionsCreate();
|
|
}
|
|
|
|
// Writes a single word to bit stream. operand_.type determines if the word is
|
|
// encoded and how.
|
|
spv_result_t EncodeNonIdWord(uint32_t word);
|
|
|
|
// Writes both opcode and num_operands as a single code.
|
|
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
|
spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, uint32_t num_operands);
|
|
|
|
// Writes mtf rank to bit stream. |mtf| is used to determine the codec
|
|
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
|
spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
|
uint64_t fallback_method);
|
|
|
|
// Writes id using coding based on mtf associated with the id descriptor.
|
|
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
|
spv_result_t EncodeIdWithDescriptor(uint32_t id);
|
|
|
|
// Writes id using coding based on the given |mtf|, which is expected to
|
|
// contain the given |id|.
|
|
spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
|
|
|
|
// Writes type id of the current instruction if can't be inferred.
|
|
spv_result_t EncodeTypeId();
|
|
|
|
// Writes result id of the current instruction if can't be inferred.
|
|
spv_result_t EncodeResultId();
|
|
|
|
// Writes ids which are neither type nor result ids.
|
|
spv_result_t EncodeRefId(uint32_t id);
|
|
|
|
// Writes bits to the stream until the beginning of the next byte if the
|
|
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
|
void AddByteBreak(size_t byte_break_if_less_than);
|
|
|
|
// Encodes a literal number operand and writes it to the bit stream.
|
|
spv_result_t EncodeLiteralNumber(const Instruction& instruction,
|
|
const spv_parsed_operand_t& operand);
|
|
|
|
spv_const_markv_encoder_options options_;
|
|
|
|
// Bit stream where encoded instructions are written.
|
|
BitWriterWord64 writer_;
|
|
|
|
// If not nullptr, encoder will write comments.
|
|
std::unique_ptr<CommentLogger> logger_;
|
|
|
|
// If not nullptr, disassembled instruction lines will be written to comments.
|
|
// Format: \n separated instruction lines, no header.
|
|
std::unique_ptr<std::stringstream> disassembly_;
|
|
};
|
|
|
|
// Decodes MARK-V buffers written by MarkvEncoder.
|
|
class MarkvDecoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvDecoder(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options), reader_(markv_data, markv_size_bytes) {
|
|
(void) options_;
|
|
vstate_.setIdBound(1);
|
|
parsed_operands_.reserve(25);
|
|
inst_words_.reserve(25);
|
|
}
|
|
|
|
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
|
|
// Can be called only once. Fails if data of wrong format or ends prematurely,
|
|
// of if validation fails.
|
|
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
|
|
|
|
private:
|
|
// Describes the format of a typed literal number.
|
|
struct NumberType {
|
|
spv_number_kind_t type;
|
|
uint32_t bit_width;
|
|
};
|
|
|
|
// Creates and returns validator options. Return value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_decoder_options) {
|
|
return spvValidatorOptionsCreate();
|
|
}
|
|
|
|
// Reads a single bit from reader_. The read bit is stored in |bit|.
|
|
// Returns false iff reader_ fails.
|
|
bool ReadBit(bool* bit) {
|
|
uint64_t bits = 0;
|
|
const bool result = reader_.ReadBits(&bits, 1);
|
|
if (result)
|
|
*bit = bits ? true : false;
|
|
return result;
|
|
};
|
|
|
|
// Returns ReadBit bound to the class object.
|
|
std::function<bool(bool*)> GetReadBitCallback() {
|
|
return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
|
|
}
|
|
|
|
// Reads a single non-id word from bit stream. operand_.type determines if
|
|
// the word needs to be decoded and how.
|
|
spv_result_t DecodeNonIdWord(uint32_t* word);
|
|
|
|
// Reads and decodes both opcode and num_operands as a single code.
|
|
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
|
spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
|
|
uint32_t* num_operands);
|
|
|
|
// Reads mtf rank from bit stream. |mtf| is used to determine the codec
|
|
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
|
spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
|
|
uint32_t* rank);
|
|
|
|
// Reads id using coding based on mtf associated with the id descriptor.
|
|
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
|
spv_result_t DecodeIdWithDescriptor(uint32_t* id);
|
|
|
|
// Reads id using coding based on the given |mtf|, which is expected to
|
|
// contain the needed |id|.
|
|
spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
|
|
|
|
// Reads type id of the current instruction if can't be inferred.
|
|
spv_result_t DecodeTypeId();
|
|
|
|
// Reads result id of the current instruction if can't be inferred.
|
|
spv_result_t DecodeResultId();
|
|
|
|
// Reads id which is neither type nor result id.
|
|
spv_result_t DecodeRefId(uint32_t* id);
|
|
|
|
// Reads and discards bits until the beginning of the next byte if the
|
|
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
|
bool ReadToByteBreak(size_t byte_break_if_less_than);
|
|
|
|
// Returns instruction words decoded up to this point.
|
|
const uint32_t* GetInstWords() const override {
|
|
return inst_words_.data();
|
|
}
|
|
|
|
// Reads a literal number as it is described in |operand| from the bit stream,
|
|
// decodes and writes it to spirv_.
|
|
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
|
|
|
|
// Reads instruction from bit stream, decodes and validates it.
|
|
// Decoded instruction is valid until the next call of DecodeInstruction().
|
|
spv_result_t DecodeInstruction();
|
|
|
|
// Read operand from the stream decodes and validates it.
|
|
spv_result_t DecodeOperand(size_t operand_offset,
|
|
const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands);
|
|
|
|
// Records the numeric type for an operand according to the type information
|
|
// associated with the given non-zero type Id. This can fail if the type Id
|
|
// is not a type Id, or if the type Id does not reference a scalar numeric
|
|
// type. On success, return SPV_SUCCESS and populates the num_words,
|
|
// number_kind, and number_bit_width fields of parsed_operand.
|
|
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
|
|
uint32_t type_id);
|
|
|
|
// Records the number type for the current instruction, if it generates a
|
|
// type. For types that aren't scalar numbers, record something with number
|
|
// kind SPV_NUMBER_NONE.
|
|
void RecordNumberType();
|
|
|
|
spv_const_markv_decoder_options options_;
|
|
|
|
// Temporary sink where decoded SPIR-V words are written. Once it contains the
|
|
// entire module, the container is moved and returned.
|
|
std::vector<uint32_t> spirv_;
|
|
|
|
// Bit stream containing encoded data.
|
|
BitReaderWord64 reader_;
|
|
|
|
// Temporary storage for operands of the currently parsed instruction.
|
|
// Valid until next DecodeInstruction call.
|
|
std::vector<spv_parsed_operand_t> parsed_operands_;
|
|
|
|
// Temporary storage for current instruction words.
|
|
// Valid until next DecodeInstruction call.
|
|
std::vector<uint32_t> inst_words_;
|
|
|
|
// Maps a type ID to its number type description.
|
|
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
|
|
|
|
// Maps an ExtInstImport id to the extended instruction type.
|
|
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
|
|
};
|
|
|
|
void MarkvCodecBase::ProcessCurInstruction() {
|
|
const SpvOp opcode = SpvOp(inst_.opcode);
|
|
|
|
if (inst_.result_id) {
|
|
// Collect ids local to the current function.
|
|
if (cur_function_id_){
|
|
ids_local_to_cur_function_.push_back(inst_.result_id);
|
|
}
|
|
|
|
// Starting new function.
|
|
if (opcode == SpvOpFunction) {
|
|
cur_function_id_ = inst_.result_id;
|
|
cur_function_return_type_ = inst_.type_id;
|
|
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
|
|
inst_.result_id);
|
|
|
|
// Store function parameter types in a queue, so that we know which types
|
|
// to expect in the following OpFunctionParameter instructions.
|
|
const Instruction* def_inst = GetDefInst(inst_.words[4]);
|
|
assert(def_inst);
|
|
assert(def_inst->opcode() == SpvOpTypeFunction);
|
|
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
|
|
remaining_function_parameter_types_.push_back(def_inst->word(i));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove local ids from MTFs if function end.
|
|
if (opcode == SpvOpFunctionEnd) {
|
|
cur_function_id_ = 0;
|
|
for (uint32_t id : ids_local_to_cur_function_)
|
|
multi_mtf_.RemoveFromAll(id);
|
|
ids_local_to_cur_function_.clear();
|
|
assert(remaining_function_parameter_types_.empty());
|
|
}
|
|
|
|
if (!inst_.result_id)
|
|
return;
|
|
|
|
{
|
|
// Save the result ID to type ID mapping.
|
|
// In the grammar, type ID always appears before result ID.
|
|
// A regular value maps to its type. Some instructions (e.g. OpLabel)
|
|
// have no type Id, and will map to 0. The result Id for a
|
|
// type-generating instruction (e.g. OpTypeInt) maps to itself.
|
|
auto insertion_result = id_to_type_id_.emplace(
|
|
inst_.result_id,
|
|
spvOpcodeGeneratesType(SpvOp(inst_.opcode)) ?
|
|
inst_.result_id : inst_.type_id);
|
|
(void)insertion_result;
|
|
assert(insertion_result.second);
|
|
}
|
|
|
|
// Add result_id to MTFs.
|
|
|
|
switch (opcode) {
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypePointer:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeImage:
|
|
case SpvOpTypeSampler:
|
|
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (spvOpcodeIsComposite(opcode)) {
|
|
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpLabel) {
|
|
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeInt) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeFloat) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeBool) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeVector) {
|
|
const uint32_t component_type_id = inst_.words[2];
|
|
const uint32_t size = inst_.words[3];
|
|
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
|
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
|
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
|
}
|
|
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
|
|
}
|
|
|
|
if (inst_.opcode == SpvOpTypeFunction) {
|
|
const uint32_t return_type = inst_.words[2];
|
|
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
|
|
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
|
|
inst_.result_id);
|
|
}
|
|
|
|
if (inst_.type_id) {
|
|
const Instruction* type_inst = GetDefInst(inst_.type_id);
|
|
assert(type_inst);
|
|
|
|
multi_mtf_.Insert(kMtfObject, inst_.result_id);
|
|
|
|
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
|
|
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
|
|
|
|
if (inst_.opcode == SpvOpConstant) {
|
|
if (multi_mtf_.HasValue(
|
|
GetMtfIdGeneratedByOpcode(SpvOpTypeInt), inst_.type_id)) {
|
|
multi_mtf_.Insert(kMtfConstInteger, inst_.result_id);
|
|
const uint32_t presumed_index = inst_.words[3];
|
|
if (presumed_index <= kMarkvMaxPresumedAccessIndex) {
|
|
const auto result =
|
|
presumed_index_to_id_.emplace(presumed_index, inst_.result_id);
|
|
if (result.second) {
|
|
id_to_presumed_index_.emplace(inst_.result_id, presumed_index);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
switch (type_inst->opcode()) {
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypePointer:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeImage:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeSampler:
|
|
multi_mtf_.Insert(GetMtfIdWithTypeGeneratedByOpcode(
|
|
type_inst->opcode()), inst_.result_id);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (type_inst->opcode() == SpvOpTypeVector) {
|
|
const uint32_t component_type = type_inst->word(2);
|
|
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
|
|
inst_.result_id);
|
|
}
|
|
|
|
if (type_inst->opcode() == SpvOpTypePointer) {
|
|
assert(type_inst->operands().size() > 2);
|
|
assert(type_inst->words().size() > type_inst->operands()[2].offset);
|
|
const uint32_t data_type =
|
|
type_inst->word(type_inst->operands()[2].offset);
|
|
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
|
|
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
|
|
}
|
|
}
|
|
|
|
if (spvOpcodeGeneratesType(opcode)) {
|
|
if (opcode != SpvOpTypeFunction) {
|
|
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
|
|
}
|
|
}
|
|
|
|
const uint32_t descriptor = id_descriptors_.ProcessInstruction(inst_);
|
|
if (model_->DescriptorHasCodingScheme(descriptor))
|
|
multi_mtf_.Insert(GetMtfIdDescriptor(descriptor), inst_.result_id);
|
|
}
|
|
|
|
uint64_t MarkvCodecBase::GetRuleBasedMtf() {
|
|
// This function is only called for id operands (but not result ids).
|
|
assert(spvIsIdType(operand_.type) ||
|
|
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
|
|
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
|
|
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
// All operand slots which expect label id.
|
|
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
|
|
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
|
|
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
|
|
(inst_.opcode == SpvOpBranchConditional &&
|
|
(operand_index_ == 1 || operand_index_ == 2 )) ||
|
|
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
|
|
operand_index_ % 2 == 1) ||
|
|
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
|
|
return kMtfLabel;
|
|
}
|
|
|
|
switch (opcode) {
|
|
case SpvOpFAdd:
|
|
case SpvOpFSub:
|
|
case SpvOpFMul:
|
|
case SpvOpFDiv:
|
|
case SpvOpFRem:
|
|
case SpvOpFMod:
|
|
case SpvOpFNegate: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeFloatScalarOrVector;
|
|
|
|
return GetMtfIdOfType(inst_.type_id);
|
|
}
|
|
|
|
case SpvOpISub:
|
|
case SpvOpIAdd:
|
|
case SpvOpIMul:
|
|
case SpvOpSDiv:
|
|
case SpvOpUDiv:
|
|
case SpvOpSMod:
|
|
case SpvOpUMod:
|
|
case SpvOpSRem:
|
|
case SpvOpSNegate: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeIntScalarOrVector;
|
|
|
|
return kMtfIntScalarOrVector;
|
|
}
|
|
|
|
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
|
|
|
|
case SpvOpFOrdEqual:
|
|
case SpvOpFUnordEqual:
|
|
case SpvOpFOrdNotEqual:
|
|
case SpvOpFUnordNotEqual:
|
|
case SpvOpFOrdLessThan:
|
|
case SpvOpFUnordLessThan:
|
|
case SpvOpFOrdGreaterThan:
|
|
case SpvOpFUnordGreaterThan:
|
|
case SpvOpFOrdLessThanEqual:
|
|
case SpvOpFUnordLessThanEqual:
|
|
case SpvOpFOrdGreaterThanEqual:
|
|
case SpvOpFUnordGreaterThanEqual: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeBoolScalarOrVector;
|
|
if (operand_index_ == 2)
|
|
return kMtfFloatScalarOrVector;
|
|
if (operand_index_ == 3) {
|
|
const uint32_t first_operand_id = GetInstWords()[3];
|
|
const uint32_t first_operand_type =
|
|
id_to_type_id_.at(first_operand_id);
|
|
return GetMtfIdOfType(first_operand_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpVectorShuffle: {
|
|
if (operand_index_ == 0) {
|
|
assert(inst_.num_operands > 4);
|
|
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
|
|
}
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2 || operand_index_ == 3)
|
|
return GetMtfVectorOfComponentType(
|
|
GetVectorComponentType(inst_.type_id));
|
|
break;
|
|
}
|
|
|
|
case SpvOpVectorTimesScalar: {
|
|
if (operand_index_ == 0) {
|
|
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
}
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdOfType(inst_.type_id);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
|
|
break;
|
|
}
|
|
|
|
case SpvOpDot: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2)
|
|
return GetMtfVectorOfComponentType(inst_.type_id);
|
|
if (operand_index_ == 3) {
|
|
const uint32_t vector_id = GetInstWords()[3];
|
|
const uint32_t vector_type = id_to_type_id_.at(vector_id);
|
|
return GetMtfIdOfType(vector_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeVector: {
|
|
if (operand_index_ == 1) {
|
|
return kMtfTypeScalar;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeMatrix: {
|
|
if (operand_index_ == 1) {
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypePointer: {
|
|
if (operand_index_ == 2) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeStruct: {
|
|
if (operand_index_ >= 1) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeFunction: {
|
|
if (operand_index_ == 1) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
|
|
if (operand_index_ >= 2) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpLoad: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeNonFunction;
|
|
|
|
if (operand_index_ == 2) {
|
|
assert(inst_.type_id);
|
|
return GetMtfPointerToType(inst_.type_id);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpStore: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
|
|
if (operand_index_ == 1) {
|
|
const uint32_t pointer_id = GetInstWords()[1];
|
|
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
|
|
const auto it = vstate_.all_definitions().find(pointer_type);
|
|
assert(it != vstate_.all_definitions().end());
|
|
const Instruction* pointer_inst = it->second;
|
|
|
|
assert(pointer_inst->opcode() == SpvOpTypePointer);
|
|
const uint32_t data_type =
|
|
pointer_inst->word(pointer_inst->operands()[2].offset);
|
|
return GetMtfIdOfType(data_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpVariable: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
|
break;
|
|
}
|
|
|
|
case SpvOpAccessChain: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
|
if (operand_index_ == 2)
|
|
return kMtfTypePointerToComposite;
|
|
if (operand_index_ >= 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
|
|
break;
|
|
}
|
|
|
|
case SpvOpCompositeConstruct: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeComposite;
|
|
if (operand_index_ >= 2) {
|
|
const uint32_t composite_type = GetInstWords()[1];
|
|
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
|
|
return kMtfFloatScalarOrVector;
|
|
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
|
|
return kMtfIntScalarOrVector;
|
|
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
|
|
return kMtfBoolScalarOrVector;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpCompositeExtract: {
|
|
if (operand_index_ == 2)
|
|
return kMtfComposite;
|
|
break;
|
|
}
|
|
|
|
case SpvOpConstantComposite: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeComposite;
|
|
if (operand_index_ >= 2) {
|
|
const Instruction* composite_type_inst = GetDefInst(inst_.type_id);
|
|
assert(composite_type_inst);
|
|
if (composite_type_inst->opcode() == SpvOpTypeVector) {
|
|
return GetMtfIdOfType(composite_type_inst->word(2));
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpExtInst: {
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
|
|
if (operand_index_ >= 4) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
const uint32_t ext_inst_type = inst_.ext_inst_type;
|
|
const uint32_t ext_inst_index = GetInstWords()[4];
|
|
// TODO(atgoo@github.com) The list of extended instructions is
|
|
// incomplete. Only common instructions and low-hanging fruits listed.
|
|
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
|
|
switch (ext_inst_index) {
|
|
case GLSLstd450FAbs:
|
|
case GLSLstd450FClamp:
|
|
case GLSLstd450FMax:
|
|
case GLSLstd450FMin:
|
|
case GLSLstd450FMix:
|
|
case GLSLstd450Step:
|
|
case GLSLstd450SmoothStep:
|
|
case GLSLstd450Fma:
|
|
case GLSLstd450Pow:
|
|
case GLSLstd450Exp:
|
|
case GLSLstd450Exp2:
|
|
case GLSLstd450Log:
|
|
case GLSLstd450Log2:
|
|
case GLSLstd450Sqrt:
|
|
case GLSLstd450InverseSqrt:
|
|
case GLSLstd450Fract:
|
|
case GLSLstd450Floor:
|
|
case GLSLstd450Ceil:
|
|
case GLSLstd450Radians:
|
|
case GLSLstd450Degrees:
|
|
case GLSLstd450Sin:
|
|
case GLSLstd450Cos:
|
|
case GLSLstd450Tan:
|
|
case GLSLstd450Sinh:
|
|
case GLSLstd450Cosh:
|
|
case GLSLstd450Tanh:
|
|
case GLSLstd450Asin:
|
|
case GLSLstd450Acos:
|
|
case GLSLstd450Atan:
|
|
case GLSLstd450Atan2:
|
|
case GLSLstd450Asinh:
|
|
case GLSLstd450Acosh:
|
|
case GLSLstd450Atanh:
|
|
case GLSLstd450MatrixInverse:
|
|
case GLSLstd450Cross:
|
|
case GLSLstd450Normalize:
|
|
case GLSLstd450Reflect:
|
|
case GLSLstd450FaceForward:
|
|
return GetMtfIdOfType(return_type);
|
|
case GLSLstd450Length:
|
|
case GLSLstd450Distance:
|
|
case GLSLstd450Refract:
|
|
return kMtfFloatScalarOrVector;
|
|
default:
|
|
break;
|
|
}
|
|
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
|
|
switch (ext_inst_index) {
|
|
case OpenCLLIB::Fabs:
|
|
case OpenCLLIB::FClamp:
|
|
case OpenCLLIB::Fmax:
|
|
case OpenCLLIB::Fmin:
|
|
case OpenCLLIB::Step:
|
|
case OpenCLLIB::Smoothstep:
|
|
case OpenCLLIB::Fma:
|
|
case OpenCLLIB::Pow:
|
|
case OpenCLLIB::Exp:
|
|
case OpenCLLIB::Exp2:
|
|
case OpenCLLIB::Log:
|
|
case OpenCLLIB::Log2:
|
|
case OpenCLLIB::Sqrt:
|
|
case OpenCLLIB::Rsqrt:
|
|
case OpenCLLIB::Fract:
|
|
case OpenCLLIB::Floor:
|
|
case OpenCLLIB::Ceil:
|
|
case OpenCLLIB::Radians:
|
|
case OpenCLLIB::Degrees:
|
|
case OpenCLLIB::Sin:
|
|
case OpenCLLIB::Cos:
|
|
case OpenCLLIB::Tan:
|
|
case OpenCLLIB::Sinh:
|
|
case OpenCLLIB::Cosh:
|
|
case OpenCLLIB::Tanh:
|
|
case OpenCLLIB::Asin:
|
|
case OpenCLLIB::Acos:
|
|
case OpenCLLIB::Atan:
|
|
case OpenCLLIB::Atan2:
|
|
case OpenCLLIB::Asinh:
|
|
case OpenCLLIB::Acosh:
|
|
case OpenCLLIB::Atanh:
|
|
case OpenCLLIB::Cross:
|
|
case OpenCLLIB::Normalize:
|
|
return GetMtfIdOfType(return_type);
|
|
case OpenCLLIB::Length:
|
|
case OpenCLLIB::Distance:
|
|
return kMtfFloatScalarOrVector;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpFunction: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeReturnedByFunction;
|
|
|
|
if (operand_index_ == 3) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
return GetMtfFunctionTypeWithReturnType(return_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpFunctionCall: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeReturnedByFunction;
|
|
|
|
if (operand_index_ == 2) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
return GetMtfFunctionWithReturnType(return_type);
|
|
}
|
|
|
|
if (operand_index_ >= 3) {
|
|
const uint32_t function_id = GetInstWords()[3];
|
|
const auto function_it = vstate_.all_definitions().find(function_id);
|
|
if (function_it == vstate_.all_definitions().end())
|
|
return kMtfObject;
|
|
|
|
const Instruction* function_inst = function_it->second;
|
|
assert(function_inst->opcode() == SpvOpFunction);
|
|
|
|
const uint32_t function_type_id = function_inst->word(4);
|
|
const auto function_type_it =
|
|
vstate_.all_definitions().find(function_type_id);
|
|
assert(function_type_it != vstate_.all_definitions().end());
|
|
const Instruction* function_type_inst = function_type_it->second;
|
|
assert(function_type_inst->opcode() == SpvOpTypeFunction);
|
|
|
|
const uint32_t argument_type =
|
|
function_type_inst->word(operand_index_);
|
|
return GetMtfIdOfType(argument_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpReturnValue: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdOfType(cur_function_return_type_);
|
|
break;
|
|
}
|
|
|
|
case SpvOpBranchConditional: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
|
|
break;
|
|
}
|
|
|
|
case SpvOpSampledImage: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
|
|
break;
|
|
}
|
|
|
|
case SpvOpImageSampleImplicitLod: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
|
|
break;
|
|
}
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return kMtfNone;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
|
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
|
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// Encoding successful.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// Encoding failed, write kMarkvNoneOfTheAbove flag.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Non-id word Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode))
|
|
<< " operand index " << operand_index_
|
|
<< " is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback encoding.
|
|
const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type);
|
|
if (chunk_length) {
|
|
writer_.WriteVariableWidthU32(word, chunk_length);
|
|
} else {
|
|
writer_.WriteUnencoded(word);
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
|
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
|
|
|
if (codec) {
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to decode non-id word with Huffman";
|
|
|
|
if (decoded_value != kMarkvNoneOfTheAbove) {
|
|
// The word decoded successfully.
|
|
*word = uint32_t(decoded_value);
|
|
assert(*word == decoded_value);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
|
}
|
|
|
|
const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type);
|
|
if (chunk_length) {
|
|
if (!reader_.ReadVariableWidthU32(word, chunk_length))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to decode non-id word with varint";
|
|
} else {
|
|
if (!reader_.ReadUnencoded(word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read unencoded non-id word";
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(
|
|
uint32_t opcode, uint32_t num_operands) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
|
|
const uint32_t word = opcode | (num_operands << 16);
|
|
|
|
// First try to use the Markov chain codec.
|
|
auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
|
if (codec) {
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and use fallback encoding.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "opcode_and_num_operands Huffman table for "
|
|
<< spvOpcodeString(GetPrevOpcode())
|
|
<< "is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback to base-rate codec.
|
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
|
assert(codec);
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and return false.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Global opcode_and_num_operands Huffman table is missing "
|
|
<< "kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
|
|
uint32_t* opcode, uint32_t* num_operands) {
|
|
// First try to use the Markov chain codec.
|
|
auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
|
if (codec) {
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode opcode_and_num_operands, previous opcode is "
|
|
<< spvOpcodeString(GetPrevOpcode());
|
|
|
|
if (decoded_value != kMarkvNoneOfTheAbove) {
|
|
// The word was successfully decoded.
|
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
|
*num_operands = uint32_t(decoded_value >> 16);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
|
}
|
|
|
|
// Fallback to base-rate codec.
|
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
|
assert(codec);
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode opcode_and_num_operands with global codec";
|
|
|
|
if (decoded_value == kMarkvNoneOfTheAbove) {
|
|
// Received kMarkvNoneOfTheAbove signal, fallback further.
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
|
*num_operands = uint32_t(decoded_value >> 16);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
|
uint64_t fallback_method) {
|
|
const auto* codec = model_->GetMtfHuffmanCodec(mtf);
|
|
if (!codec) {
|
|
assert(fallback_method != kMtfNone);
|
|
codec = model_->GetMtfHuffmanCodec(fallback_method);
|
|
}
|
|
|
|
if (!codec)
|
|
return vstate_.diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
|
|
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (rank < kMtfSmallestRankEncodedByValue) {
|
|
// Encode using Huffman coding.
|
|
if (!codec->Encode(rank, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode MTF rank with Huffman";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
} else {
|
|
// Encode by value.
|
|
if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode kMtfRankEncodedByValueSignal";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue,
|
|
model_->mtf_rank_chunk_length());
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(
|
|
uint64_t mtf, uint32_t fallback_method, uint32_t* rank) {
|
|
const auto* codec = model_->GetMtfHuffmanCodec(mtf);
|
|
if (!codec) {
|
|
assert(fallback_method != kMtfNone);
|
|
codec = model_->GetMtfHuffmanCodec(fallback_method);
|
|
}
|
|
|
|
if (!codec)
|
|
return vstate_.diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
|
|
|
|
uint32_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode MTF rank with Huffman";
|
|
|
|
if (decoded_value == kMtfRankEncodedByValueSignal) {
|
|
// Decode by value.
|
|
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode MTF rank with varint";
|
|
*rank += kMtfSmallestRankEncodedByValue;
|
|
} else {
|
|
// Decode using Huffman coding.
|
|
assert(decoded_value < kMtfSmallestRankEncodedByValue);
|
|
*rank = decoded_value;
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
|
|
auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode,
|
|
operand_index_);
|
|
if (!codec)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
|
|
// Get the descriptor for id.
|
|
const uint32_t descriptor = id_descriptors_.GetDescriptor(id);
|
|
|
|
if (descriptor && codec->Encode(descriptor, &bits, &num_bits)) {
|
|
// If the descriptor exists and is in the table, write the descriptor and
|
|
// proceed to encoding the rank.
|
|
writer_.WriteBits(bits, num_bits);
|
|
} else {
|
|
// The descriptor doesn't exist or we have no coding for it. Write
|
|
// kMarkvNoneOfTheAbove and go to fallback method.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Descriptor Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode))
|
|
<< " operand index " << operand_index_
|
|
<< " is missing kMarkvNoneOfTheAbove";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
// Descriptor has been encoded. Now encode the rank of the id in the
|
|
// associated mtf sequence.
|
|
const uint64_t mtf = GetMtfIdDescriptor(descriptor);
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
|
|
auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode,
|
|
operand_index_);
|
|
if (!codec)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode descriptor with Huffman";
|
|
|
|
if (decoded_value == kMarkvNoneOfTheAbove)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
// If descriptor exists then the id was encoded through descriptor mtf.
|
|
const uint32_t descriptor = uint32_t(decoded_value);
|
|
assert(descriptor == decoded_value);
|
|
assert(descriptor);
|
|
|
|
const uint64_t mtf = GetMtfIdDescriptor(descriptor);
|
|
return DecodeExistingId(mtf, id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
|
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
|
// If the sequence has only one element no need to write rank, the decoder
|
|
// would make the same decision.
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
uint32_t rank = 0;
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Id is not in the MTF sequence";
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
|
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
|
*id = 0;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
|
rank = 1;
|
|
} else {
|
|
const spv_result_t result =
|
|
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
assert(rank);
|
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "MTF rank is out of bounds";
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
|
|
// TODO(atgoo@github.com) This might not be needed as EncodeIdWithDescriptor
|
|
// can handle SpvOpAccessChain indices if enough statistics is collected.
|
|
if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) {
|
|
const auto it = id_to_presumed_index_.find(id);
|
|
if (it != id_to_presumed_index_.end()) {
|
|
writer_.WriteBits(1, 1);
|
|
writer_.WriteFixedWidth(it->second, kMarkvMaxPresumedAccessIndex);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
writer_.WriteBits(0, 1);
|
|
}
|
|
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
// Encode using rule-based mtf.
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
const bool can_forward_declare =
|
|
spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_);
|
|
|
|
if (mtf != kMtfNone && !can_forward_declare) {
|
|
assert(multi_mtf_.HasValue(kMtfAll, id));
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
if (mtf == kMtfNone)
|
|
mtf = kMtfAll;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
|
|
// This is the first occurrence of a forward declared id.
|
|
multi_mtf_.Insert(kMtfAll, id);
|
|
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
|
if (mtf != kMtfAll)
|
|
multi_mtf_.Insert(mtf, id);
|
|
rank = 0;
|
|
}
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
|
|
if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) {
|
|
uint64_t use_presumed_index_technique = 0;
|
|
if (!reader_.ReadBits(&use_presumed_index_technique, 1))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read use_presumed_index_technique flag";
|
|
|
|
if (use_presumed_index_technique) {
|
|
uint64_t value = 0;
|
|
if (!reader_.ReadFixedWidth(&value, kMarkvMaxPresumedAccessIndex))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read presumed_index";
|
|
|
|
const uint32_t presumed_index = static_cast<uint32_t>(value);
|
|
|
|
const auto it = presumed_index_to_id_.find(presumed_index);
|
|
if (it == presumed_index_to_id_.end()) {
|
|
assert(0);
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Presumed index id not found";
|
|
}
|
|
|
|
*id = it->second;
|
|
return SPV_SUCCESS;
|
|
}
|
|
}
|
|
|
|
{
|
|
const spv_result_t result = DecodeIdWithDescriptor(id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
const bool can_forward_declare =
|
|
spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_);
|
|
|
|
if (mtf != kMtfNone && !can_forward_declare) {
|
|
return DecodeExistingId(mtf, id);
|
|
}
|
|
|
|
if (mtf == kMtfNone)
|
|
mtf = kMtfAll;
|
|
|
|
*id = 0;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
{
|
|
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
if (rank == 0) {
|
|
// This is the first occurrence of a forward declared id.
|
|
*id = vstate_.getIdBound();
|
|
vstate_.setIdBound(*id + 1);
|
|
multi_mtf_.Insert(kMtfAll, *id);
|
|
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
|
if (mtf != kMtfAll)
|
|
multi_mtf_.Insert(mtf, *id);
|
|
} else {
|
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
|
}
|
|
|
|
assert(*id);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeTypeId() {
|
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
|
assert(!remaining_function_parameter_types_.empty());
|
|
assert(inst_.type_id == remaining_function_parameter_types_.front());
|
|
remaining_function_parameter_types_.pop_front();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
assert(!spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_));
|
|
|
|
if (mtf == kMtfNone) {
|
|
mtf = kMtfTypeNonFunction;
|
|
// Function types should have been handled by GetRuleBasedMtf.
|
|
assert(inst_.opcode != SpvOpFunction);
|
|
}
|
|
|
|
return EncodeExistingId(mtf, inst_.type_id);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeTypeId() {
|
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
|
assert(!remaining_function_parameter_types_.empty());
|
|
inst_.type_id = remaining_function_parameter_types_.front();
|
|
remaining_function_parameter_types_.pop_front();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
{
|
|
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
assert(!spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_));
|
|
|
|
if (mtf == kMtfNone) {
|
|
mtf = kMtfTypeNonFunction;
|
|
// Function types should have been handled by GetRuleBasedMtf.
|
|
assert(inst_.opcode != SpvOpFunction);
|
|
}
|
|
|
|
return DecodeExistingId(mtf, &inst_.type_id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeResultId() {
|
|
uint32_t rank = 0;
|
|
|
|
const uint64_t num_still_forward_declared =
|
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
|
|
|
if (num_still_forward_declared) {
|
|
// We write the rank only if kMtfForwardDeclared is not empty. If it is
|
|
// empty the decoder knows that there are no forward declared ids to expect.
|
|
if (multi_mtf_.RankFromValue(kMtfForwardDeclared,
|
|
inst_.result_id, &rank)) {
|
|
// This is a definition of a forward declared id. We can remove the id
|
|
// from kMtfForwardDeclared.
|
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to remove id from kMtfForwardDeclared";
|
|
writer_.WriteBits(1, 1);
|
|
writer_.WriteVariableWidthU32(
|
|
rank, model_->mtf_rank_chunk_length());
|
|
} else {
|
|
rank = 0;
|
|
writer_.WriteBits(0, 1);
|
|
}
|
|
}
|
|
|
|
if (!rank) {
|
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeResultId() {
|
|
uint32_t rank = 0;
|
|
|
|
const uint64_t num_still_forward_declared =
|
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
|
|
|
if (num_still_forward_declared) {
|
|
// Some ids were forward declared. Check if this id is one of them.
|
|
uint64_t id_was_forward_declared;
|
|
if (!reader_.ReadBits(&id_was_forward_declared, 1))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read id_was_forward_declared flag";
|
|
|
|
if (id_was_forward_declared) {
|
|
if (!reader_.ReadVariableWidthU32(
|
|
&rank, model_->mtf_rank_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read MTF rank of forward declared id";
|
|
|
|
if (rank) {
|
|
// The id was forward declared, recover it from kMtfForwardDeclared.
|
|
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared,
|
|
rank, &inst_.result_id))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Forward declared MTF rank is out of bounds";
|
|
|
|
// We can now remove the id from kMtfForwardDeclared.
|
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
|
return vstate_.diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to remove id from kMtfForwardDeclared";
|
|
}
|
|
}
|
|
}
|
|
|
|
if (inst_.result_id == 0) {
|
|
// The id was not forward declared, issue a new id.
|
|
inst_.result_id = vstate_.getIdBound();
|
|
vstate_.setIdBound(inst_.result_id + 1);
|
|
}
|
|
|
|
if (!rank) {
|
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeLiteralNumber(
|
|
const Instruction& instruction, const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width <= 32) {
|
|
const uint32_t word = instruction.word(operand.offset);
|
|
return EncodeNonIdWord(word);
|
|
} else {
|
|
assert(operand.number_bit_width <= 64);
|
|
const uint64_t word =
|
|
uint64_t(instruction.word(operand.offset)) |
|
|
(uint64_t(instruction.word(operand.offset + 1)) << 32);
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
std::memcpy(&val, &word, 8);
|
|
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
return vstate_.diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
|
}
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeLiteralNumber(
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width <= 32) {
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
inst_words_.push_back(word);
|
|
} else {
|
|
assert(operand.number_bit_width <= 64);
|
|
uint64_t word = 0;
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal U64";
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal S64";
|
|
std::memcpy(&word, &val, 8);
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
if (!reader_.ReadUnencoded(&word))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal F64";
|
|
} else {
|
|
return vstate_.diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
|
}
|
|
inst_words_.push_back(static_cast<uint32_t>(word));
|
|
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
|
|
const size_t num_bits_to_next_byte =
|
|
GetNumBitsToNextByte(writer_.GetNumBits());
|
|
if (num_bits_to_next_byte == 0 ||
|
|
num_bits_to_next_byte > byte_break_if_less_than)
|
|
return;
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<byte break>");
|
|
}
|
|
|
|
writer_.WriteBits(0, num_bits_to_next_byte);
|
|
}
|
|
|
|
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
|
|
const size_t num_bits_to_next_byte =
|
|
GetNumBitsToNextByte(reader_.GetNumReadBits());
|
|
if (num_bits_to_next_byte == 0 ||
|
|
num_bits_to_next_byte > byte_break_if_less_than)
|
|
return true;
|
|
|
|
|
|
uint64_t bits = 0;
|
|
if (!reader_.ReadBits(&bits, num_bits_to_next_byte))
|
|
return false;
|
|
|
|
assert(bits == 0);
|
|
if (bits != 0)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeInstruction(
|
|
const spv_parsed_instruction_t& inst) {
|
|
SpvOp opcode = SpvOp(inst.opcode);
|
|
inst_ = inst;
|
|
|
|
const spv_result_t validation_result = UpdateValidationState(inst);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
|
|
const Instruction& instruction = vstate_.ordered_instructions().back();
|
|
const auto& operands = instruction.operands();
|
|
|
|
LogDisassemblyInstruction();
|
|
|
|
const spv_result_t opcode_encodig_result =
|
|
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
|
|
if (opcode_encodig_result < 0)
|
|
return opcode_encodig_result;
|
|
|
|
if (opcode_encodig_result != SPV_SUCCESS) {
|
|
// Fallback encoding for opcode and num_operands.
|
|
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
|
|
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
// If the opcode has a variable number of operands, encode the number of
|
|
// operands with the instruction.
|
|
|
|
if (logger_)
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
|
|
writer_.WriteVariableWidthU16(inst.num_operands,
|
|
model_->num_operands_chunk_length());
|
|
}
|
|
}
|
|
|
|
// Write operands.
|
|
for (operand_index_ = 0; operand_index_ < operands.size(); ++operand_index_) {
|
|
operand_ = operands[operand_index_];
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<");
|
|
logger_->AppendText(spvOperandTypeStr(operand_.type));
|
|
logger_->AppendText(">");
|
|
}
|
|
|
|
switch (operand_.type) {
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
const uint32_t id = instruction.word(operand_.offset);
|
|
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
|
|
const spv_result_t result = EncodeTypeId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
|
|
const spv_result_t result = EncodeResultId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
} else {
|
|
const spv_result_t result = EncodeRefId(id);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
multi_mtf_.Promote(id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
|
|
const spv_result_t result =
|
|
EncodeNonIdWord(instruction.word(operand_.offset));
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
|
|
const spv_result_t result = EncodeLiteralNumber(instruction, operand_);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING: {
|
|
const char* src = reinterpret_cast<const char*>(
|
|
&instruction.words()[operand_.offset]);
|
|
|
|
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
const std::string str = reinterpret_cast<const char*>(
|
|
&instruction.words()[operand_.offset]);
|
|
if (codec->Encode(str, &bits, &num_bits)) {
|
|
writer_.WriteBits(bits, num_bits);
|
|
break;
|
|
} else {
|
|
bool result = codec->Encode("kMarkvNoneOfTheAbove",
|
|
&bits, &num_bits);
|
|
(void)result;
|
|
assert(result);
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
|
|
if (length == operand_.num_words * 4)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to find terminal character of literal string";
|
|
for (size_t i = 0; i < length + 1; ++i)
|
|
writer_.WriteUnencoded(src[i]);
|
|
break;
|
|
}
|
|
|
|
default: {
|
|
for (int i = 0; i < operand_.num_words; ++i) {
|
|
const uint32_t word = instruction.word(operand_.offset + i);
|
|
const spv_result_t result = EncodeNonIdWord(word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte);
|
|
|
|
if (logger_) {
|
|
logger_->NewLine();
|
|
logger_->NewLine();
|
|
}
|
|
|
|
ProcessCurInstruction();
|
|
instructions_.push_back(&instruction);
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
|
|
const bool header_read_success =
|
|
reader_.ReadUnencoded(&header_.magic_number) &&
|
|
reader_.ReadUnencoded(&header_.markv_version) &&
|
|
reader_.ReadUnencoded(&header_.markv_model) &&
|
|
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
|
|
reader_.ReadUnencoded(&header_.spirv_version) &&
|
|
reader_.ReadUnencoded(&header_.spirv_generator);
|
|
|
|
if (!header_read_success)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to read MARK-V header";
|
|
|
|
if (header_.markv_length_in_bits == 0)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Header markv_length_in_bits field is zero";
|
|
|
|
if (header_.magic_number != kMarkvMagicNumber)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has incorrect magic number";
|
|
|
|
// TODO(atgoo@github.com): Print version strings.
|
|
if (header_.markv_version != GetMarkvVersion())
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary and the codec have different versions";
|
|
|
|
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
|
|
spirv_.resize(5, 0);
|
|
spirv_[0] = kSpirvMagicNumber;
|
|
spirv_[1] = header_.spirv_version;
|
|
spirv_[2] = header_.spirv_generator;
|
|
|
|
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
|
|
inst_ = {};
|
|
const spv_result_t decode_result = DecodeInstruction();
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
|
|
const spv_result_t validation_result = UpdateValidationState(inst_);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
|
|
instructions_.push_back(&vstate_.ordered_instructions().back());
|
|
}
|
|
|
|
|
|
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
|
|
!reader_.OnlyZeroesLeft()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has wrong stated bit length "
|
|
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
|
|
}
|
|
|
|
// Decoding of the module is finished, validation state should have correct
|
|
// id bound.
|
|
spirv_[3] = vstate_.getIdBound();
|
|
|
|
*spirv_binary = std::move(spirv_);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// TODO(atgoo@github.com): The implementation borrows heavily from
|
|
// Parser::parseOperand.
|
|
// Consider coupling them together in some way once MARK-V codec is more mature.
|
|
// For now it's better to keep the code independent for experimentation
|
|
// purposes.
|
|
spv_result_t MarkvDecoder::DecodeOperand(
|
|
size_t operand_offset,
|
|
const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands) {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
memset(&operand_, 0, sizeof(operand_));
|
|
|
|
assert((operand_offset >> 16) == 0);
|
|
operand_.offset = static_cast<uint16_t>(operand_offset);
|
|
operand_.type = type;
|
|
|
|
// Set default values, may be updated later.
|
|
operand_.number_kind = SPV_NUMBER_NONE;
|
|
operand_.number_bit_width = 0;
|
|
|
|
const size_t first_word_index = inst_words_.size();
|
|
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_RESULT_ID: {
|
|
const spv_result_t result = DecodeResultId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(inst_.result_id);
|
|
vstate_.setIdBound(std::max(vstate_.getIdBound(), inst_.result_id + 1));
|
|
multi_mtf_.Promote(inst_.result_id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPE_ID: {
|
|
const spv_result_t result = DecodeTypeId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(inst_.type_id);
|
|
vstate_.setIdBound(std::max(vstate_.getIdBound(), inst_.type_id + 1));
|
|
multi_mtf_.Promote(inst_.type_id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
uint32_t id = 0;
|
|
const spv_result_t result = DecodeRefId(&id);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
if (id == 0)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
|
|
|
|
if (type == SPV_OPERAND_TYPE_ID ||
|
|
type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
|
|
|
|
operand_.type = SPV_OPERAND_TYPE_ID;
|
|
|
|
if (opcode == SpvOpExtInst && operand_.offset == 3) {
|
|
// The current word is the extended instruction set id.
|
|
// Set the extended instruction set type for the current
|
|
// instruction.
|
|
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
|
|
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_ID)
|
|
<< "OpExtInst set id " << id
|
|
<< " does not reference an OpExtInstImport result Id";
|
|
}
|
|
inst_.ext_inst_type = ext_inst_type_iter->second;
|
|
}
|
|
}
|
|
|
|
inst_words_.push_back(id);
|
|
vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
|
|
multi_mtf_.Promote(id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
assert(SpvOpExtInst == opcode);
|
|
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
|
|
spv_ext_inst_desc ext_inst;
|
|
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction number: " << word;
|
|
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
|
|
// These are regular single-word literal integer operands.
|
|
// Post-parsing validation should check the range of the parsed value.
|
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
|
|
// It turns out they are always unsigned integers!
|
|
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
|
|
operand_.number_bit_width = 32;
|
|
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
|
|
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
|
|
if (opcode == SpvOpSwitch) {
|
|
// The literal operands have the same type as the value
|
|
// referenced by the selector Id.
|
|
const uint32_t selector_id = inst_words_.at(1);
|
|
const auto type_id_iter = id_to_type_id_.find(selector_id);
|
|
if (type_id_iter == id_to_type_id_.end() ||
|
|
type_id_iter->second == 0) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " has no type";
|
|
}
|
|
uint32_t type_id = type_id_iter->second;
|
|
|
|
if (selector_id == type_id) {
|
|
// Recall that by convention, a result ID that is a type definition
|
|
// maps to itself.
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is a type, not a value";
|
|
}
|
|
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
|
|
return error;
|
|
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
|
|
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is not a scalar integer";
|
|
}
|
|
} else {
|
|
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
|
|
// The literal number type is determined by the type Id for the
|
|
// constant.
|
|
assert(inst_.type_id);
|
|
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
|
|
return error;
|
|
}
|
|
|
|
if (auto error = DecodeLiteralNumber(operand_))
|
|
return error;
|
|
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
|
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
|
|
std::vector<char> str;
|
|
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
|
|
|
|
if (codec) {
|
|
std::string decoded_string;
|
|
const bool huffman_result =
|
|
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
|
|
assert(huffman_result);
|
|
if (!huffman_result)
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal string";
|
|
|
|
if (decoded_string != "kMarkvNoneOfTheAbove") {
|
|
std::copy(decoded_string.begin(), decoded_string.end(),
|
|
std::back_inserter(str));
|
|
str.push_back('\0');
|
|
}
|
|
}
|
|
|
|
// The loop is expected to terminate once we encounter '\0' or exhaust
|
|
// the bit stream.
|
|
if (str.empty()) {
|
|
while (true) {
|
|
char ch = 0;
|
|
if (!reader_.ReadUnencoded(&ch))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal string";
|
|
|
|
str.push_back(ch);
|
|
|
|
if (ch == '\0')
|
|
break;
|
|
}
|
|
}
|
|
|
|
while (str.size() % 4 != 0)
|
|
str.push_back('\0');
|
|
|
|
inst_words_.resize(inst_words_.size() + str.size() / 4);
|
|
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
|
|
|
|
if (SpvOpExtInstImport == opcode) {
|
|
// Record the extended instruction type for the ID for this import.
|
|
// There is only one string literal argument to OpExtInstImport,
|
|
// so it's sufficient to guard this just on the opcode.
|
|
const spv_ext_inst_type_t ext_inst_type =
|
|
spvExtInstImportTypeGet(str.data());
|
|
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction import '" << str.data() << "'";
|
|
}
|
|
// We must have parsed a valid result ID. It's a condition
|
|
// of the grammar, and we only accept non-zero result Ids.
|
|
assert(inst_.result_id);
|
|
const bool inserted = import_id_to_ext_inst_type_.emplace(
|
|
inst_.result_id, ext_inst_type).second;
|
|
(void)inserted;
|
|
assert(inserted);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
|
|
// A single word that is a plain enum value.
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
|
|
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
|
|
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, word, &entry)) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid "
|
|
<< spvOperandTypeStr(operand_.type)
|
|
<< " operand: " << word;
|
|
}
|
|
|
|
// Prepare to accept operands to this operand, if needed.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
|
|
// This operand is a mask.
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
|
|
operand_.type = SPV_OPERAND_TYPE_IMAGE;
|
|
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
|
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
|
|
|
// Check validity of set mask bits. Also prepare for operands for those
|
|
// masks if they have any. To get operand order correct, scan from
|
|
// MSB to LSB since we can only prepend operands to a pattern.
|
|
// The only case in the grammar where you have more than one mask bit
|
|
// having an operand is for image operands. See SPIR-V 3.14 Image
|
|
// Operands.
|
|
uint32_t remaining_word = word;
|
|
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
|
|
if (remaining_word & mask) {
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, mask, &entry)) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
|
<< " operand: " << word << " has invalid mask component "
|
|
<< mask;
|
|
}
|
|
remaining_word ^= mask;
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
if (word == 0) {
|
|
// An all-zeroes mask *might* also be valid.
|
|
spv_operand_desc entry;
|
|
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
|
|
// Prepare for its operands, if any.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Internal error: Unhandled operand type: " << type;
|
|
}
|
|
|
|
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
|
|
|
|
assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(operand_.type));
|
|
assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(operand_.type));
|
|
|
|
parsed_operands_.push_back(operand_);
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeInstruction() {
|
|
parsed_operands_.clear();
|
|
inst_words_.clear();
|
|
|
|
// Opcode/num_words placeholder, the word will be filled in later.
|
|
inst_words_.push_back(0);
|
|
|
|
bool num_operands_still_unknown = true;
|
|
{
|
|
uint32_t opcode = 0;
|
|
uint32_t num_operands = 0;
|
|
|
|
const spv_result_t opcode_decoding_result =
|
|
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
|
|
if (opcode_decoding_result < 0)
|
|
return opcode_decoding_result;
|
|
|
|
if (opcode_decoding_result == SPV_SUCCESS) {
|
|
inst_.num_operands = static_cast<uint16_t>(num_operands);
|
|
num_operands_still_unknown = false;
|
|
} else {
|
|
if (!reader_.ReadVariableWidthU32(
|
|
&opcode, model_->opcode_chunk_length())) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read opcode of instruction";
|
|
}
|
|
}
|
|
|
|
inst_.opcode = static_cast<uint16_t>(opcode);
|
|
}
|
|
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
spv_opcode_desc opcode_desc;
|
|
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
|
|
}
|
|
|
|
spv_operand_pattern_t expected_operands;
|
|
expected_operands.reserve(opcode_desc->numTypes);
|
|
for (auto i = 0; i < opcode_desc->numTypes; i++) {
|
|
expected_operands.push_back(
|
|
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
|
|
}
|
|
|
|
if (num_operands_still_unknown) {
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
|
|
model_->num_operands_chunk_length()))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read num_operands of instruction";
|
|
} else {
|
|
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
|
|
}
|
|
}
|
|
|
|
for (operand_index_ = 0;
|
|
operand_index_ < static_cast<size_t>(inst_.num_operands);
|
|
++operand_index_) {
|
|
assert(!expected_operands.empty());
|
|
const spv_operand_type_t type =
|
|
spvTakeFirstMatchableOperand(&expected_operands);
|
|
|
|
const size_t operand_offset = inst_words_.size();
|
|
|
|
const spv_result_t decode_result = DecodeOperand(
|
|
operand_offset, type, &expected_operands);
|
|
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
}
|
|
|
|
|
|
assert(inst_.num_operands == parsed_operands_.size());
|
|
|
|
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
|
|
// next DecodeInstruction call).
|
|
inst_.words = inst_words_.data();
|
|
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
|
|
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
|
|
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
|
|
|
|
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
|
|
|
|
assert(inst_.num_words == std::accumulate(
|
|
parsed_operands_.begin(), parsed_operands_.end(), 1,
|
|
[](int num_words, const spv_parsed_operand_t& operand) {
|
|
return num_words += operand.num_words;
|
|
}) && "num_words in instruction doesn't correspond to the sum of num_words"
|
|
"in the operands");
|
|
|
|
RecordNumberType();
|
|
ProcessCurInstruction();
|
|
|
|
if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte))
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read to byte break";
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
|
|
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
|
|
assert(type_id != 0);
|
|
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
|
|
if (type_info_iter == type_id_to_number_type_info_.end()) {
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a type";
|
|
}
|
|
|
|
const NumberType& info = type_info_iter->second;
|
|
if (info.type == SPV_NUMBER_NONE) {
|
|
// This is a valid type, but for something other than a scalar number.
|
|
return vstate_.diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a scalar numeric type";
|
|
}
|
|
|
|
parsed_operand->number_kind = info.type;
|
|
parsed_operand->number_bit_width = info.bit_width;
|
|
// Round up the word count.
|
|
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvDecoder::RecordNumberType() {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
if (spvOpcodeGeneratesType(opcode)) {
|
|
NumberType info = {SPV_NUMBER_NONE, 0};
|
|
if (SpvOpTypeInt == opcode) {
|
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
|
info.type = inst_.words[inst_.operands[2].offset] ?
|
|
SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
|
|
} else if (SpvOpTypeFloat == opcode) {
|
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
|
info.type = SPV_NUMBER_FLOATING;
|
|
}
|
|
// The *result* Id of a type generating instruction is the type Id.
|
|
type_id_to_number_type_info_[inst_.result_id] = info;
|
|
}
|
|
}
|
|
|
|
spv_result_t EncodeHeader(
|
|
void* user_data, spv_endianness_t endian, uint32_t magic,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t schema) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeHeader(
|
|
endian, magic, version, generator, id_bound, schema);
|
|
}
|
|
|
|
spv_result_t EncodeInstruction(
|
|
void* user_data, const spv_parsed_instruction_t* inst) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeInstruction(*inst);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
spv_result_t spvSpirvToMarkv(spv_const_context context,
|
|
const uint32_t* spirv_words,
|
|
const size_t spirv_num_words,
|
|
spv_const_markv_encoder_options options,
|
|
spv_markv_binary* markv_binary,
|
|
spv_text* comments, spv_diagnostic* diagnostic) {
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
|
|
|
|
spv_endianness_t endian;
|
|
spv_position_t position = {};
|
|
if (spvBinaryEndianness(&spirv_binary, &endian)) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V magic number.";
|
|
}
|
|
|
|
spv_header_t header;
|
|
if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V header.";
|
|
}
|
|
|
|
MarkvEncoder encoder(&hijack_context, options);
|
|
|
|
if (comments) {
|
|
encoder.CreateCommentsLogger();
|
|
|
|
spv_text text = nullptr;
|
|
if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
|
|
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
|
|
!= SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to disassemble SPIR-V binary.";
|
|
}
|
|
assert(text);
|
|
encoder.SetDisassembly(std::string(text->str, text->length));
|
|
spvTextDestroy(text);
|
|
}
|
|
|
|
if (spvBinaryParse(
|
|
&hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
|
|
EncodeInstruction, diagnostic) != SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to encode to MARK-V.";
|
|
}
|
|
|
|
if (comments)
|
|
*comments = CreateSpvText(encoder.GetComments());
|
|
|
|
*markv_binary = encoder.GetMarkvBinary();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t spvMarkvToSpirv(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options,
|
|
spv_binary* spirv_binary,
|
|
spv_text* /* comments */,
|
|
spv_diagnostic* diagnostic) {
|
|
spv_position_t position = {};
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
|
|
|
|
std::vector<uint32_t> words;
|
|
|
|
if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
|
|
return libspirv::DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to decode MARK-V.";
|
|
}
|
|
|
|
assert(!words.empty());
|
|
|
|
*spirv_binary = new spv_binary_t();
|
|
(*spirv_binary)->code = new uint32_t[words.size()];
|
|
(*spirv_binary)->wordCount = words.size();
|
|
std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void spvMarkvBinaryDestroy(spv_markv_binary binary) {
|
|
if (!binary) return;
|
|
delete[] binary->data;
|
|
delete binary;
|
|
}
|
|
|
|
spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
|
|
return new spv_markv_encoder_options_t;
|
|
}
|
|
|
|
void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
|
|
delete options;
|
|
}
|
|
|
|
spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
|
|
return new spv_markv_decoder_options_t;
|
|
}
|
|
|
|
void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
|
|
delete options;
|
|
}
|