SPIRV-Tools/source/util/huffman_codec.h
Andrey Tuganov 40a2829611 Added Huffman codec to utils
Attached ids to Huffman nodes for deterministic internal node
comparison.
2017-06-29 14:51:01 -04:00

300 lines
9.1 KiB
C++

// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains utils for reading, writing and debug printing bit streams.
#ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
#define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
#include <algorithm>
#include <cassert>
#include <functional>
#include <queue>
#include <iomanip>
#include <map>
#include <memory>
#include <ostream>
#include <sstream>
#include <stack>
#include <tuple>
#include <unordered_map>
#include <vector>
namespace spvutils {
// Used to generate and apply a Huffman coding scheme.
// |Val| is the type of variable being encoded (for example a string or a
// literal).
template <class Val>
class HuffmanCodec {
struct Node;
public:
// Creates Huffman codec from a histogramm.
// Histogramm counts must not be zero.
explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
if (hist.empty()) return;
// Heuristic estimate.
all_nodes_.reserve(3 * hist.size());
// The queue is sorted in ascending order by weight (or by node id if
// weights are equal).
std::vector<Node*> queue_vector;
queue_vector.reserve(hist.size());
std::priority_queue<Node*, std::vector<Node*>,
std::function<bool(const Node*, const Node*)>>
queue(LeftIsBigger, std::move(queue_vector));
// Put all leaves in the queue.
for (const auto& pair : hist) {
Node* node = CreateNode();
node->val = pair.first;
node->weight = pair.second;
assert(node->weight);
queue.push(node);
}
// Form the tree by combining two subtrees with the least weight,
// and pushing the root of the new tree in the queue.
while (true) {
// We push a node at the end of each iteration, so the queue is never
// supposed to be empty at this point, unless there are no leaves, but
// that case was already handled.
assert(!queue.empty());
Node* right = queue.top();
queue.pop();
// If the queue is empty at this point, then the last node is
// the root of the complete Huffman tree.
if (queue.empty()) {
root_ = right;
break;
}
Node* left = queue.top();
queue.pop();
// Combine left and right into a new tree and push it into the queue.
Node* parent = CreateNode();
parent->weight = right->weight + left->weight;
parent->left = left;
parent->right = right;
queue.push(parent);
}
// Traverse the tree and form encoding table.
CreateEncodingTable();
}
// Prints the Huffman tree in the following format:
// w------w------'x'
// w------'y'
// Where w stands for the weight of the node.
// Right tree branches appear above left branches. Taking the right path
// adds 1 to the code, taking the left adds 0.
void PrintTree(std::ostream& out) {
PrintTreeInternal(out, root_, 0);
}
// Traverses the tree and prints the Huffman table: value, code
// and optionally node weight for every leaf.
void PrintTable(std::ostream& out, bool print_weights = true) {
std::queue<std::pair<Node*, std::string>> queue;
queue.emplace(root_, "");
while (!queue.empty()) {
const Node* node = queue.front().first;
const std::string code = queue.front().second;
queue.pop();
if (!node->right && !node->left) {
out << node->val;
if (print_weights)
out << " " << node->weight;
out << " " << code << std::endl;
} else {
if (node->left)
queue.emplace(node->left, code + "0");
if (node->right)
queue.emplace(node->right, code + "1");
}
}
}
// Returns the Huffman table. The table was built at at construction time,
// this function just returns a const reference.
const std::unordered_map<Val, std::pair<uint64_t, size_t>>&
GetEncodingTable() const {
return encoding_table_;
}
// Encodes |val| and stores its Huffman code in the lower |num_bits| of
// |bits|. Returns false of |val| is not in the Huffman table.
bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) {
auto it = encoding_table_.find(val);
if (it == encoding_table_.end())
return false;
*bits = it->second.first;
*num_bits = it->second.second;
return true;
}
// Reads bits one-by-one using callback |read_bit| until a match is found.
// Matching value is stored in |val|. Returns false if |read_bit| terminates
// before a code was mathced.
// |read_bit| has type bool func(bool* bit). When called, the next bit is
// stored in |bit|. |read_bit| returns false if the stream terminates
// prematurely.
bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) {
Node* node = root_;
while (true) {
assert(node);
if (node->left == nullptr && node->right == nullptr) {
*val = node->val;
return true;
}
bool go_right;
if (!read_bit(&go_right))
return false;
if (go_right)
node = node->right;
else
node = node->left;
}
assert (0);
return false;
}
private:
// Huffman tree node.
struct Node {
Val val = Val();
uint32_t weight = 0;
// Ids are issued sequentially starting from 1. Ids are used as an ordering
// tie-breaker, to make sure that the ordering (and resulting coding scheme)
// are consistent accross multiple platforms.
uint32_t id = 0;
Node* left = nullptr;
Node* right = nullptr;
};
// Returns true if |left| has bigger weight than |right|. Node ids are
// used as tie-breaker.
static bool LeftIsBigger(const Node* left, const Node* right) {
if (left->weight == right->weight) {
assert (left->id != right->id);
return left->id > right->id;
}
return left->weight > right->weight;
}
// Prints subtree (helper function used by PrintTree).
static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) {
if (!node)
return;
const size_t kTextFieldWidth = 7;
if (!node->right && !node->left) {
out << node->val << std::endl;
} else {
if (node->right) {
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< node->right->weight;
out << label.str();
PrintTreeInternal(out, node->right, depth + 1);
}
if (node->left) {
out << std::string(depth * kTextFieldWidth, ' ');
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< node->left->weight;
out << label.str();
PrintTreeInternal(out, node->left, depth + 1);
}
}
}
// Traverses the Huffman tree and saves paths to the leaves as bit
// sequences to encoding_table_.
void CreateEncodingTable() {
struct Context {
Context(Node* in_node, uint64_t in_bits, size_t in_depth)
: node(in_node), bits(in_bits), depth(in_depth) {}
Node* node;
// Huffman tree depth cannot exceed 64 as histogramm counts are expected
// to be positive and limited by numeric_limits<uint32_t>::max().
// For practical applications tree depth would be much smaller than 64.
uint64_t bits;
size_t depth;
};
std::queue<Context> queue;
queue.emplace(root_, 0, 0);
while (!queue.empty()) {
const Context& context = queue.front();
const Node* node = context.node;
const uint64_t bits = context.bits;
const size_t depth = context.depth;
queue.pop();
if (!node->right && !node->left) {
auto insertion_result = encoding_table_.emplace(
node->val, std::pair<uint64_t, size_t>(bits, depth));
assert(insertion_result.second);
(void)insertion_result;
} else {
if (node->left)
queue.emplace(node->left, bits, depth + 1);
if (node->right)
queue.emplace(node->right, bits | (1ULL << depth), depth + 1);
}
}
}
// Creates new Huffman tree node and stores it in the deleter array.
Node* CreateNode() {
all_nodes_.emplace_back(new Node());
all_nodes_.back()->id = next_node_id_++;
return all_nodes_.back().get();
}
// Huffman tree root.
Node* root_ = nullptr;
// Huffman tree deleter.
std::vector<std::unique_ptr<Node>> all_nodes_;
// Encoding table value -> {bits, num_bits}.
// Huffman codes are expected to never exceed 64 bit length (this is in fact
// impossible if frequencies are stored as uint32_t).
std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
// Next node id issued by CreateNode();
uint32_t next_node_id_ = 1;
};
} // namespace spvutils
#endif // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_