Make the histogram clustering function more generic.

Change the template parameter to be the histogram class
instead of the alphabet size of the histogram.
This commit is contained in:
Zoltan Szabadka 2014-10-28 13:36:21 +01:00
parent c6c08e492e
commit f321ba1964

View File

@ -20,6 +20,7 @@
#include <math.h> #include <math.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm>
#include <complex> #include <complex>
#include <map> #include <map>
#include <set> #include <set>
@ -59,8 +60,8 @@ inline double ClusterCostDiff(int size_a, int size_b) {
// Computes the bit cost reduction by combining out[idx1] and out[idx2] and if // Computes the bit cost reduction by combining out[idx1] and out[idx2] and if
// it is below a threshold, stores the pair (idx1, idx2) in the *pairs heap. // it is below a threshold, stores the pair (idx1, idx2) in the *pairs heap.
template<int kSize> template<typename HistogramType>
void CompareAndPushToHeap(const Histogram<kSize>* out, void CompareAndPushToHeap(const HistogramType* out,
const int* cluster_size, const int* cluster_size,
int idx1, int idx2, int idx1, int idx2,
std::vector<HistogramPair>* pairs) { std::vector<HistogramPair>* pairs) {
@ -90,7 +91,7 @@ void CompareAndPushToHeap(const Histogram<kSize>* out,
} else { } else {
double threshold = pairs->empty() ? 1e99 : double threshold = pairs->empty() ? 1e99 :
std::max(0.0, (*pairs)[0].cost_diff); std::max(0.0, (*pairs)[0].cost_diff);
Histogram<kSize> combo = out[idx1]; HistogramType combo = out[idx1];
combo.AddHistogram(out[idx2]); combo.AddHistogram(out[idx2]);
double cost_combo = PopulationCost(combo); double cost_combo = PopulationCost(combo);
if (cost_combo < threshold - p.cost_diff) { if (cost_combo < threshold - p.cost_diff) {
@ -105,8 +106,8 @@ void CompareAndPushToHeap(const Histogram<kSize>* out,
} }
} }
template<int kSize> template<typename HistogramType>
void HistogramCombine(Histogram<kSize>* out, void HistogramCombine(HistogramType* out,
int* cluster_size, int* cluster_size,
int* symbols, int* symbols,
int symbols_size, int symbols_size,
@ -178,22 +179,22 @@ void HistogramCombine(Histogram<kSize>* out,
// Histogram refinement // Histogram refinement
// What is the bit cost of moving histogram from cur_symbol to candidate. // What is the bit cost of moving histogram from cur_symbol to candidate.
template<int kSize> template<typename HistogramType>
double HistogramBitCostDistance(const Histogram<kSize>& histogram, double HistogramBitCostDistance(const HistogramType& histogram,
const Histogram<kSize>& candidate) { const HistogramType& candidate) {
if (histogram.total_count_ == 0) { if (histogram.total_count_ == 0) {
return 0.0; return 0.0;
} }
Histogram<kSize> tmp = histogram; HistogramType tmp = histogram;
tmp.AddHistogram(candidate); tmp.AddHistogram(candidate);
return PopulationCost(tmp) - candidate.bit_cost_; return PopulationCost(tmp) - candidate.bit_cost_;
} }
// Find the best 'out' histogram for each of the 'in' histograms. // Find the best 'out' histogram for each of the 'in' histograms.
// Note: we assume that out[]->bit_cost_ is already up-to-date. // Note: we assume that out[]->bit_cost_ is already up-to-date.
template<int kSize> template<typename HistogramType>
void HistogramRemap(const Histogram<kSize>* in, int in_size, void HistogramRemap(const HistogramType* in, int in_size,
Histogram<kSize>* out, int* symbols) { HistogramType* out, int* symbols) {
std::set<int> all_symbols; std::set<int> all_symbols;
for (int i = 0; i < in_size; ++i) { for (int i = 0; i < in_size; ++i) {
all_symbols.insert(symbols[i]); all_symbols.insert(symbols[i]);
@ -224,10 +225,10 @@ void HistogramRemap(const Histogram<kSize>* in, int in_size,
// Reorder histograms in *out so that the new symbols in *symbols come in // Reorder histograms in *out so that the new symbols in *symbols come in
// increasing order. // increasing order.
template<int kSize> template<typename HistogramType>
void HistogramReindex(std::vector<Histogram<kSize> >* out, void HistogramReindex(std::vector<HistogramType>* out,
std::vector<int>* symbols) { std::vector<int>* symbols) {
std::vector<Histogram<kSize> > tmp(*out); std::vector<HistogramType> tmp(*out);
std::map<int, int> new_index; std::map<int, int> new_index;
int next_index = 0; int next_index = 0;
for (int i = 0; i < symbols->size(); ++i) { for (int i = 0; i < symbols->size(); ++i) {
@ -246,11 +247,11 @@ void HistogramReindex(std::vector<Histogram<kSize> >* out,
// Clusters similar histograms in 'in' together, the selected histograms are // Clusters similar histograms in 'in' together, the selected histograms are
// placed in 'out', and for each index in 'in', *histogram_symbols will // placed in 'out', and for each index in 'in', *histogram_symbols will
// indicate which of the 'out' histograms is the best approximation. // indicate which of the 'out' histograms is the best approximation.
template<int kSize> template<typename HistogramType>
void ClusterHistograms(const std::vector<Histogram<kSize> >& in, void ClusterHistograms(const std::vector<HistogramType>& in,
int num_contexts, int num_blocks, int num_contexts, int num_blocks,
int max_histograms, int max_histograms,
std::vector<Histogram<kSize> >* out, std::vector<HistogramType>* out,
std::vector<int>* histogram_symbols) { std::vector<int>* histogram_symbols) {
const int in_size = num_contexts * num_blocks; const int in_size = num_contexts * num_blocks;
std::vector<int> cluster_size(in_size, 1); std::vector<int> cluster_size(in_size, 1);