mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-19 03:20:14 +00:00
976e4218d5
MARK-V codec was previously dependent on the validation state. Now it doesn't need the validator to function, but can still optionally create it and validate every instruction once it's decoded.
3146 lines
103 KiB
C++
3146 lines
103 KiB
C++
// Copyright (c) 2017 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Contains
|
|
// - SPIR-V to MARK-V encoder
|
|
// - MARK-V to SPIR-V decoder
|
|
//
|
|
// MARK-V is a compression format for SPIR-V binaries. It strips away
|
|
// non-essential information (such as result ids which can be regenerated) and
|
|
// uses various bit reduction techiniques to reduce the size of the binary.
|
|
//
|
|
// MarkvModel is a flatbuffers object containing a set of rules defining how
|
|
// compression/decompression is done (coding schemes, dictionaries).
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <iterator>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
#include "spirv/1.2/GLSL.std.450.h"
|
|
#include "spirv/1.2/OpenCL.std.h"
|
|
#include "spirv/1.2/spirv.h"
|
|
|
|
#include "binary.h"
|
|
#include "diagnostic.h"
|
|
#include "enum_string_mapping.h"
|
|
#include "extensions.h"
|
|
#include "ext_inst.h"
|
|
#include "id_descriptor.h"
|
|
#include "instruction.h"
|
|
#include "markv_autogen.h"
|
|
#include "opcode.h"
|
|
#include "operand.h"
|
|
#include "spirv-tools/libspirv.h"
|
|
#include "spirv-tools/markv.h"
|
|
#include "spirv_endian.h"
|
|
#include "spirv_validator_options.h"
|
|
#include "util/bit_stream.h"
|
|
#include "util/huffman_codec.h"
|
|
#include "util/move_to_front.h"
|
|
#include "util/parse_number.h"
|
|
#include "validate.h"
|
|
#include "val/instruction.h"
|
|
#include "val/validation_state.h"
|
|
|
|
using libspirv::IdDescriptorCollection;
|
|
using libspirv::Instruction;
|
|
using libspirv::ValidationState_t;
|
|
using libspirv::DiagnosticStream;
|
|
using spvtools::ValidateInstructionAndUpdateValidationState;
|
|
using spvutils::BitReaderWord64;
|
|
using spvutils::BitWriterWord64;
|
|
using spvutils::HuffmanCodec;
|
|
using MoveToFront = spvutils::MoveToFront<uint32_t>;
|
|
using MultiMoveToFront = spvutils::MultiMoveToFront<uint32_t>;
|
|
|
|
struct spv_markv_encoder_options_t {
|
|
bool validate_spirv_binary = false;
|
|
};
|
|
|
|
struct spv_markv_decoder_options_t {
|
|
bool validate_spirv_binary = false;
|
|
};
|
|
|
|
namespace {
|
|
|
|
const uint32_t kSpirvMagicNumber = SpvMagicNumber;
|
|
const uint32_t kMarkvMagicNumber = 0x07230303;
|
|
|
|
// Handles for move-to-front sequences. Enums which end with "Begin" define
|
|
// handle spaces which start at that value and span 16 or 32 bit wide.
|
|
enum : uint64_t {
|
|
kMtfNone = 0,
|
|
// All ids.
|
|
kMtfAll,
|
|
// All forward declared ids.
|
|
kMtfForwardDeclared,
|
|
// All type ids except for generated by OpTypeFunction.
|
|
kMtfTypeNonFunction,
|
|
// All labels.
|
|
kMtfLabel,
|
|
// All ids created by instructions which had type_id.
|
|
kMtfObject,
|
|
// All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
|
|
kMtfTypeScalar,
|
|
// All composite types.
|
|
kMtfTypeComposite,
|
|
// Boolean type or any vector type of it.
|
|
kMtfTypeBoolScalarOrVector,
|
|
// All float types or any vector floats type.
|
|
kMtfTypeFloatScalarOrVector,
|
|
// All int types or any vector int type.
|
|
kMtfTypeIntScalarOrVector,
|
|
// All types declared as return types in OpTypeFunction.
|
|
kMtfTypeReturnedByFunction,
|
|
// All object ids which are integer constants.
|
|
kMtfConstInteger,
|
|
// All composite objects.
|
|
kMtfComposite,
|
|
// All bool objects or vectors of bools.
|
|
kMtfBoolScalarOrVector,
|
|
// All float objects or vectors of float.
|
|
kMtfFloatScalarOrVector,
|
|
// All int objects or vectors of int.
|
|
kMtfIntScalarOrVector,
|
|
// All pointer types which point to composited.
|
|
kMtfTypePointerToComposite,
|
|
// Used by EncodeMtfRankHuffman.
|
|
kMtfGenericNonZeroRank,
|
|
// Handle space for ids of specific type.
|
|
kMtfIdOfTypeBegin = 0x10000,
|
|
// Handle space for ids generated by specific opcode.
|
|
kMtfIdGeneratedByOpcode = 0x20000,
|
|
// Handle space for ids of objects with type generated by specific opcode.
|
|
kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
|
|
// All vectors of specific component type.
|
|
kMtfVectorOfComponentTypeBegin = 0x40000,
|
|
// All vector types of specific size.
|
|
kMtfTypeVectorOfSizeBegin = 0x50000,
|
|
// All pointer types to specific type.
|
|
kMtfPointerToTypeBegin = 0x60000,
|
|
// All function types which return specific type.
|
|
kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
|
|
// All function objects which return specific type.
|
|
kMtfFunctionWithReturnTypeBegin = 0x80000,
|
|
// All float vectors of specific size.
|
|
kMtfFloatVectorOfSizeBegin = 0x90000,
|
|
// Id descriptor space (32-bit).
|
|
kMtfIdDescriptorSpaceBegin = 0x100000000,
|
|
};
|
|
|
|
// Used by "presumed index" technique which does special treatment of integer
|
|
// constants no greater than this value.
|
|
const uint32_t kMarkvMaxPresumedAccessIndex = 31;
|
|
|
|
// Signals that the value is not in the coding scheme and a fallback method
|
|
// needs to be used.
|
|
const uint64_t kMarkvNoneOfTheAbove = GetMarkvNonOfTheAbove();
|
|
|
|
// Mtf ranks smaller than this are encoded with Huffman coding.
|
|
const uint32_t kMtfSmallestRankEncodedByValue = 10;
|
|
|
|
// Signals that the mtf rank is too large to be encoded with Huffman.
|
|
const uint32_t kMtfRankEncodedByValueSignal =
|
|
std::numeric_limits<uint32_t>::max();
|
|
|
|
const size_t kCommentNumWhitespaces = 2;
|
|
|
|
const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8;
|
|
|
|
// Returns a set of mtf rank codecs based on a plausible hand-coded
|
|
// distribution.
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
|
GetMtfHuffmanCodecs() {
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
|
|
|
|
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
|
|
|
|
codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
|
|
{ 0, 5 },
|
|
{ 1, 40 },
|
|
{ 2, 10 },
|
|
{ 3, 5 },
|
|
{ 4, 5 },
|
|
{ 5, 5 },
|
|
{ 6, 3 },
|
|
{ 7, 3 },
|
|
{ 8, 3 },
|
|
{ 9, 3 },
|
|
{ kMtfRankEncodedByValueSignal, 10 },
|
|
})));
|
|
codecs.emplace(kMtfAll, std::move(codec));
|
|
|
|
codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
|
|
{ 1, 50 },
|
|
{ 2, 20 },
|
|
{ 3, 5 },
|
|
{ 4, 5 },
|
|
{ 5, 2 },
|
|
{ 6, 1 },
|
|
{ 7, 1 },
|
|
{ 8, 1 },
|
|
{ 9, 1 },
|
|
{ kMtfRankEncodedByValueSignal, 10 },
|
|
})));
|
|
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
|
|
|
|
return codecs;
|
|
}
|
|
|
|
// Encoding/decoding model containing various constants and codecs.
|
|
class MarkvModel {
|
|
public:
|
|
MarkvModel()
|
|
: mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
|
|
opcode_and_num_operands_huffman_codec_(GetOpcodeAndNumOperandsHist()),
|
|
opcode_and_num_operands_markov_huffman_codecs_(
|
|
GetOpcodeAndNumOperandsMarkovHuffmanCodecs()),
|
|
non_id_word_huffman_codecs_(GetNonIdWordHuffmanCodecs()),
|
|
id_descriptor_huffman_codecs_(GetIdDescriptorHuffmanCodecs()),
|
|
descriptors_with_coding_scheme_(GetDescriptorsWithCodingScheme()),
|
|
literal_string_huffman_codecs_(GetLiteralStringHuffmanCodecs()) {}
|
|
|
|
size_t opcode_chunk_length() const { return 7; }
|
|
size_t num_operands_chunk_length() const { return 3; }
|
|
size_t mtf_rank_chunk_length() const { return 5; }
|
|
|
|
size_t u16_chunk_length() const { return 4; }
|
|
size_t s16_chunk_length() const { return 4; }
|
|
size_t s16_block_exponent() const { return 6; }
|
|
|
|
size_t u32_chunk_length() const { return 8; }
|
|
size_t s32_chunk_length() const { return 8; }
|
|
size_t s32_block_exponent() const { return 10; }
|
|
|
|
size_t u64_chunk_length() const { return 8; }
|
|
size_t s64_chunk_length() const { return 8; }
|
|
size_t s64_block_exponent() const { return 10; }
|
|
|
|
// Returns Huffman codec for ranks of the mtf with given |handle|.
|
|
// Different mtfs can use different rank distributions.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
|
|
const auto it = mtf_huffman_codecs_.find(handle);
|
|
if (it == mtf_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common opcode_and_num_operands words for the given
|
|
// previous opcode. May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetOpcodeAndNumOperandsMarkovHuffmanCodec(
|
|
uint32_t prev_opcode) const {
|
|
if (prev_opcode == SpvOpNop)
|
|
return &opcode_and_num_operands_huffman_codec_;
|
|
|
|
const auto it =
|
|
opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode);
|
|
if (it == opcode_and_num_operands_markov_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common non-id words used for given operand slot.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetNonIdWordHuffmanCodec(
|
|
uint32_t opcode, uint32_t operand_index) const {
|
|
const auto it = non_id_word_huffman_codecs_.find(
|
|
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
|
if (it == non_id_word_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common id descriptos used for given operand slot.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<uint64_t>* GetIdDescriptorHuffmanCodec(
|
|
uint32_t opcode, uint32_t operand_index) const {
|
|
const auto it = id_descriptor_huffman_codecs_.find(
|
|
std::pair<uint32_t, uint32_t>(opcode, operand_index));
|
|
if (it == id_descriptor_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
// Returns a codec for common strings used by the given opcode.
|
|
// Operand slot is defined by the opcode and the operand index.
|
|
// May return nullptr if the codec doesn't exist.
|
|
const HuffmanCodec<std::string>* GetLiteralStringHuffmanCodec(
|
|
uint32_t opcode) const {
|
|
const auto it = literal_string_huffman_codecs_.find(opcode);
|
|
if (it == literal_string_huffman_codecs_.end())
|
|
return nullptr;
|
|
return it->second.get();
|
|
}
|
|
|
|
bool DescriptorHasCodingScheme(uint32_t descriptor) const {
|
|
return descriptors_with_coding_scheme_.count(descriptor);
|
|
}
|
|
|
|
private:
|
|
// Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
|
|
// need to contain a different codec for every handle as most use one and the
|
|
// same.
|
|
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
|
|
mtf_huffman_codecs_;
|
|
|
|
// Huffman codec for base-rate of opcode_and_num_operands.
|
|
HuffmanCodec<uint64_t> opcode_and_num_operands_huffman_codec_;
|
|
|
|
// Huffman codecs for opcode_and_num_operands. The map key is previous opcode.
|
|
std::map<uint32_t, std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
opcode_and_num_operands_markov_huffman_codecs_;
|
|
|
|
// Huffman codecs for non-id single-word operand values.
|
|
// The map key is pair <opcode, operand_index>.
|
|
std::map<std::pair<uint32_t, uint32_t>,
|
|
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
non_id_word_huffman_codecs_;
|
|
|
|
// Huffman codecs for id descriptors. The map key is pair
|
|
// <opcode, operand_index>.
|
|
std::map<std::pair<uint32_t, uint32_t>,
|
|
std::unique_ptr<HuffmanCodec<uint64_t>>>
|
|
id_descriptor_huffman_codecs_;
|
|
|
|
std::unordered_set<uint32_t> descriptors_with_coding_scheme_;
|
|
|
|
// Huffman codecs for literal strings. The map key is the opcode of the
|
|
// current instruction. This assumes, that there is no more than one literal
|
|
// string operand per instruction, but would still work even if this is not
|
|
// the case. Names and debug information strings are not collected.
|
|
std::map<uint32_t, std::unique_ptr<HuffmanCodec<std::string>>>
|
|
literal_string_huffman_codecs_;
|
|
};
|
|
|
|
const MarkvModel* GetDefaultModel() {
|
|
static MarkvModel model;
|
|
return &model;
|
|
}
|
|
|
|
// Returns chunk length used for variable length encoding of spirv operand
|
|
// words. Returns zero if operand type corresponds to potentially multiple
|
|
// words or a word which is not expected to profit from variable width encoding.
|
|
// Chunk length is selected based on the size of expected value.
|
|
// Most of these values will later be encoded with probability-based coding,
|
|
// but variable width integer coding is a good quick solution.
|
|
// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
|
|
size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
|
|
return 8;
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
return 3;
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
return 6;
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
|
|
return 2;
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
|
return 4;
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
|
return 6;
|
|
default:
|
|
return 0;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Returns true if the opcode has a fixed number of operands. May return a
|
|
// false negative.
|
|
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
|
|
switch (opcode) {
|
|
// TODO(atgoo@github.com) This is not a complete list.
|
|
case SpvOpNop:
|
|
case SpvOpName:
|
|
case SpvOpUndef:
|
|
case SpvOpSizeOf:
|
|
case SpvOpLine:
|
|
case SpvOpNoLine:
|
|
case SpvOpDecorationGroup:
|
|
case SpvOpExtension:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpMemoryModel:
|
|
case SpvOpCapability:
|
|
case SpvOpTypeVoid:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeMatrix:
|
|
case SpvOpTypeSampler:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeArray:
|
|
case SpvOpTypePointer:
|
|
case SpvOpConstantTrue:
|
|
case SpvOpConstantFalse:
|
|
case SpvOpLabel:
|
|
case SpvOpBranch:
|
|
case SpvOpFunction:
|
|
case SpvOpFunctionParameter:
|
|
case SpvOpFunctionEnd:
|
|
case SpvOpBitcast:
|
|
case SpvOpCopyObject:
|
|
case SpvOpTranspose:
|
|
case SpvOpSNegate:
|
|
case SpvOpFNegate:
|
|
case SpvOpIAdd:
|
|
case SpvOpFAdd:
|
|
case SpvOpISub:
|
|
case SpvOpFSub:
|
|
case SpvOpIMul:
|
|
case SpvOpFMul:
|
|
case SpvOpUDiv:
|
|
case SpvOpSDiv:
|
|
case SpvOpFDiv:
|
|
case SpvOpUMod:
|
|
case SpvOpSRem:
|
|
case SpvOpSMod:
|
|
case SpvOpFRem:
|
|
case SpvOpFMod:
|
|
case SpvOpVectorTimesScalar:
|
|
case SpvOpMatrixTimesScalar:
|
|
case SpvOpVectorTimesMatrix:
|
|
case SpvOpMatrixTimesVector:
|
|
case SpvOpMatrixTimesMatrix:
|
|
case SpvOpOuterProduct:
|
|
case SpvOpDot:
|
|
return true;
|
|
default:
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
size_t GetNumBitsToNextByte(size_t bit_pos) {
|
|
return (8 - (bit_pos % 8)) % 8;
|
|
}
|
|
|
|
// Defines and returns current MARK-V version.
|
|
uint32_t GetMarkvVersion() {
|
|
const uint32_t kVersionMajor = 1;
|
|
const uint32_t kVersionMinor = 2;
|
|
return kVersionMinor | (kVersionMajor << 16);
|
|
}
|
|
|
|
class CommentLogger {
|
|
public:
|
|
void AppendText(const std::string& str) {
|
|
Append(str);
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendTextNewLine(const std::string& str) {
|
|
Append(str);
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void AppendBitSequence(const std::string& str) {
|
|
if (use_delimiter_)
|
|
Append("-");
|
|
Append(str);
|
|
use_delimiter_ = true;
|
|
}
|
|
|
|
void AppendWhitespaces(size_t num) {
|
|
Append(std::string(num, ' '));
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
void NewLine() {
|
|
Append("\n");
|
|
use_delimiter_ = false;
|
|
}
|
|
|
|
std::string GetText() const {
|
|
return ss_.str();
|
|
}
|
|
|
|
private:
|
|
void Append(const std::string& str) {
|
|
ss_ << str;
|
|
// std::cerr << str;
|
|
}
|
|
|
|
std::stringstream ss_;
|
|
|
|
// If true a delimiter will be appended before the next bit sequence.
|
|
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
|
|
bool use_delimiter_ = false;
|
|
};
|
|
|
|
// Creates spv_text object containing text from |str|.
|
|
// The returned value is owned by the caller and needs to be destroyed with
|
|
// spvTextDestroy.
|
|
spv_text CreateSpvText(const std::string& str) {
|
|
spv_text out = new spv_text_t();
|
|
assert(out);
|
|
char* cstr = new char[str.length() + 1];
|
|
assert(cstr);
|
|
std::strncpy(cstr, str.c_str(), str.length());
|
|
cstr[str.length()] = '\0';
|
|
out->str = cstr;
|
|
out->length = str.length();
|
|
return out;
|
|
}
|
|
|
|
// Base class for MARK-V encoder and decoder. Contains common functionality
|
|
// such as:
|
|
// - Validator connection and validation state.
|
|
// - SPIR-V grammar and helper functions.
|
|
class MarkvCodecBase {
|
|
public:
|
|
virtual ~MarkvCodecBase() {
|
|
spvValidatorOptionsDestroy(validator_options_);
|
|
}
|
|
|
|
MarkvCodecBase() = delete;
|
|
|
|
void SetModel(const MarkvModel* model) {
|
|
model_ = model;
|
|
}
|
|
|
|
protected:
|
|
struct MarkvHeader {
|
|
MarkvHeader() {
|
|
magic_number = kMarkvMagicNumber;
|
|
markv_version = GetMarkvVersion();
|
|
markv_model = 0;
|
|
markv_length_in_bits = 0;
|
|
spirv_version = 0;
|
|
spirv_generator = 0;
|
|
}
|
|
|
|
uint32_t magic_number;
|
|
uint32_t markv_version;
|
|
// Magic number to identify or verify MarkvModel used for encoding.
|
|
uint32_t markv_model;
|
|
uint32_t markv_length_in_bits;
|
|
uint32_t spirv_version;
|
|
uint32_t spirv_generator;
|
|
};
|
|
|
|
explicit MarkvCodecBase(spv_const_context context,
|
|
spv_validator_options validator_options)
|
|
: validator_options_(validator_options), grammar_(context),
|
|
model_(GetDefaultModel()), context_(context),
|
|
vstate_(validator_options ?
|
|
new ValidationState_t(context, validator_options_) : nullptr) {}
|
|
|
|
// Validates a single instruction and updates validation state of the module.
|
|
// Does nothing and returns SPV_SUCCESS if validator was not created.
|
|
spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
|
|
if (!vstate_)
|
|
return SPV_SUCCESS;
|
|
|
|
return ValidateInstructionAndUpdateValidationState(vstate_.get(), &inst);
|
|
}
|
|
|
|
// Returns instruction which created |id| or nullptr if such instruction was
|
|
// not registered.
|
|
const Instruction* FindDef(uint32_t id) const {
|
|
const auto it = id_to_def_instruction_.find(id);
|
|
if (it == id_to_def_instruction_.end())
|
|
return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
// Returns type id of vector type component.
|
|
uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
|
|
const Instruction* type_inst = FindDef(vector_type_id);
|
|
assert(type_inst);
|
|
assert(type_inst->opcode() == SpvOpTypeVector);
|
|
|
|
const uint32_t component_type =
|
|
type_inst->word(type_inst->operands()[1].offset);
|
|
return component_type;
|
|
}
|
|
|
|
// Returns mtf handle for ids of given type.
|
|
uint64_t GetMtfIdOfType(uint32_t type_id) const {
|
|
return kMtfIdOfTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for ids generated by given opcode.
|
|
uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
|
|
return kMtfIdGeneratedByOpcode + opcode;
|
|
}
|
|
|
|
// Returns mtf handle for ids of type generated by given opcode.
|
|
uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
|
|
return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
|
|
}
|
|
|
|
// Returns mtf handle for vectors of specific component type.
|
|
uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
|
|
return kMtfVectorOfComponentTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for float vectors of specific size.
|
|
uint64_t GetMtfFloatVectorOfSize(uint32_t size) const {
|
|
return kMtfFloatVectorOfSizeBegin + size;
|
|
}
|
|
|
|
// Returns mtf handle for vector type of specific size.
|
|
uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
|
|
return kMtfTypeVectorOfSizeBegin + size;
|
|
}
|
|
|
|
// Returns mtf handle for pointers to specific size.
|
|
uint64_t GetMtfPointerToType(uint32_t type_id) const {
|
|
return kMtfPointerToTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for function types with given return type.
|
|
uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
|
|
return kMtfFunctionTypeWithReturnTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for functions with given return type.
|
|
uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
|
|
return kMtfFunctionWithReturnTypeBegin + type_id;
|
|
}
|
|
|
|
// Returns mtf handle for the given id descriptor.
|
|
uint64_t GetMtfIdDescriptor(uint32_t descriptor) const {
|
|
return kMtfIdDescriptorSpaceBegin + descriptor;
|
|
}
|
|
|
|
// Process data from the current instruction. This would update MTFs and
|
|
// other data containers.
|
|
void ProcessCurInstruction();
|
|
|
|
// Returns move-to-front handle to be used for the current operand slot.
|
|
// Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
|
|
uint64_t GetRuleBasedMtf();
|
|
|
|
// Returns words of the current instruction. Decoder has a different
|
|
// implementation and the array is valid only until the previously decoded
|
|
// word.
|
|
virtual const uint32_t* GetInstWords() const {
|
|
return inst_.words;
|
|
}
|
|
|
|
// Returns the opcode of the previous instruction.
|
|
SpvOp GetPrevOpcode() const {
|
|
if (instructions_.empty())
|
|
return SpvOpNop;
|
|
|
|
return instructions_.back()->opcode();
|
|
}
|
|
|
|
// Returns diagnostic stream, position index is set to instruction number.
|
|
DiagnosticStream Diag(spv_result_t error_code) const {
|
|
return DiagnosticStream({0, 0, instructions_.size()},
|
|
context_->consumer, error_code);
|
|
}
|
|
|
|
// Returns current id bound.
|
|
uint32_t GetIdBound() const {
|
|
return id_bound_;
|
|
}
|
|
|
|
// Sets current id bound, expected to be no lower than the previous one.
|
|
void SetIdBound(uint32_t id_bound) {
|
|
assert(id_bound >= id_bound_);
|
|
id_bound_ = id_bound;
|
|
if (vstate_)
|
|
vstate_->setIdBound(id_bound);
|
|
}
|
|
|
|
spv_validator_options validator_options_ = nullptr;
|
|
const libspirv::AssemblyGrammar grammar_;
|
|
MarkvHeader header_;
|
|
const MarkvModel* model_ = nullptr;
|
|
|
|
// Current instruction, current operand and current operand index.
|
|
spv_parsed_instruction_t inst_;
|
|
spv_parsed_operand_t operand_;
|
|
uint32_t operand_index_;
|
|
|
|
// Maps a result ID to its type ID. By convention:
|
|
// - a result ID that is a type definition maps to itself.
|
|
// - a result ID without a type maps to 0. (E.g. for OpLabel)
|
|
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
|
|
|
|
// Container for all move-to-front sequences.
|
|
MultiMoveToFront multi_mtf_;
|
|
|
|
// Id of the current function or zero if outside of function.
|
|
uint32_t cur_function_id_ = 0;
|
|
|
|
// Return type of the current function.
|
|
uint32_t cur_function_return_type_ = 0;
|
|
|
|
// Remaining function parameter types. This container is filled on OpFunction,
|
|
// and drained on OpFunctionParameter.
|
|
std::list<uint32_t> remaining_function_parameter_types_;
|
|
|
|
// List of ids local to the current function.
|
|
std::vector<uint32_t> ids_local_to_cur_function_;
|
|
|
|
// List of instructions in the order they are given in the module.
|
|
std::vector<std::unique_ptr<const Instruction>> instructions_;
|
|
|
|
// Maps used for the 'presumed id' techniques. Maps small constant integer
|
|
// value to its id and back.
|
|
std::map<uint32_t, uint32_t> presumed_index_to_id_;
|
|
std::map<uint32_t, uint32_t> id_to_presumed_index_;
|
|
|
|
// Container/computer for id descriptors.
|
|
IdDescriptorCollection id_descriptors_;
|
|
|
|
private:
|
|
spv_const_context context_ = nullptr;
|
|
|
|
std::unique_ptr<ValidationState_t> vstate_;
|
|
|
|
// Maps result id to the instruction which defined it.
|
|
std::unordered_map<uint32_t, const Instruction*> id_to_def_instruction_;
|
|
|
|
uint32_t id_bound_ = 1;
|
|
};
|
|
|
|
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
|
|
// EncodeInstruction which can be used as callback by spvBinaryParse.
|
|
// Encoded binary is written to an internally maintained bitstream.
|
|
// After the last instruction is encoded, the resulting MARK-V binary can be
|
|
// acquired by calling GetMarkvBinary().
|
|
// The encoder uses SPIR-V validator to keep internal state, therefore
|
|
// SPIR-V binary needs to be able to pass validator checks.
|
|
// CreateCommentsLogger() can be used to enable the encoder to write comments
|
|
// on how encoding was done, which can later be accessed with GetComments().
|
|
class MarkvEncoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvEncoder(spv_const_context context,
|
|
spv_const_markv_encoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options) {
|
|
(void) options_;
|
|
}
|
|
|
|
// Writes data from SPIR-V header to MARK-V header.
|
|
spv_result_t EncodeHeader(
|
|
spv_endianness_t /* endian */, uint32_t /* magic */,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t /* schema */) {
|
|
SetIdBound(id_bound);
|
|
header_.spirv_version = version;
|
|
header_.spirv_generator = generator;
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
|
|
// Operation can fail if the instruction fails to pass the validator or if
|
|
// the encoder stubmles on something unexpected.
|
|
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
|
|
|
|
// Concatenates MARK-V header and the bit stream with encoded instructions
|
|
// into a single buffer and returns it as spv_markv_binary. The returned
|
|
// value is owned by the caller and needs to be destroyed with
|
|
// spvMarkvBinaryDestroy().
|
|
spv_markv_binary GetMarkvBinary() {
|
|
header_.markv_length_in_bits =
|
|
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
|
|
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
|
|
|
|
spv_markv_binary markv_binary = new spv_markv_binary_t();
|
|
markv_binary->data = new uint8_t[num_bytes];
|
|
markv_binary->length = num_bytes;
|
|
assert(writer_.GetData());
|
|
std::memcpy(markv_binary->data, &header_, sizeof(header_));
|
|
std::memcpy(markv_binary->data + sizeof(header_),
|
|
writer_.GetData(), writer_.GetDataSizeBytes());
|
|
return markv_binary;
|
|
}
|
|
|
|
// Creates an internal logger which writes comments on the encoding process.
|
|
// Output can later be accessed with GetComments().
|
|
void CreateCommentsLogger() {
|
|
logger_.reset(new CommentLogger());
|
|
writer_.SetCallback([this](const std::string& str){
|
|
logger_->AppendBitSequence(str);
|
|
});
|
|
}
|
|
|
|
// Optionally adds disassembly to the comments.
|
|
// Disassembly should contain all instructions in the module separated by
|
|
// \n, and no header.
|
|
void SetDisassembly(std::string&& disassembly) {
|
|
disassembly_.reset(new std::stringstream(std::move(disassembly)));
|
|
}
|
|
|
|
// Extracts the next instruction line from the disassembly and logs it.
|
|
void LogDisassemblyInstruction() {
|
|
if (logger_ && disassembly_) {
|
|
std::string line;
|
|
std::getline(*disassembly_, line, '\n');
|
|
logger_->AppendTextNewLine(line);
|
|
}
|
|
}
|
|
|
|
// Extracts the text from the comment logger.
|
|
std::string GetComments() const {
|
|
if (!logger_)
|
|
return "";
|
|
return logger_->GetText();
|
|
}
|
|
|
|
private:
|
|
// Creates and returns validator options. Returned value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_encoder_options options) {
|
|
return options->validate_spirv_binary ?
|
|
spvValidatorOptionsCreate() : nullptr;
|
|
}
|
|
|
|
// Writes a single word to bit stream. operand_.type determines if the word is
|
|
// encoded and how.
|
|
spv_result_t EncodeNonIdWord(uint32_t word);
|
|
|
|
// Writes both opcode and num_operands as a single code.
|
|
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
|
spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, uint32_t num_operands);
|
|
|
|
// Writes mtf rank to bit stream. |mtf| is used to determine the codec
|
|
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
|
spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
|
uint64_t fallback_method);
|
|
|
|
// Writes id using coding based on mtf associated with the id descriptor.
|
|
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
|
spv_result_t EncodeIdWithDescriptor(uint32_t id);
|
|
|
|
// Writes id using coding based on the given |mtf|, which is expected to
|
|
// contain the given |id|.
|
|
spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
|
|
|
|
// Writes type id of the current instruction if can't be inferred.
|
|
spv_result_t EncodeTypeId();
|
|
|
|
// Writes result id of the current instruction if can't be inferred.
|
|
spv_result_t EncodeResultId();
|
|
|
|
// Writes ids which are neither type nor result ids.
|
|
spv_result_t EncodeRefId(uint32_t id);
|
|
|
|
// Writes bits to the stream until the beginning of the next byte if the
|
|
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
|
void AddByteBreak(size_t byte_break_if_less_than);
|
|
|
|
// Encodes a literal number operand and writes it to the bit stream.
|
|
spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
|
|
|
|
spv_const_markv_encoder_options options_;
|
|
|
|
// Bit stream where encoded instructions are written.
|
|
BitWriterWord64 writer_;
|
|
|
|
// If not nullptr, encoder will write comments.
|
|
std::unique_ptr<CommentLogger> logger_;
|
|
|
|
// If not nullptr, disassembled instruction lines will be written to comments.
|
|
// Format: \n separated instruction lines, no header.
|
|
std::unique_ptr<std::stringstream> disassembly_;
|
|
};
|
|
|
|
// Decodes MARK-V buffers written by MarkvEncoder.
|
|
class MarkvDecoder : public MarkvCodecBase {
|
|
public:
|
|
MarkvDecoder(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options)
|
|
: MarkvCodecBase(context, GetValidatorOptions(options)),
|
|
options_(options), reader_(markv_data, markv_size_bytes) {
|
|
(void) options_;
|
|
SetIdBound(1);
|
|
parsed_operands_.reserve(25);
|
|
inst_words_.reserve(25);
|
|
}
|
|
|
|
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
|
|
// Can be called only once. Fails if data of wrong format or ends prematurely,
|
|
// of if validation fails.
|
|
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
|
|
|
|
private:
|
|
// Describes the format of a typed literal number.
|
|
struct NumberType {
|
|
spv_number_kind_t type;
|
|
uint32_t bit_width;
|
|
};
|
|
|
|
// Creates and returns validator options. Returned value owned by the caller.
|
|
static spv_validator_options GetValidatorOptions(
|
|
spv_const_markv_decoder_options options) {
|
|
return options->validate_spirv_binary ?
|
|
spvValidatorOptionsCreate() : nullptr;
|
|
}
|
|
|
|
// Reads a single bit from reader_. The read bit is stored in |bit|.
|
|
// Returns false iff reader_ fails.
|
|
bool ReadBit(bool* bit) {
|
|
uint64_t bits = 0;
|
|
const bool result = reader_.ReadBits(&bits, 1);
|
|
if (result)
|
|
*bit = bits ? true : false;
|
|
return result;
|
|
};
|
|
|
|
// Returns ReadBit bound to the class object.
|
|
std::function<bool(bool*)> GetReadBitCallback() {
|
|
return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
|
|
}
|
|
|
|
// Reads a single non-id word from bit stream. operand_.type determines if
|
|
// the word needs to be decoded and how.
|
|
spv_result_t DecodeNonIdWord(uint32_t* word);
|
|
|
|
// Reads and decodes both opcode and num_operands as a single code.
|
|
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
|
|
spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
|
|
uint32_t* num_operands);
|
|
|
|
// Reads mtf rank from bit stream. |mtf| is used to determine the codec
|
|
// scheme. |fallback_method| is used if no codec defined for |mtf|.
|
|
spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
|
|
uint32_t* rank);
|
|
|
|
// Reads id using coding based on mtf associated with the id descriptor.
|
|
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
|
|
spv_result_t DecodeIdWithDescriptor(uint32_t* id);
|
|
|
|
// Reads id using coding based on the given |mtf|, which is expected to
|
|
// contain the needed |id|.
|
|
spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
|
|
|
|
// Reads type id of the current instruction if can't be inferred.
|
|
spv_result_t DecodeTypeId();
|
|
|
|
// Reads result id of the current instruction if can't be inferred.
|
|
spv_result_t DecodeResultId();
|
|
|
|
// Reads id which is neither type nor result id.
|
|
spv_result_t DecodeRefId(uint32_t* id);
|
|
|
|
// Reads and discards bits until the beginning of the next byte if the
|
|
// number of bits until the next byte is less than |byte_break_if_less_than|.
|
|
bool ReadToByteBreak(size_t byte_break_if_less_than);
|
|
|
|
// Returns instruction words decoded up to this point.
|
|
const uint32_t* GetInstWords() const override {
|
|
return inst_words_.data();
|
|
}
|
|
|
|
// Reads a literal number as it is described in |operand| from the bit stream,
|
|
// decodes and writes it to spirv_.
|
|
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
|
|
|
|
// Reads instruction from bit stream, decodes and validates it.
|
|
// Decoded instruction is valid until the next call of DecodeInstruction().
|
|
spv_result_t DecodeInstruction();
|
|
|
|
// Read operand from the stream decodes and validates it.
|
|
spv_result_t DecodeOperand(size_t operand_offset,
|
|
const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands);
|
|
|
|
// Records the numeric type for an operand according to the type information
|
|
// associated with the given non-zero type Id. This can fail if the type Id
|
|
// is not a type Id, or if the type Id does not reference a scalar numeric
|
|
// type. On success, return SPV_SUCCESS and populates the num_words,
|
|
// number_kind, and number_bit_width fields of parsed_operand.
|
|
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
|
|
uint32_t type_id);
|
|
|
|
// Records the number type for the current instruction, if it generates a
|
|
// type. For types that aren't scalar numbers, record something with number
|
|
// kind SPV_NUMBER_NONE.
|
|
void RecordNumberType();
|
|
|
|
spv_const_markv_decoder_options options_;
|
|
|
|
// Temporary sink where decoded SPIR-V words are written. Once it contains the
|
|
// entire module, the container is moved and returned.
|
|
std::vector<uint32_t> spirv_;
|
|
|
|
// Bit stream containing encoded data.
|
|
BitReaderWord64 reader_;
|
|
|
|
// Temporary storage for operands of the currently parsed instruction.
|
|
// Valid until next DecodeInstruction call.
|
|
std::vector<spv_parsed_operand_t> parsed_operands_;
|
|
|
|
// Temporary storage for current instruction words.
|
|
// Valid until next DecodeInstruction call.
|
|
std::vector<uint32_t> inst_words_;
|
|
|
|
// Maps a type ID to its number type description.
|
|
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
|
|
|
|
// Maps an ExtInstImport id to the extended instruction type.
|
|
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
|
|
};
|
|
|
|
void MarkvCodecBase::ProcessCurInstruction() {
|
|
instructions_.emplace_back(new Instruction(&inst_));
|
|
|
|
const SpvOp opcode = SpvOp(inst_.opcode);
|
|
|
|
if (inst_.result_id) {
|
|
id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
|
|
|
|
// Collect ids local to the current function.
|
|
if (cur_function_id_){
|
|
ids_local_to_cur_function_.push_back(inst_.result_id);
|
|
}
|
|
|
|
// Starting new function.
|
|
if (opcode == SpvOpFunction) {
|
|
cur_function_id_ = inst_.result_id;
|
|
cur_function_return_type_ = inst_.type_id;
|
|
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
|
|
inst_.result_id);
|
|
|
|
// Store function parameter types in a queue, so that we know which types
|
|
// to expect in the following OpFunctionParameter instructions.
|
|
const Instruction* def_inst = FindDef(inst_.words[4]);
|
|
assert(def_inst);
|
|
assert(def_inst->opcode() == SpvOpTypeFunction);
|
|
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
|
|
remaining_function_parameter_types_.push_back(def_inst->word(i));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove local ids from MTFs if function end.
|
|
if (opcode == SpvOpFunctionEnd) {
|
|
cur_function_id_ = 0;
|
|
for (uint32_t id : ids_local_to_cur_function_)
|
|
multi_mtf_.RemoveFromAll(id);
|
|
ids_local_to_cur_function_.clear();
|
|
assert(remaining_function_parameter_types_.empty());
|
|
}
|
|
|
|
if (!inst_.result_id)
|
|
return;
|
|
|
|
{
|
|
// Save the result ID to type ID mapping.
|
|
// In the grammar, type ID always appears before result ID.
|
|
// A regular value maps to its type. Some instructions (e.g. OpLabel)
|
|
// have no type Id, and will map to 0. The result Id for a
|
|
// type-generating instruction (e.g. OpTypeInt) maps to itself.
|
|
auto insertion_result = id_to_type_id_.emplace(
|
|
inst_.result_id,
|
|
spvOpcodeGeneratesType(SpvOp(inst_.opcode)) ?
|
|
inst_.result_id : inst_.type_id);
|
|
(void)insertion_result;
|
|
assert(insertion_result.second);
|
|
}
|
|
|
|
// Add result_id to MTFs.
|
|
|
|
switch (opcode) {
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypePointer:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeImage:
|
|
case SpvOpTypeSampler:
|
|
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (spvOpcodeIsComposite(opcode)) {
|
|
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpLabel) {
|
|
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeInt) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeFloat) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeBool) {
|
|
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
|
|
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (opcode == SpvOpTypeVector) {
|
|
const uint32_t component_type_id = inst_.words[2];
|
|
const uint32_t size = inst_.words[3];
|
|
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
|
|
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
|
|
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
|
|
component_type_id)) {
|
|
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
|
|
}
|
|
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
|
|
}
|
|
|
|
if (inst_.opcode == SpvOpTypeFunction) {
|
|
const uint32_t return_type = inst_.words[2];
|
|
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
|
|
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
|
|
inst_.result_id);
|
|
}
|
|
|
|
if (inst_.type_id) {
|
|
const Instruction* type_inst = FindDef(inst_.type_id);
|
|
assert(type_inst);
|
|
|
|
multi_mtf_.Insert(kMtfObject, inst_.result_id);
|
|
|
|
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
|
|
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
|
|
}
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
|
|
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
|
|
|
|
if (inst_.opcode == SpvOpConstant) {
|
|
if (multi_mtf_.HasValue(
|
|
GetMtfIdGeneratedByOpcode(SpvOpTypeInt), inst_.type_id)) {
|
|
multi_mtf_.Insert(kMtfConstInteger, inst_.result_id);
|
|
const uint32_t presumed_index = inst_.words[3];
|
|
if (presumed_index <= kMarkvMaxPresumedAccessIndex) {
|
|
const auto result =
|
|
presumed_index_to_id_.emplace(presumed_index, inst_.result_id);
|
|
if (result.second) {
|
|
id_to_presumed_index_.emplace(inst_.result_id, presumed_index);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
switch (type_inst->opcode()) {
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
case SpvOpTypePointer:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeImage:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeSampler:
|
|
multi_mtf_.Insert(GetMtfIdWithTypeGeneratedByOpcode(
|
|
type_inst->opcode()), inst_.result_id);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (type_inst->opcode() == SpvOpTypeVector) {
|
|
const uint32_t component_type = type_inst->word(2);
|
|
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
|
|
inst_.result_id);
|
|
}
|
|
|
|
if (type_inst->opcode() == SpvOpTypePointer) {
|
|
assert(type_inst->operands().size() > 2);
|
|
assert(type_inst->words().size() > type_inst->operands()[2].offset);
|
|
const uint32_t data_type =
|
|
type_inst->word(type_inst->operands()[2].offset);
|
|
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
|
|
|
|
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
|
|
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
|
|
}
|
|
}
|
|
|
|
if (spvOpcodeGeneratesType(opcode)) {
|
|
if (opcode != SpvOpTypeFunction) {
|
|
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
|
|
}
|
|
}
|
|
|
|
const uint32_t descriptor = id_descriptors_.ProcessInstruction(inst_);
|
|
if (model_->DescriptorHasCodingScheme(descriptor))
|
|
multi_mtf_.Insert(GetMtfIdDescriptor(descriptor), inst_.result_id);
|
|
}
|
|
|
|
uint64_t MarkvCodecBase::GetRuleBasedMtf() {
|
|
// This function is only called for id operands (but not result ids).
|
|
assert(spvIsIdType(operand_.type) ||
|
|
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
|
|
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
|
|
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
// All operand slots which expect label id.
|
|
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
|
|
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
|
|
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
|
|
(inst_.opcode == SpvOpBranchConditional &&
|
|
(operand_index_ == 1 || operand_index_ == 2 )) ||
|
|
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
|
|
operand_index_ % 2 == 1) ||
|
|
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
|
|
return kMtfLabel;
|
|
}
|
|
|
|
switch (opcode) {
|
|
case SpvOpFAdd:
|
|
case SpvOpFSub:
|
|
case SpvOpFMul:
|
|
case SpvOpFDiv:
|
|
case SpvOpFRem:
|
|
case SpvOpFMod:
|
|
case SpvOpFNegate: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeFloatScalarOrVector;
|
|
|
|
return GetMtfIdOfType(inst_.type_id);
|
|
}
|
|
|
|
case SpvOpISub:
|
|
case SpvOpIAdd:
|
|
case SpvOpIMul:
|
|
case SpvOpSDiv:
|
|
case SpvOpUDiv:
|
|
case SpvOpSMod:
|
|
case SpvOpUMod:
|
|
case SpvOpSRem:
|
|
case SpvOpSNegate: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeIntScalarOrVector;
|
|
|
|
return kMtfIntScalarOrVector;
|
|
}
|
|
|
|
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
|
|
|
|
case SpvOpFOrdEqual:
|
|
case SpvOpFUnordEqual:
|
|
case SpvOpFOrdNotEqual:
|
|
case SpvOpFUnordNotEqual:
|
|
case SpvOpFOrdLessThan:
|
|
case SpvOpFUnordLessThan:
|
|
case SpvOpFOrdGreaterThan:
|
|
case SpvOpFUnordGreaterThan:
|
|
case SpvOpFOrdLessThanEqual:
|
|
case SpvOpFUnordLessThanEqual:
|
|
case SpvOpFOrdGreaterThanEqual:
|
|
case SpvOpFUnordGreaterThanEqual: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeBoolScalarOrVector;
|
|
if (operand_index_ == 2)
|
|
return kMtfFloatScalarOrVector;
|
|
if (operand_index_ == 3) {
|
|
const uint32_t first_operand_id = GetInstWords()[3];
|
|
const uint32_t first_operand_type =
|
|
id_to_type_id_.at(first_operand_id);
|
|
return GetMtfIdOfType(first_operand_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpVectorShuffle: {
|
|
if (operand_index_ == 0) {
|
|
assert(inst_.num_operands > 4);
|
|
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
|
|
}
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2 || operand_index_ == 3)
|
|
return GetMtfVectorOfComponentType(
|
|
GetVectorComponentType(inst_.type_id));
|
|
break;
|
|
}
|
|
|
|
case SpvOpVectorTimesScalar: {
|
|
if (operand_index_ == 0) {
|
|
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
}
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdOfType(inst_.type_id);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
|
|
break;
|
|
}
|
|
|
|
case SpvOpDot: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
|
|
|
|
assert(inst_.type_id);
|
|
if (operand_index_ == 2)
|
|
return GetMtfVectorOfComponentType(inst_.type_id);
|
|
if (operand_index_ == 3) {
|
|
const uint32_t vector_id = GetInstWords()[3];
|
|
const uint32_t vector_type = id_to_type_id_.at(vector_id);
|
|
return GetMtfIdOfType(vector_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeVector: {
|
|
if (operand_index_ == 1) {
|
|
return kMtfTypeScalar;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeMatrix: {
|
|
if (operand_index_ == 1) {
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypePointer: {
|
|
if (operand_index_ == 2) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeStruct: {
|
|
if (operand_index_ >= 1) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpTypeFunction: {
|
|
if (operand_index_ == 1) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
|
|
if (operand_index_ >= 2) {
|
|
return kMtfTypeNonFunction;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpLoad: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeNonFunction;
|
|
|
|
if (operand_index_ == 2) {
|
|
assert(inst_.type_id);
|
|
return GetMtfPointerToType(inst_.type_id);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpStore: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
|
|
if (operand_index_ == 1) {
|
|
const uint32_t pointer_id = GetInstWords()[1];
|
|
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
|
|
const Instruction* pointer_inst = FindDef(pointer_type);
|
|
assert(pointer_inst);
|
|
assert(pointer_inst->opcode() == SpvOpTypePointer);
|
|
const uint32_t data_type =
|
|
pointer_inst->word(pointer_inst->operands()[2].offset);
|
|
return GetMtfIdOfType(data_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpVariable: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
|
break;
|
|
}
|
|
|
|
case SpvOpAccessChain: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
|
|
if (operand_index_ == 2)
|
|
return kMtfTypePointerToComposite;
|
|
if (operand_index_ >= 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
|
|
break;
|
|
}
|
|
|
|
case SpvOpCompositeConstruct: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeComposite;
|
|
if (operand_index_ >= 2) {
|
|
const uint32_t composite_type = GetInstWords()[1];
|
|
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
|
|
return kMtfFloatScalarOrVector;
|
|
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
|
|
return kMtfIntScalarOrVector;
|
|
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
|
|
return kMtfBoolScalarOrVector;
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpCompositeExtract: {
|
|
if (operand_index_ == 2)
|
|
return kMtfComposite;
|
|
break;
|
|
}
|
|
|
|
case SpvOpConstantComposite: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeComposite;
|
|
if (operand_index_ >= 2) {
|
|
const Instruction* composite_type_inst = FindDef(inst_.type_id);
|
|
assert(composite_type_inst);
|
|
if (composite_type_inst->opcode() == SpvOpTypeVector) {
|
|
return GetMtfIdOfType(composite_type_inst->word(2));
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpExtInst: {
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
|
|
if (operand_index_ >= 4) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
const uint32_t ext_inst_type = inst_.ext_inst_type;
|
|
const uint32_t ext_inst_index = GetInstWords()[4];
|
|
// TODO(atgoo@github.com) The list of extended instructions is
|
|
// incomplete. Only common instructions and low-hanging fruits listed.
|
|
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
|
|
switch (ext_inst_index) {
|
|
case GLSLstd450FAbs:
|
|
case GLSLstd450FClamp:
|
|
case GLSLstd450FMax:
|
|
case GLSLstd450FMin:
|
|
case GLSLstd450FMix:
|
|
case GLSLstd450Step:
|
|
case GLSLstd450SmoothStep:
|
|
case GLSLstd450Fma:
|
|
case GLSLstd450Pow:
|
|
case GLSLstd450Exp:
|
|
case GLSLstd450Exp2:
|
|
case GLSLstd450Log:
|
|
case GLSLstd450Log2:
|
|
case GLSLstd450Sqrt:
|
|
case GLSLstd450InverseSqrt:
|
|
case GLSLstd450Fract:
|
|
case GLSLstd450Floor:
|
|
case GLSLstd450Ceil:
|
|
case GLSLstd450Radians:
|
|
case GLSLstd450Degrees:
|
|
case GLSLstd450Sin:
|
|
case GLSLstd450Cos:
|
|
case GLSLstd450Tan:
|
|
case GLSLstd450Sinh:
|
|
case GLSLstd450Cosh:
|
|
case GLSLstd450Tanh:
|
|
case GLSLstd450Asin:
|
|
case GLSLstd450Acos:
|
|
case GLSLstd450Atan:
|
|
case GLSLstd450Atan2:
|
|
case GLSLstd450Asinh:
|
|
case GLSLstd450Acosh:
|
|
case GLSLstd450Atanh:
|
|
case GLSLstd450MatrixInverse:
|
|
case GLSLstd450Cross:
|
|
case GLSLstd450Normalize:
|
|
case GLSLstd450Reflect:
|
|
case GLSLstd450FaceForward:
|
|
return GetMtfIdOfType(return_type);
|
|
case GLSLstd450Length:
|
|
case GLSLstd450Distance:
|
|
case GLSLstd450Refract:
|
|
return kMtfFloatScalarOrVector;
|
|
default:
|
|
break;
|
|
}
|
|
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
|
|
switch (ext_inst_index) {
|
|
case OpenCLLIB::Fabs:
|
|
case OpenCLLIB::FClamp:
|
|
case OpenCLLIB::Fmax:
|
|
case OpenCLLIB::Fmin:
|
|
case OpenCLLIB::Step:
|
|
case OpenCLLIB::Smoothstep:
|
|
case OpenCLLIB::Fma:
|
|
case OpenCLLIB::Pow:
|
|
case OpenCLLIB::Exp:
|
|
case OpenCLLIB::Exp2:
|
|
case OpenCLLIB::Log:
|
|
case OpenCLLIB::Log2:
|
|
case OpenCLLIB::Sqrt:
|
|
case OpenCLLIB::Rsqrt:
|
|
case OpenCLLIB::Fract:
|
|
case OpenCLLIB::Floor:
|
|
case OpenCLLIB::Ceil:
|
|
case OpenCLLIB::Radians:
|
|
case OpenCLLIB::Degrees:
|
|
case OpenCLLIB::Sin:
|
|
case OpenCLLIB::Cos:
|
|
case OpenCLLIB::Tan:
|
|
case OpenCLLIB::Sinh:
|
|
case OpenCLLIB::Cosh:
|
|
case OpenCLLIB::Tanh:
|
|
case OpenCLLIB::Asin:
|
|
case OpenCLLIB::Acos:
|
|
case OpenCLLIB::Atan:
|
|
case OpenCLLIB::Atan2:
|
|
case OpenCLLIB::Asinh:
|
|
case OpenCLLIB::Acosh:
|
|
case OpenCLLIB::Atanh:
|
|
case OpenCLLIB::Cross:
|
|
case OpenCLLIB::Normalize:
|
|
return GetMtfIdOfType(return_type);
|
|
case OpenCLLIB::Length:
|
|
case OpenCLLIB::Distance:
|
|
return kMtfFloatScalarOrVector;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpFunction: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeReturnedByFunction;
|
|
|
|
if (operand_index_ == 3) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
return GetMtfFunctionTypeWithReturnType(return_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpFunctionCall: {
|
|
if (operand_index_ == 0)
|
|
return kMtfTypeReturnedByFunction;
|
|
|
|
if (operand_index_ == 2) {
|
|
const uint32_t return_type = GetInstWords()[1];
|
|
return GetMtfFunctionWithReturnType(return_type);
|
|
}
|
|
|
|
if (operand_index_ >= 3) {
|
|
const uint32_t function_id = GetInstWords()[3];
|
|
const Instruction* function_inst = FindDef(function_id);
|
|
if (!function_inst)
|
|
return kMtfObject;
|
|
|
|
assert(function_inst->opcode() == SpvOpFunction);
|
|
|
|
const uint32_t function_type_id = function_inst->word(4);
|
|
const Instruction* function_type_inst = FindDef(function_type_id);
|
|
assert(function_type_inst);
|
|
assert(function_type_inst->opcode() == SpvOpTypeFunction);
|
|
|
|
const uint32_t argument_type =
|
|
function_type_inst->word(operand_index_);
|
|
return GetMtfIdOfType(argument_type);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SpvOpReturnValue: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdOfType(cur_function_return_type_);
|
|
break;
|
|
}
|
|
|
|
case SpvOpBranchConditional: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
|
|
break;
|
|
}
|
|
|
|
case SpvOpSampledImage: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
|
|
break;
|
|
}
|
|
|
|
case SpvOpImageSampleImplicitLod: {
|
|
if (operand_index_ == 0)
|
|
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
|
if (operand_index_ == 2)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
|
|
if (operand_index_ == 3)
|
|
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
|
|
break;
|
|
}
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return kMtfNone;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
|
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
|
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// Encoding successful.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// Encoding failed, write kMarkvNoneOfTheAbove flag.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Non-id word Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode))
|
|
<< " operand index " << operand_index_
|
|
<< " is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback encoding.
|
|
const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type);
|
|
if (chunk_length) {
|
|
writer_.WriteVariableWidthU32(word, chunk_length);
|
|
} else {
|
|
writer_.WriteUnencoded(word);
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
|
|
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
|
|
|
|
if (codec) {
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to decode non-id word with Huffman";
|
|
|
|
if (decoded_value != kMarkvNoneOfTheAbove) {
|
|
// The word decoded successfully.
|
|
*word = uint32_t(decoded_value);
|
|
assert(*word == decoded_value);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
|
}
|
|
|
|
const size_t chunk_length = GetOperandVariableWidthChunkLength(operand_.type);
|
|
if (chunk_length) {
|
|
if (!reader_.ReadVariableWidthU32(word, chunk_length))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to decode non-id word with varint";
|
|
} else {
|
|
if (!reader_.ReadUnencoded(word))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read unencoded non-id word";
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(
|
|
uint32_t opcode, uint32_t num_operands) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
|
|
const uint32_t word = opcode | (num_operands << 16);
|
|
|
|
// First try to use the Markov chain codec.
|
|
auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
|
if (codec) {
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and use fallback encoding.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "opcode_and_num_operands Huffman table for "
|
|
<< spvOpcodeString(GetPrevOpcode())
|
|
<< "is missing kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
// Fallback to base-rate codec.
|
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
|
assert(codec);
|
|
if (codec->Encode(word, &bits, &num_bits)) {
|
|
// The word was successfully encoded into bits/num_bits.
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_SUCCESS;
|
|
} else {
|
|
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
|
|
// and return false.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Global opcode_and_num_operands Huffman table is missing "
|
|
<< "kMarkvNoneOfTheAbove";
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
|
|
uint32_t* opcode, uint32_t* num_operands) {
|
|
// First try to use the Markov chain codec.
|
|
auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
|
|
if (codec) {
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode opcode_and_num_operands, previous opcode is "
|
|
<< spvOpcodeString(GetPrevOpcode());
|
|
|
|
if (decoded_value != kMarkvNoneOfTheAbove) {
|
|
// The word was successfully decoded.
|
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
|
*num_operands = uint32_t(decoded_value >> 16);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
|
|
}
|
|
|
|
// Fallback to base-rate codec.
|
|
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
|
|
assert(codec);
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode opcode_and_num_operands with global codec";
|
|
|
|
if (decoded_value == kMarkvNoneOfTheAbove) {
|
|
// Received kMarkvNoneOfTheAbove signal, fallback further.
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
*opcode = uint32_t(decoded_value & 0xFFFF);
|
|
*num_operands = uint32_t(decoded_value >> 16);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
|
|
uint64_t fallback_method) {
|
|
const auto* codec = model_->GetMtfHuffmanCodec(mtf);
|
|
if (!codec) {
|
|
assert(fallback_method != kMtfNone);
|
|
codec = model_->GetMtfHuffmanCodec(fallback_method);
|
|
}
|
|
|
|
if (!codec)
|
|
return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
|
|
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
if (rank < kMtfSmallestRankEncodedByValue) {
|
|
// Encode using Huffman coding.
|
|
if (!codec->Encode(rank, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode MTF rank with Huffman";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
} else {
|
|
// Encode by value.
|
|
if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to encode kMtfRankEncodedByValueSignal";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue,
|
|
model_->mtf_rank_chunk_length());
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(
|
|
uint64_t mtf, uint32_t fallback_method, uint32_t* rank) {
|
|
const auto* codec = model_->GetMtfHuffmanCodec(mtf);
|
|
if (!codec) {
|
|
assert(fallback_method != kMtfNone);
|
|
codec = model_->GetMtfHuffmanCodec(fallback_method);
|
|
}
|
|
|
|
if (!codec)
|
|
return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
|
|
|
|
uint32_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode MTF rank with Huffman";
|
|
|
|
if (decoded_value == kMtfRankEncodedByValueSignal) {
|
|
// Decode by value.
|
|
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode MTF rank with varint";
|
|
*rank += kMtfSmallestRankEncodedByValue;
|
|
} else {
|
|
// Decode using Huffman coding.
|
|
assert(decoded_value < kMtfSmallestRankEncodedByValue);
|
|
*rank = decoded_value;
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
|
|
auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode,
|
|
operand_index_);
|
|
if (!codec)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
|
|
// Get the descriptor for id.
|
|
const uint32_t descriptor = id_descriptors_.GetDescriptor(id);
|
|
|
|
if (descriptor && codec->Encode(descriptor, &bits, &num_bits)) {
|
|
// If the descriptor exists and is in the table, write the descriptor and
|
|
// proceed to encoding the rank.
|
|
writer_.WriteBits(bits, num_bits);
|
|
} else {
|
|
// The descriptor doesn't exist or we have no coding for it. Write
|
|
// kMarkvNoneOfTheAbove and go to fallback method.
|
|
if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Descriptor Huffman table for "
|
|
<< spvOpcodeString(SpvOp(inst_.opcode))
|
|
<< " operand index " << operand_index_
|
|
<< " is missing kMarkvNoneOfTheAbove";
|
|
|
|
writer_.WriteBits(bits, num_bits);
|
|
return SPV_UNSUPPORTED;
|
|
}
|
|
|
|
// Descriptor has been encoded. Now encode the rank of the id in the
|
|
// associated mtf sequence.
|
|
const uint64_t mtf = GetMtfIdDescriptor(descriptor);
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
|
|
auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode,
|
|
operand_index_);
|
|
if (!codec)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
uint64_t decoded_value = 0;
|
|
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to decode descriptor with Huffman";
|
|
|
|
if (decoded_value == kMarkvNoneOfTheAbove)
|
|
return SPV_UNSUPPORTED;
|
|
|
|
// If descriptor exists then the id was encoded through descriptor mtf.
|
|
const uint32_t descriptor = uint32_t(decoded_value);
|
|
assert(descriptor == decoded_value);
|
|
assert(descriptor);
|
|
|
|
const uint64_t mtf = GetMtfIdDescriptor(descriptor);
|
|
return DecodeExistingId(mtf, id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
|
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
|
// If the sequence has only one element no need to write rank, the decoder
|
|
// would make the same decision.
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
uint32_t rank = 0;
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Id is not in the MTF sequence";
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
|
|
assert(multi_mtf_.GetSize(mtf) > 0);
|
|
*id = 0;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
if (multi_mtf_.GetSize(mtf) == 1) {
|
|
rank = 1;
|
|
} else {
|
|
const spv_result_t result =
|
|
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
assert(rank);
|
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "MTF rank is out of bounds";
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
|
|
// TODO(atgoo@github.com) This might not be needed as EncodeIdWithDescriptor
|
|
// can handle SpvOpAccessChain indices if enough statistics is collected.
|
|
if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) {
|
|
const auto it = id_to_presumed_index_.find(id);
|
|
if (it != id_to_presumed_index_.end()) {
|
|
writer_.WriteBits(1, 1);
|
|
writer_.WriteFixedWidth(it->second, kMarkvMaxPresumedAccessIndex);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
writer_.WriteBits(0, 1);
|
|
}
|
|
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
// Encode using rule-based mtf.
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
const bool can_forward_declare =
|
|
spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_);
|
|
|
|
if (mtf != kMtfNone && !can_forward_declare) {
|
|
assert(multi_mtf_.HasValue(kMtfAll, id));
|
|
return EncodeExistingId(mtf, id);
|
|
}
|
|
|
|
if (mtf == kMtfNone)
|
|
mtf = kMtfAll;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
|
|
// This is the first occurrence of a forward declared id.
|
|
multi_mtf_.Insert(kMtfAll, id);
|
|
multi_mtf_.Insert(kMtfForwardDeclared, id);
|
|
if (mtf != kMtfAll)
|
|
multi_mtf_.Insert(mtf, id);
|
|
rank = 0;
|
|
}
|
|
|
|
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
|
|
if (inst_.opcode == SpvOpAccessChain && operand_index_ >= 3) {
|
|
uint64_t use_presumed_index_technique = 0;
|
|
if (!reader_.ReadBits(&use_presumed_index_technique, 1))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read use_presumed_index_technique flag";
|
|
|
|
if (use_presumed_index_technique) {
|
|
uint64_t value = 0;
|
|
if (!reader_.ReadFixedWidth(&value, kMarkvMaxPresumedAccessIndex))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read presumed_index";
|
|
|
|
const uint32_t presumed_index = static_cast<uint32_t>(value);
|
|
|
|
const auto it = presumed_index_to_id_.find(presumed_index);
|
|
if (it == presumed_index_to_id_.end()) {
|
|
assert(0);
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Presumed index id not found";
|
|
}
|
|
|
|
*id = it->second;
|
|
return SPV_SUCCESS;
|
|
}
|
|
}
|
|
|
|
{
|
|
const spv_result_t result = DecodeIdWithDescriptor(id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
const bool can_forward_declare =
|
|
spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_);
|
|
|
|
if (mtf != kMtfNone && !can_forward_declare) {
|
|
return DecodeExistingId(mtf, id);
|
|
}
|
|
|
|
if (mtf == kMtfNone)
|
|
mtf = kMtfAll;
|
|
|
|
*id = 0;
|
|
|
|
uint32_t rank = 0;
|
|
|
|
{
|
|
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
if (rank == 0) {
|
|
// This is the first occurrence of a forward declared id.
|
|
*id = GetIdBound();
|
|
SetIdBound(*id + 1);
|
|
multi_mtf_.Insert(kMtfAll, *id);
|
|
multi_mtf_.Insert(kMtfForwardDeclared, *id);
|
|
if (mtf != kMtfAll)
|
|
multi_mtf_.Insert(mtf, *id);
|
|
} else {
|
|
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
|
|
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
|
|
}
|
|
|
|
assert(*id);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeTypeId() {
|
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
|
assert(!remaining_function_parameter_types_.empty());
|
|
assert(inst_.type_id == remaining_function_parameter_types_.front());
|
|
remaining_function_parameter_types_.pop_front();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
{
|
|
// Try to encode using id descriptor mtfs.
|
|
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
// If can't be done continue with other methods.
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
assert(!spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_));
|
|
|
|
if (mtf == kMtfNone) {
|
|
mtf = kMtfTypeNonFunction;
|
|
// Function types should have been handled by GetRuleBasedMtf.
|
|
assert(inst_.opcode != SpvOpFunction);
|
|
}
|
|
|
|
return EncodeExistingId(mtf, inst_.type_id);
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeTypeId() {
|
|
if (inst_.opcode == SpvOpFunctionParameter) {
|
|
assert(!remaining_function_parameter_types_.empty());
|
|
inst_.type_id = remaining_function_parameter_types_.front();
|
|
remaining_function_parameter_types_.pop_front();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
{
|
|
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
|
|
if (result != SPV_UNSUPPORTED)
|
|
return result;
|
|
}
|
|
|
|
uint64_t mtf = GetRuleBasedMtf();
|
|
assert(!spvOperandCanBeForwardDeclaredFunction(
|
|
SpvOp(inst_.opcode))(operand_index_));
|
|
|
|
if (mtf == kMtfNone) {
|
|
mtf = kMtfTypeNonFunction;
|
|
// Function types should have been handled by GetRuleBasedMtf.
|
|
assert(inst_.opcode != SpvOpFunction);
|
|
}
|
|
|
|
return DecodeExistingId(mtf, &inst_.type_id);
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeResultId() {
|
|
uint32_t rank = 0;
|
|
|
|
const uint64_t num_still_forward_declared =
|
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
|
|
|
if (num_still_forward_declared) {
|
|
// We write the rank only if kMtfForwardDeclared is not empty. If it is
|
|
// empty the decoder knows that there are no forward declared ids to expect.
|
|
if (multi_mtf_.RankFromValue(kMtfForwardDeclared,
|
|
inst_.result_id, &rank)) {
|
|
// This is a definition of a forward declared id. We can remove the id
|
|
// from kMtfForwardDeclared.
|
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to remove id from kMtfForwardDeclared";
|
|
writer_.WriteBits(1, 1);
|
|
writer_.WriteVariableWidthU32(
|
|
rank, model_->mtf_rank_chunk_length());
|
|
} else {
|
|
rank = 0;
|
|
writer_.WriteBits(0, 1);
|
|
}
|
|
}
|
|
|
|
if (!rank) {
|
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeResultId() {
|
|
uint32_t rank = 0;
|
|
|
|
const uint64_t num_still_forward_declared =
|
|
multi_mtf_.GetSize(kMtfForwardDeclared);
|
|
|
|
if (num_still_forward_declared) {
|
|
// Some ids were forward declared. Check if this id is one of them.
|
|
uint64_t id_was_forward_declared;
|
|
if (!reader_.ReadBits(&id_was_forward_declared, 1))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read id_was_forward_declared flag";
|
|
|
|
if (id_was_forward_declared) {
|
|
if (!reader_.ReadVariableWidthU32(
|
|
&rank, model_->mtf_rank_chunk_length()))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read MTF rank of forward declared id";
|
|
|
|
if (rank) {
|
|
// The id was forward declared, recover it from kMtfForwardDeclared.
|
|
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared,
|
|
rank, &inst_.result_id))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Forward declared MTF rank is out of bounds";
|
|
|
|
// We can now remove the id from kMtfForwardDeclared.
|
|
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
|
|
return Diag(SPV_ERROR_INTERNAL)
|
|
<< "Failed to remove id from kMtfForwardDeclared";
|
|
}
|
|
}
|
|
}
|
|
|
|
if (inst_.result_id == 0) {
|
|
// The id was not forward declared, issue a new id.
|
|
inst_.result_id = GetIdBound();
|
|
SetIdBound(inst_.result_id + 1);
|
|
}
|
|
|
|
if (!rank) {
|
|
multi_mtf_.Insert(kMtfAll, inst_.result_id);
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeLiteralNumber(
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width <= 32) {
|
|
const uint32_t word = inst_.words[operand.offset];
|
|
return EncodeNonIdWord(word);
|
|
} else {
|
|
assert(operand.number_bit_width <= 64);
|
|
const uint64_t word =
|
|
uint64_t(inst_.words[operand.offset]) |
|
|
(uint64_t(inst_.words[operand.offset + 1]) << 32);
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
std::memcpy(&val, &word, 8);
|
|
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent());
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
writer_.WriteUnencoded(word);
|
|
} else {
|
|
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
|
}
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeLiteralNumber(
|
|
const spv_parsed_operand_t& operand) {
|
|
if (operand.number_bit_width <= 32) {
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
inst_words_.push_back(word);
|
|
} else {
|
|
assert(operand.number_bit_width <= 64);
|
|
uint64_t word = 0;
|
|
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
|
|
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal U64";
|
|
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
|
|
int64_t val = 0;
|
|
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
|
|
model_->s64_block_exponent()))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal S64";
|
|
std::memcpy(&word, &val, 8);
|
|
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
|
|
if (!reader_.ReadUnencoded(&word))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal F64";
|
|
} else {
|
|
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
|
|
}
|
|
inst_words_.push_back(static_cast<uint32_t>(word));
|
|
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
|
|
const size_t num_bits_to_next_byte =
|
|
GetNumBitsToNextByte(writer_.GetNumBits());
|
|
if (num_bits_to_next_byte == 0 ||
|
|
num_bits_to_next_byte > byte_break_if_less_than)
|
|
return;
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<byte break>");
|
|
}
|
|
|
|
writer_.WriteBits(0, num_bits_to_next_byte);
|
|
}
|
|
|
|
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
|
|
const size_t num_bits_to_next_byte =
|
|
GetNumBitsToNextByte(reader_.GetNumReadBits());
|
|
if (num_bits_to_next_byte == 0 ||
|
|
num_bits_to_next_byte > byte_break_if_less_than)
|
|
return true;
|
|
|
|
|
|
uint64_t bits = 0;
|
|
if (!reader_.ReadBits(&bits, num_bits_to_next_byte))
|
|
return false;
|
|
|
|
assert(bits == 0);
|
|
if (bits != 0)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
spv_result_t MarkvEncoder::EncodeInstruction(
|
|
const spv_parsed_instruction_t& inst) {
|
|
SpvOp opcode = SpvOp(inst.opcode);
|
|
inst_ = inst;
|
|
|
|
const spv_result_t validation_result = UpdateValidationState(inst);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
|
|
LogDisassemblyInstruction();
|
|
|
|
const spv_result_t opcode_encodig_result =
|
|
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
|
|
if (opcode_encodig_result < 0)
|
|
return opcode_encodig_result;
|
|
|
|
if (opcode_encodig_result != SPV_SUCCESS) {
|
|
// Fallback encoding for opcode and num_operands.
|
|
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
|
|
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
// If the opcode has a variable number of operands, encode the number of
|
|
// operands with the instruction.
|
|
|
|
if (logger_)
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
|
|
writer_.WriteVariableWidthU16(inst.num_operands,
|
|
model_->num_operands_chunk_length());
|
|
}
|
|
}
|
|
|
|
// Write operands.
|
|
const uint32_t num_operands = inst_.num_operands;
|
|
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
|
|
operand_ = inst_.operands[operand_index_];
|
|
|
|
if (logger_) {
|
|
logger_->AppendWhitespaces(kCommentNumWhitespaces);
|
|
logger_->AppendText("<");
|
|
logger_->AppendText(spvOperandTypeStr(operand_.type));
|
|
logger_->AppendText(">");
|
|
}
|
|
|
|
switch (operand_.type) {
|
|
case SPV_OPERAND_TYPE_RESULT_ID:
|
|
case SPV_OPERAND_TYPE_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
const uint32_t id = inst_.words[operand_.offset];
|
|
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
|
|
const spv_result_t result = EncodeTypeId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
|
|
const spv_result_t result = EncodeResultId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
} else {
|
|
const spv_result_t result = EncodeRefId(id);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
|
|
multi_mtf_.Promote(id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
|
|
const spv_result_t result =
|
|
EncodeNonIdWord(inst_.words[operand_.offset]);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
|
|
const spv_result_t result = EncodeLiteralNumber(operand_);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING: {
|
|
const char* src = reinterpret_cast<const char*>(
|
|
&inst_.words[operand_.offset]);
|
|
|
|
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
|
|
if (codec) {
|
|
uint64_t bits = 0;
|
|
size_t num_bits = 0;
|
|
const std::string str = src;
|
|
if (codec->Encode(str, &bits, &num_bits)) {
|
|
writer_.WriteBits(bits, num_bits);
|
|
break;
|
|
} else {
|
|
bool result = codec->Encode("kMarkvNoneOfTheAbove",
|
|
&bits, &num_bits);
|
|
(void)result;
|
|
assert(result);
|
|
writer_.WriteBits(bits, num_bits);
|
|
}
|
|
}
|
|
|
|
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
|
|
if (length == operand_.num_words * 4)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to find terminal character of literal string";
|
|
for (size_t i = 0; i < length + 1; ++i)
|
|
writer_.WriteUnencoded(src[i]);
|
|
break;
|
|
}
|
|
|
|
default: {
|
|
for (int i = 0; i < operand_.num_words; ++i) {
|
|
const uint32_t word = inst_.words[operand_.offset + i];
|
|
const spv_result_t result = EncodeNonIdWord(word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte);
|
|
|
|
if (logger_) {
|
|
logger_->NewLine();
|
|
logger_->NewLine();
|
|
}
|
|
|
|
ProcessCurInstruction();
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
|
|
const bool header_read_success =
|
|
reader_.ReadUnencoded(&header_.magic_number) &&
|
|
reader_.ReadUnencoded(&header_.markv_version) &&
|
|
reader_.ReadUnencoded(&header_.markv_model) &&
|
|
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
|
|
reader_.ReadUnencoded(&header_.spirv_version) &&
|
|
reader_.ReadUnencoded(&header_.spirv_generator);
|
|
|
|
if (!header_read_success)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to read MARK-V header";
|
|
|
|
if (header_.markv_length_in_bits == 0)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Header markv_length_in_bits field is zero";
|
|
|
|
if (header_.magic_number != kMarkvMagicNumber)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has incorrect magic number";
|
|
|
|
// TODO(atgoo@github.com): Print version strings.
|
|
if (header_.markv_version != GetMarkvVersion())
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary and the codec have different versions";
|
|
|
|
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
|
|
spirv_.resize(5, 0);
|
|
spirv_[0] = kSpirvMagicNumber;
|
|
spirv_[1] = header_.spirv_version;
|
|
spirv_[2] = header_.spirv_generator;
|
|
|
|
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
|
|
inst_ = {};
|
|
const spv_result_t decode_result = DecodeInstruction();
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
|
|
const spv_result_t validation_result = UpdateValidationState(inst_);
|
|
if (validation_result != SPV_SUCCESS)
|
|
return validation_result;
|
|
}
|
|
|
|
|
|
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
|
|
!reader_.OnlyZeroesLeft()) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "MARK-V binary has wrong stated bit length "
|
|
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
|
|
}
|
|
|
|
// Decoding of the module is finished, validation state should have correct
|
|
// id bound.
|
|
spirv_[3] = GetIdBound();
|
|
|
|
*spirv_binary = std::move(spirv_);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// TODO(atgoo@github.com): The implementation borrows heavily from
|
|
// Parser::parseOperand.
|
|
// Consider coupling them together in some way once MARK-V codec is more mature.
|
|
// For now it's better to keep the code independent for experimentation
|
|
// purposes.
|
|
spv_result_t MarkvDecoder::DecodeOperand(
|
|
size_t operand_offset,
|
|
const spv_operand_type_t type,
|
|
spv_operand_pattern_t* expected_operands) {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
memset(&operand_, 0, sizeof(operand_));
|
|
|
|
assert((operand_offset >> 16) == 0);
|
|
operand_.offset = static_cast<uint16_t>(operand_offset);
|
|
operand_.type = type;
|
|
|
|
// Set default values, may be updated later.
|
|
operand_.number_kind = SPV_NUMBER_NONE;
|
|
operand_.number_bit_width = 0;
|
|
|
|
const size_t first_word_index = inst_words_.size();
|
|
|
|
switch (type) {
|
|
case SPV_OPERAND_TYPE_RESULT_ID: {
|
|
const spv_result_t result = DecodeResultId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(inst_.result_id);
|
|
SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
|
|
multi_mtf_.Promote(inst_.result_id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPE_ID: {
|
|
const spv_result_t result = DecodeTypeId();
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(inst_.type_id);
|
|
SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
|
|
multi_mtf_.Promote(inst_.type_id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_ID:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ID:
|
|
case SPV_OPERAND_TYPE_SCOPE_ID:
|
|
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
|
|
uint32_t id = 0;
|
|
const spv_result_t result = DecodeRefId(&id);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
if (id == 0)
|
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
|
|
|
|
if (type == SPV_OPERAND_TYPE_ID ||
|
|
type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
|
|
|
|
operand_.type = SPV_OPERAND_TYPE_ID;
|
|
|
|
if (opcode == SpvOpExtInst && operand_.offset == 3) {
|
|
// The current word is the extended instruction set id.
|
|
// Set the extended instruction set type for the current
|
|
// instruction.
|
|
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
|
|
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
|
|
return Diag(SPV_ERROR_INVALID_ID)
|
|
<< "OpExtInst set id " << id
|
|
<< " does not reference an OpExtInstImport result Id";
|
|
}
|
|
inst_.ext_inst_type = ext_inst_type_iter->second;
|
|
}
|
|
}
|
|
|
|
inst_words_.push_back(id);
|
|
SetIdBound(std::max(GetIdBound(), id + 1));
|
|
multi_mtf_.Promote(id);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
assert(SpvOpExtInst == opcode);
|
|
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
|
|
spv_ext_inst_desc ext_inst;
|
|
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction number: " << word;
|
|
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
|
|
// These are regular single-word literal integer operands.
|
|
// Post-parsing validation should check the range of the parsed value.
|
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
|
|
// It turns out they are always unsigned integers!
|
|
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
|
|
operand_.number_bit_width = 32;
|
|
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
|
|
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
|
|
if (opcode == SpvOpSwitch) {
|
|
// The literal operands have the same type as the value
|
|
// referenced by the selector Id.
|
|
const uint32_t selector_id = inst_words_.at(1);
|
|
const auto type_id_iter = id_to_type_id_.find(selector_id);
|
|
if (type_id_iter == id_to_type_id_.end() ||
|
|
type_id_iter->second == 0) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " has no type";
|
|
}
|
|
uint32_t type_id = type_id_iter->second;
|
|
|
|
if (selector_id == type_id) {
|
|
// Recall that by convention, a result ID that is a type definition
|
|
// maps to itself.
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is a type, not a value";
|
|
}
|
|
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
|
|
return error;
|
|
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
|
|
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid OpSwitch: selector id " << selector_id
|
|
<< " is not a scalar integer";
|
|
}
|
|
} else {
|
|
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
|
|
// The literal number type is determined by the type Id for the
|
|
// constant.
|
|
assert(inst_.type_id);
|
|
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
|
|
return error;
|
|
}
|
|
|
|
if (auto error = DecodeLiteralNumber(operand_))
|
|
return error;
|
|
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_LITERAL_STRING:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
|
|
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
|
|
std::vector<char> str;
|
|
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
|
|
|
|
if (codec) {
|
|
std::string decoded_string;
|
|
const bool huffman_result =
|
|
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
|
|
assert(huffman_result);
|
|
if (!huffman_result)
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal string";
|
|
|
|
if (decoded_string != "kMarkvNoneOfTheAbove") {
|
|
std::copy(decoded_string.begin(), decoded_string.end(),
|
|
std::back_inserter(str));
|
|
str.push_back('\0');
|
|
}
|
|
}
|
|
|
|
// The loop is expected to terminate once we encounter '\0' or exhaust
|
|
// the bit stream.
|
|
if (str.empty()) {
|
|
while (true) {
|
|
char ch = 0;
|
|
if (!reader_.ReadUnencoded(&ch))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read literal string";
|
|
|
|
str.push_back(ch);
|
|
|
|
if (ch == '\0')
|
|
break;
|
|
}
|
|
}
|
|
|
|
while (str.size() % 4 != 0)
|
|
str.push_back('\0');
|
|
|
|
inst_words_.resize(inst_words_.size() + str.size() / 4);
|
|
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
|
|
|
|
if (SpvOpExtInstImport == opcode) {
|
|
// Record the extended instruction type for the ID for this import.
|
|
// There is only one string literal argument to OpExtInstImport,
|
|
// so it's sufficient to guard this just on the opcode.
|
|
const spv_ext_inst_type_t ext_inst_type =
|
|
spvExtInstImportTypeGet(str.data());
|
|
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid extended instruction import '" << str.data() << "'";
|
|
}
|
|
// We must have parsed a valid result ID. It's a condition
|
|
// of the grammar, and we only accept non-zero result Ids.
|
|
assert(inst_.result_id);
|
|
const bool inserted = import_id_to_ext_inst_type_.emplace(
|
|
inst_.result_id, ext_inst_type).second;
|
|
(void)inserted;
|
|
assert(inserted);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_CAPABILITY:
|
|
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
|
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
|
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
|
case SPV_OPERAND_TYPE_EXECUTION_MODE:
|
|
case SPV_OPERAND_TYPE_STORAGE_CLASS:
|
|
case SPV_OPERAND_TYPE_DIMENSIONALITY:
|
|
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
|
|
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
|
|
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
|
|
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
|
|
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
|
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
|
|
case SPV_OPERAND_TYPE_DECORATION:
|
|
case SPV_OPERAND_TYPE_BUILT_IN:
|
|
case SPV_OPERAND_TYPE_GROUP_OPERATION:
|
|
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
|
|
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
|
|
// A single word that is a plain enum value.
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
|
|
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
|
|
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, word, &entry)) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid "
|
|
<< spvOperandTypeStr(operand_.type)
|
|
<< " operand: " << word;
|
|
}
|
|
|
|
// Prepare to accept operands to this operand, if needed.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
break;
|
|
}
|
|
|
|
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
|
|
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
|
|
case SPV_OPERAND_TYPE_LOOP_CONTROL:
|
|
case SPV_OPERAND_TYPE_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
|
|
// This operand is a mask.
|
|
uint32_t word = 0;
|
|
const spv_result_t result = DecodeNonIdWord(&word);
|
|
if (result != SPV_SUCCESS)
|
|
return result;
|
|
|
|
inst_words_.push_back(word);
|
|
|
|
// Map an optional operand type to its corresponding concrete type.
|
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
|
|
operand_.type = SPV_OPERAND_TYPE_IMAGE;
|
|
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
|
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
|
|
|
// Check validity of set mask bits. Also prepare for operands for those
|
|
// masks if they have any. To get operand order correct, scan from
|
|
// MSB to LSB since we can only prepend operands to a pattern.
|
|
// The only case in the grammar where you have more than one mask bit
|
|
// having an operand is for image operands. See SPIR-V 3.14 Image
|
|
// Operands.
|
|
uint32_t remaining_word = word;
|
|
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
|
|
if (remaining_word & mask) {
|
|
spv_operand_desc entry;
|
|
if (grammar_.lookupOperand(type, mask, &entry)) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid " << spvOperandTypeStr(operand_.type)
|
|
<< " operand: " << word << " has invalid mask component "
|
|
<< mask;
|
|
}
|
|
remaining_word ^= mask;
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
if (word == 0) {
|
|
// An all-zeroes mask *might* also be valid.
|
|
spv_operand_desc entry;
|
|
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
|
|
// Prepare for its operands, if any.
|
|
spvPushOperandTypes(entry->operandTypes, expected_operands);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Internal error: Unhandled operand type: " << type;
|
|
}
|
|
|
|
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
|
|
|
|
assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(operand_.type));
|
|
assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(operand_.type));
|
|
|
|
parsed_operands_.push_back(operand_);
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::DecodeInstruction() {
|
|
parsed_operands_.clear();
|
|
inst_words_.clear();
|
|
|
|
// Opcode/num_words placeholder, the word will be filled in later.
|
|
inst_words_.push_back(0);
|
|
|
|
bool num_operands_still_unknown = true;
|
|
{
|
|
uint32_t opcode = 0;
|
|
uint32_t num_operands = 0;
|
|
|
|
const spv_result_t opcode_decoding_result =
|
|
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
|
|
if (opcode_decoding_result < 0)
|
|
return opcode_decoding_result;
|
|
|
|
if (opcode_decoding_result == SPV_SUCCESS) {
|
|
inst_.num_operands = static_cast<uint16_t>(num_operands);
|
|
num_operands_still_unknown = false;
|
|
} else {
|
|
if (!reader_.ReadVariableWidthU32(
|
|
&opcode, model_->opcode_chunk_length())) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read opcode of instruction";
|
|
}
|
|
}
|
|
|
|
inst_.opcode = static_cast<uint16_t>(opcode);
|
|
}
|
|
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
|
|
spv_opcode_desc opcode_desc;
|
|
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
|
|
}
|
|
|
|
spv_operand_pattern_t expected_operands;
|
|
expected_operands.reserve(opcode_desc->numTypes);
|
|
for (auto i = 0; i < opcode_desc->numTypes; i++) {
|
|
expected_operands.push_back(
|
|
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
|
|
}
|
|
|
|
if (num_operands_still_unknown) {
|
|
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
|
|
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
|
|
model_->num_operands_chunk_length()))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read num_operands of instruction";
|
|
} else {
|
|
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
|
|
}
|
|
}
|
|
|
|
for (operand_index_ = 0;
|
|
operand_index_ < static_cast<size_t>(inst_.num_operands);
|
|
++operand_index_) {
|
|
assert(!expected_operands.empty());
|
|
const spv_operand_type_t type =
|
|
spvTakeFirstMatchableOperand(&expected_operands);
|
|
|
|
const size_t operand_offset = inst_words_.size();
|
|
|
|
const spv_result_t decode_result = DecodeOperand(
|
|
operand_offset, type, &expected_operands);
|
|
|
|
if (decode_result != SPV_SUCCESS)
|
|
return decode_result;
|
|
}
|
|
|
|
|
|
assert(inst_.num_operands == parsed_operands_.size());
|
|
|
|
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
|
|
// next DecodeInstruction call).
|
|
inst_.words = inst_words_.data();
|
|
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
|
|
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
|
|
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
|
|
|
|
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
|
|
|
|
assert(inst_.num_words == std::accumulate(
|
|
parsed_operands_.begin(), parsed_operands_.end(), 1,
|
|
[](int num_words, const spv_parsed_operand_t& operand) {
|
|
return num_words += operand.num_words;
|
|
}) && "num_words in instruction doesn't correspond to the sum of num_words"
|
|
"in the operands");
|
|
|
|
RecordNumberType();
|
|
ProcessCurInstruction();
|
|
|
|
if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte))
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to read to byte break";
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
|
|
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
|
|
assert(type_id != 0);
|
|
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
|
|
if (type_info_iter == type_id_to_number_type_info_.end()) {
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a type";
|
|
}
|
|
|
|
const NumberType& info = type_info_iter->second;
|
|
if (info.type == SPV_NUMBER_NONE) {
|
|
// This is a valid type, but for something other than a scalar number.
|
|
return Diag(SPV_ERROR_INVALID_BINARY)
|
|
<< "Type Id " << type_id << " is not a scalar numeric type";
|
|
}
|
|
|
|
parsed_operand->number_kind = info.type;
|
|
parsed_operand->number_bit_width = info.bit_width;
|
|
// Round up the word count.
|
|
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void MarkvDecoder::RecordNumberType() {
|
|
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
|
|
if (spvOpcodeGeneratesType(opcode)) {
|
|
NumberType info = {SPV_NUMBER_NONE, 0};
|
|
if (SpvOpTypeInt == opcode) {
|
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
|
info.type = inst_.words[inst_.operands[2].offset] ?
|
|
SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
|
|
} else if (SpvOpTypeFloat == opcode) {
|
|
info.bit_width = inst_.words[inst_.operands[1].offset];
|
|
info.type = SPV_NUMBER_FLOATING;
|
|
}
|
|
// The *result* Id of a type generating instruction is the type Id.
|
|
type_id_to_number_type_info_[inst_.result_id] = info;
|
|
}
|
|
}
|
|
|
|
spv_result_t EncodeHeader(
|
|
void* user_data, spv_endianness_t endian, uint32_t magic,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t schema) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeHeader(
|
|
endian, magic, version, generator, id_bound, schema);
|
|
}
|
|
|
|
spv_result_t EncodeInstruction(
|
|
void* user_data, const spv_parsed_instruction_t* inst) {
|
|
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
|
|
return encoder->EncodeInstruction(*inst);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
spv_result_t spvSpirvToMarkv(spv_const_context context,
|
|
const uint32_t* spirv_words,
|
|
const size_t spirv_num_words,
|
|
spv_const_markv_encoder_options options,
|
|
spv_markv_binary* markv_binary,
|
|
spv_text* comments, spv_diagnostic* diagnostic) {
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
|
|
|
|
spv_endianness_t endian;
|
|
spv_position_t position = {};
|
|
if (spvBinaryEndianness(&spirv_binary, &endian)) {
|
|
return DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V magic number.";
|
|
}
|
|
|
|
spv_header_t header;
|
|
if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
|
|
return DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Invalid SPIR-V header.";
|
|
}
|
|
|
|
MarkvEncoder encoder(&hijack_context, options);
|
|
|
|
if (comments) {
|
|
encoder.CreateCommentsLogger();
|
|
|
|
spv_text text = nullptr;
|
|
if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
|
|
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
|
|
!= SPV_SUCCESS) {
|
|
return DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Failed to disassemble SPIR-V binary.";
|
|
}
|
|
assert(text);
|
|
encoder.SetDisassembly(std::string(text->str, text->length));
|
|
spvTextDestroy(text);
|
|
}
|
|
|
|
if (spvBinaryParse(
|
|
&hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
|
|
EncodeInstruction, diagnostic) != SPV_SUCCESS) {
|
|
return DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to encode to MARK-V.";
|
|
}
|
|
|
|
if (comments)
|
|
*comments = CreateSpvText(encoder.GetComments());
|
|
|
|
*markv_binary = encoder.GetMarkvBinary();
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t spvMarkvToSpirv(spv_const_context context,
|
|
const uint8_t* markv_data,
|
|
size_t markv_size_bytes,
|
|
spv_const_markv_decoder_options options,
|
|
spv_binary* spirv_binary,
|
|
spv_text* /* comments */,
|
|
spv_diagnostic* diagnostic) {
|
|
spv_position_t position = {};
|
|
spv_context_t hijack_context = *context;
|
|
if (diagnostic) {
|
|
*diagnostic = nullptr;
|
|
libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
|
|
}
|
|
|
|
MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
|
|
|
|
std::vector<uint32_t> words;
|
|
|
|
if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
|
|
return DiagnosticStream(position, hijack_context.consumer,
|
|
SPV_ERROR_INVALID_BINARY)
|
|
<< "Unable to decode MARK-V.";
|
|
}
|
|
|
|
assert(!words.empty());
|
|
|
|
*spirv_binary = new spv_binary_t();
|
|
(*spirv_binary)->code = new uint32_t[words.size()];
|
|
(*spirv_binary)->wordCount = words.size();
|
|
std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
void spvMarkvBinaryDestroy(spv_markv_binary binary) {
|
|
if (!binary) return;
|
|
delete[] binary->data;
|
|
delete binary;
|
|
}
|
|
|
|
spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
|
|
return new spv_markv_encoder_options_t;
|
|
}
|
|
|
|
void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
|
|
delete options;
|
|
}
|
|
|
|
spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
|
|
return new spv_markv_decoder_options_t;
|
|
}
|
|
|
|
void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
|
|
delete options;
|
|
}
|