Added Markov chain analysis to stats

Added data structure to SpirvStats which is used to collect statistics
on opcodes following other opcodes.

Added a simple analysis print-out to spirv-stats.
This commit is contained in:
Andrey Tuganov 2017-04-20 15:32:38 -04:00 committed by David Neto
parent bad90d9f12
commit 87a3f651e2
7 changed files with 261 additions and 22 deletions

View File

@ -34,20 +34,27 @@ using libspirv::SpirvStats;
namespace {
struct StatsContext {
SpirvStats* stats;
// Opcodes of already processed instructions in the order as they appear in
// the module.
std::vector<uint32_t> opcodes;
};
// Collects statistics from SPIR-V header (version, generator).
spv_result_t ProcessHeader(
void* user_data, spv_endianness_t /* endian */, uint32_t /* magic */,
uint32_t version, uint32_t generator, uint32_t /* id_bound */,
uint32_t /* schema */) {
SpirvStats* stats =
reinterpret_cast<libspirv::SpirvStats*>(user_data);
++stats->version_hist[version];
++stats->generator_hist[generator];
StatsContext* ctx = reinterpret_cast<StatsContext*>(user_data);
++ctx->stats->version_hist[version];
++ctx->stats->generator_hist[generator];
return SPV_SUCCESS;
}
// Collects OpCapability statistics.
void ProcessCapability(SpirvStats* stats,
void ProcessCapability(StatsContext* ctx,
const spv_parsed_instruction_t* inst) {
if (static_cast<SpvOp>(inst->opcode) != SpvOpCapability) return;
assert(inst->num_operands == 1);
@ -55,28 +62,44 @@ void ProcessCapability(SpirvStats* stats,
assert(operand.num_words == 1);
assert(operand.offset < inst->num_words);
const uint32_t capability = inst->words[operand.offset];
++stats->capability_hist[capability];
++ctx->stats->capability_hist[capability];
}
// Collects OpExtension statistics.
void ProcessExtension(SpirvStats* stats,
void ProcessExtension(StatsContext* ctx,
const spv_parsed_instruction_t* inst) {
if (static_cast<SpvOp>(inst->opcode) != SpvOpExtension) return;
const std::string extension = libspirv::GetExtensionString(inst);
++stats->extension_hist[extension];
++ctx->stats->extension_hist[extension];
}
// Collects OpCode statistics.
void ProcessOpcode(StatsContext* ctx,
const spv_parsed_instruction_t* inst) {
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
++ctx->stats->opcode_hist[opcode];
auto opcode_it = ctx->opcodes.rbegin();
auto step_it = ctx->stats->opcode_markov_hist.begin();
for (; opcode_it != ctx->opcodes.rend() &&
step_it != ctx->stats->opcode_markov_hist.end();
++opcode_it, ++step_it) {
auto& hist = (*step_it)[*opcode_it];
++hist[opcode];
}
}
// Collects opcode usage statistics and calls other collectors.
spv_result_t ProcessInstruction(
void* user_data, const spv_parsed_instruction_t* inst) {
SpirvStats* stats =
reinterpret_cast<libspirv::SpirvStats*>(user_data);
StatsContext* ctx = reinterpret_cast<StatsContext*>(user_data);
ProcessOpcode(ctx, inst);
ProcessCapability(ctx, inst);
ProcessExtension(ctx, inst);
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
++stats->opcode_hist[opcode];
ProcessCapability(stats, inst);
ProcessExtension(stats, inst);
ctx->opcodes.push_back(opcode);
return SPV_SUCCESS;
}
@ -86,28 +109,30 @@ spv_result_t ProcessInstruction(
namespace libspirv {
spv_result_t AggregateStats(
const spv_context_t& context, const uint32_t* words, const size_t num_words,
const spv_context_t& spv_context, const uint32_t* words, const size_t num_words,
spv_diagnostic* pDiagnostic, SpirvStats* stats) {
spv_const_binary_t binary = {words, num_words};
spv_endianness_t endian;
spv_position_t position = {};
if (spvBinaryEndianness(&binary, &endian)) {
return libspirv::DiagnosticStream(position, context.consumer,
return libspirv::DiagnosticStream(position, spv_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V magic number.";
}
spv_header_t header;
if (spvBinaryHeaderGet(&binary, endian, &header)) {
return libspirv::DiagnosticStream(position, context.consumer,
return libspirv::DiagnosticStream(position, spv_context.consumer,
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V header.";
}
return spvBinaryParse(&context, stats, words, num_words,
ProcessHeader, ProcessInstruction,
pDiagnostic);
StatsContext stats_context;
stats_context.stats = stats;
return spvBinaryParse(&spv_context, &stats_context, words, num_words,
ProcessHeader, ProcessInstruction, pDiagnostic);
}
} // namespace libspirv

View File

@ -38,6 +38,20 @@ struct SpirvStats {
// Opcode histogram, SpvOpXXX -> count.
std::unordered_map<uint32_t, uint32_t> opcode_hist;
// Used to collect statistics on opcodes triggering other opcodes.
// Container scheme: gap between instructions -> cue opcode -> later opcode
// -> count.
// For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to
// the number of times an OpMul appears, followed by 2 other instructions,
// followed by OpFAdd.
// opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times
// OpFMul appears, directly followed by OpFAdd.
// The size of the outer std::vector also serves as an input parameter,
// determining how many steps will be collected.
// I.e. do opcode_markov_hist.resize(1) to collect data for one step only.
std::vector<std::unordered_map<uint32_t,
std::unordered_map<uint32_t, uint32_t>>> opcode_markov_hist;
};
// Aggregates existing |stats| with new stats extracted from |binary|.

View File

@ -235,4 +235,85 @@ OpMemoryModel Logical GLSL450
EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension));
}
TEST(AggregateStats, OpcodeMarkovHistogram) {
const std::string code1 = R"(
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_NV_viewport_array2"
OpMemoryModel Logical GLSL450
)";
const std::string code2 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability GenericPointer
OpCapability Linkage
OpMemoryModel Physical32 OpenCL
%i32 = OpTypeInt 32 1
%u32 = OpTypeInt 32 0
%f32 = OpTypeFloat 32
)";
SpirvStats stats;
stats.opcode_markov_hist.resize(2);
CompileAndAggregateStats(code1, &stats);
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
EXPECT_EQ(1u, stats.opcode_markov_hist[1].size());
EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
CompileAndAggregateStats(code2, &stats);
ASSERT_EQ(2u, stats.opcode_markov_hist.size());
EXPECT_EQ(4u, stats.opcode_markov_hist[0].size());
EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size());
EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size());
EXPECT_EQ(
4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt));
EXPECT_EQ(
1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat));
EXPECT_EQ(3u, stats.opcode_markov_hist[1].size());
EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size());
EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size());
EXPECT_EQ(
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension));
EXPECT_EQ(
2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt));
EXPECT_EQ(
1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat));
}
} // namespace

View File

@ -140,4 +140,33 @@ TEST(StatsAnalyzer, Opcode) {
EXPECT_EQ(expected_output, output);
}
TEST(StatsAnalyzer, OpcodeMarkov) {
SpirvStats stats;
FillDefaultStats(&stats);
stats.opcode_hist[SpvOpFMul] = 400;
stats.opcode_hist[SpvOpFAdd] = 200;
stats.opcode_hist[SpvOpFSub] = 400;
stats.opcode_markov_hist.resize(1);
auto& hist = stats.opcode_markov_hist[0];
hist[SpvOpFMul][SpvOpFAdd] = 100;
hist[SpvOpFMul][SpvOpFSub] = 300;
hist[SpvOpFAdd][SpvOpFMul] = 100;
hist[SpvOpFAdd][SpvOpFAdd] = 100;
StatsAnalyzer analyzer(stats);
std::stringstream ss;
analyzer.WriteOpcodeMarkov(ss);
const std::string output = ss.str();
const std::string expected_output =
"FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n"
"FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n"
"FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n"
"FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n";
EXPECT_EQ(expected_output, output);
}
} // namespace

View File

@ -105,13 +105,13 @@ int main(int argc, char** argv) {
return return_code;
}
std::cerr << "Processing " << paths.size()
<< " files..." << std::endl;
std::cerr << "Processing " << paths.size() << " files..." << std::endl;
ScopedContext ctx(SPV_ENV_UNIVERSAL_1_1);
SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
libspirv::SpirvStats stats;
stats.opcode_markov_hist.resize(1);
for (size_t index = 0; index < paths.size(); ++index) {
const size_t kMilestonePeriod = 1000;
@ -148,5 +148,8 @@ int main(int argc, char** argv) {
out << std::endl;
analyzer.WriteOpcode(out);
out << std::endl;
analyzer.WriteOpcodeMarkov(out);
return 0;
}

View File

@ -74,6 +74,11 @@ std::unordered_map<Key, double> GetPrevalence(
return GetRecall(hist, total);
}
// Writes |freq| to |out| sorted by frequency in the following format:
// LABEL3 70%
// LABEL1 20%
// LABEL2 10%
// |label_from_key| is used to convert |Key| to label.
template <class Key>
void WriteFreq(std::ostream& out, const std::unordered_map<Key, double>& freq,
std::string (*label_from_key)(Key)) {
@ -90,6 +95,26 @@ void WriteFreq(std::ostream& out, const std::unordered_map<Key, double>& freq,
}
}
// Writes |hist| to |out| sorted by count in the following format:
// LABEL3 100
// LABEL1 50
// LABEL2 10
// |label_from_key| is used to convert |Key| to label.
template <class Key>
void WriteHist(std::ostream& out, const std::unordered_map<Key, uint32_t>& hist,
std::string (*label_from_key)(Key)) {
std::vector<std::pair<Key, uint32_t>> sorted_hist(hist.begin(), hist.end());
std::sort(sorted_hist.begin(), sorted_hist.end(),
[](const std::pair<Key, uint32_t>& left,
const std::pair<Key, uint32_t>& right) {
return left.second > right.second;
});
for (const auto& pair : sorted_hist) {
out << label_from_key(pair.first) << " " << pair.second << std::endl;
}
}
} // namespace
StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) {
@ -125,3 +150,59 @@ void StatsAnalyzer::WriteOpcode(std::ostream& out) {
out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl;
WriteFreq(out, opcode_freq_, GetOpcodeString);
}
void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) {
if (stats_.opcode_markov_hist.empty())
return;
const std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>&
cue_to_hist = stats_.opcode_markov_hist[0];
// Sort by prevalence of the opcodes in opcode_freq_ (descending).
std::vector<std::pair<uint32_t, std::unordered_map<uint32_t, uint32_t>>>
sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end());
std::sort(sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(),
[this](
const std::pair<uint32_t,
std::unordered_map<uint32_t, uint32_t>>& left,
const std::pair<uint32_t,
std::unordered_map<uint32_t, uint32_t>>& right) {
const double lf = opcode_freq_[left.first];
const double rf = opcode_freq_[right.first];
if (lf == rf)
return right.first > left.first;
return lf > rf;
});
for (const auto& kv : sorted_cue_to_hist) {
const uint32_t cue = kv.first;
const double kFrequentEnoughToAnalyze = 0.0001;
if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue;
const std::unordered_map<uint32_t, uint32_t>& hist = kv.second;
uint32_t total = 0;
for (const auto& pair : hist) {
total += pair.second;
}
std::vector<std::pair<uint32_t, uint32_t>>
sorted_hist(hist.begin(), hist.end());
std::sort(sorted_hist.begin(), sorted_hist.end(),
[](const std::pair<uint32_t, uint32_t>& left,
const std::pair<uint32_t, uint32_t>& right) {
if (left.second == right.second)
return right.first > left.first;
return left.second > right.second;
});
for (const auto& pair : sorted_hist) {
const double prior = opcode_freq_[pair.first];
const double posterior =
static_cast<double>(pair.second) / static_cast<double>(total);
out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first)
<< " " << posterior * 100 << "% (base rate " << prior * 100
<< "%, pair occurrences " << pair.second << ")" << std::endl;
}
}
}

View File

@ -23,12 +23,18 @@ class StatsAnalyzer {
public:
explicit StatsAnalyzer(const libspirv::SpirvStats& stats);
// Writes respective histograms to |out|.
void WriteVersion(std::ostream& out);
void WriteGenerator(std::ostream& out);
void WriteCapability(std::ostream& out);
void WriteExtension(std::ostream& out);
void WriteOpcode(std::ostream& out);
// Writes first order Markov analysis to |out|.
// stats_.opcode_markov_hist needs to contain raw data for at least one
// level.
void WriteOpcodeMarkov(std::ostream& out);
private:
const libspirv::SpirvStats& stats_;