diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp index e240b3d00..19da719e1 100644 --- a/source/spirv_stats.cpp +++ b/source/spirv_stats.cpp @@ -34,20 +34,27 @@ using libspirv::SpirvStats; namespace { +struct StatsContext { + SpirvStats* stats; + + // Opcodes of already processed instructions in the order as they appear in + // the module. + std::vector opcodes; +}; + // Collects statistics from SPIR-V header (version, generator). spv_result_t ProcessHeader( void* user_data, spv_endianness_t /* endian */, uint32_t /* magic */, uint32_t version, uint32_t generator, uint32_t /* id_bound */, uint32_t /* schema */) { - SpirvStats* stats = - reinterpret_cast(user_data); - ++stats->version_hist[version]; - ++stats->generator_hist[generator]; + StatsContext* ctx = reinterpret_cast(user_data); + ++ctx->stats->version_hist[version]; + ++ctx->stats->generator_hist[generator]; return SPV_SUCCESS; } // Collects OpCapability statistics. -void ProcessCapability(SpirvStats* stats, +void ProcessCapability(StatsContext* ctx, const spv_parsed_instruction_t* inst) { if (static_cast(inst->opcode) != SpvOpCapability) return; assert(inst->num_operands == 1); @@ -55,28 +62,44 @@ void ProcessCapability(SpirvStats* stats, assert(operand.num_words == 1); assert(operand.offset < inst->num_words); const uint32_t capability = inst->words[operand.offset]; - ++stats->capability_hist[capability]; + ++ctx->stats->capability_hist[capability]; } // Collects OpExtension statistics. -void ProcessExtension(SpirvStats* stats, +void ProcessExtension(StatsContext* ctx, const spv_parsed_instruction_t* inst) { if (static_cast(inst->opcode) != SpvOpExtension) return; const std::string extension = libspirv::GetExtensionString(inst); - ++stats->extension_hist[extension]; + ++ctx->stats->extension_hist[extension]; +} + +// Collects OpCode statistics. +void ProcessOpcode(StatsContext* ctx, + const spv_parsed_instruction_t* inst) { + const SpvOp opcode = static_cast(inst->opcode); + ++ctx->stats->opcode_hist[opcode]; + + auto opcode_it = ctx->opcodes.rbegin(); + auto step_it = ctx->stats->opcode_markov_hist.begin(); + for (; opcode_it != ctx->opcodes.rend() && + step_it != ctx->stats->opcode_markov_hist.end(); + ++opcode_it, ++step_it) { + auto& hist = (*step_it)[*opcode_it]; + ++hist[opcode]; + } } // Collects opcode usage statistics and calls other collectors. spv_result_t ProcessInstruction( void* user_data, const spv_parsed_instruction_t* inst) { - SpirvStats* stats = - reinterpret_cast(user_data); + StatsContext* ctx = reinterpret_cast(user_data); + + ProcessOpcode(ctx, inst); + ProcessCapability(ctx, inst); + ProcessExtension(ctx, inst); const SpvOp opcode = static_cast(inst->opcode); - ++stats->opcode_hist[opcode]; - - ProcessCapability(stats, inst); - ProcessExtension(stats, inst); + ctx->opcodes.push_back(opcode); return SPV_SUCCESS; } @@ -86,28 +109,30 @@ spv_result_t ProcessInstruction( namespace libspirv { spv_result_t AggregateStats( - const spv_context_t& context, const uint32_t* words, const size_t num_words, + const spv_context_t& spv_context, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, SpirvStats* stats) { spv_const_binary_t binary = {words, num_words}; spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(&binary, &endian)) { - return libspirv::DiagnosticStream(position, context.consumer, + return libspirv::DiagnosticStream(position, spv_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(&binary, endian, &header)) { - return libspirv::DiagnosticStream(position, context.consumer, + return libspirv::DiagnosticStream(position, spv_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } - return spvBinaryParse(&context, stats, words, num_words, - ProcessHeader, ProcessInstruction, - pDiagnostic); + StatsContext stats_context; + stats_context.stats = stats; + + return spvBinaryParse(&spv_context, &stats_context, words, num_words, + ProcessHeader, ProcessInstruction, pDiagnostic); } } // namespace libspirv diff --git a/source/spirv_stats.h b/source/spirv_stats.h index 862cb828d..d63992496 100644 --- a/source/spirv_stats.h +++ b/source/spirv_stats.h @@ -38,6 +38,20 @@ struct SpirvStats { // Opcode histogram, SpvOpXXX -> count. std::unordered_map opcode_hist; + + // Used to collect statistics on opcodes triggering other opcodes. + // Container scheme: gap between instructions -> cue opcode -> later opcode + // -> count. + // For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to + // the number of times an OpMul appears, followed by 2 other instructions, + // followed by OpFAdd. + // opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times + // OpFMul appears, directly followed by OpFAdd. + // The size of the outer std::vector also serves as an input parameter, + // determining how many steps will be collected. + // I.e. do opcode_markov_hist.resize(1) to collect data for one step only. + std::vector>> opcode_markov_hist; }; // Aggregates existing |stats| with new stats extracted from |binary|. diff --git a/test/stats/stats_aggregate_test.cpp b/test/stats/stats_aggregate_test.cpp index 1fc3e71b7..463b7820e 100644 --- a/test/stats/stats_aggregate_test.cpp +++ b/test/stats/stats_aggregate_test.cpp @@ -235,4 +235,85 @@ OpMemoryModel Logical GLSL450 EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension)); } +TEST(AggregateStats, OpcodeMarkovHistogram) { + const std::string code1 = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_NV_viewport_array2" +OpMemoryModel Logical GLSL450 +)"; + + const std::string code2 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%i32 = OpTypeInt 32 1 +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +)"; + + SpirvStats stats; + stats.opcode_markov_hist.resize(2); + + CompileAndAggregateStats(code1, &stats); + ASSERT_EQ(2u, stats.opcode_markov_hist.size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size()); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel)); + + EXPECT_EQ(1u, stats.opcode_markov_hist[1].size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size()); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel)); + + CompileAndAggregateStats(code2, &stats); + ASSERT_EQ(2u, stats.opcode_markov_hist.size()); + EXPECT_EQ(4u, stats.opcode_markov_hist[0].size()); + EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size()); + EXPECT_EQ( + 4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat)); + + EXPECT_EQ(3u, stats.opcode_markov_hist[1].size()); + EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size()); + EXPECT_EQ( + 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat)); +} + } // namespace diff --git a/test/stats/stats_analyzer_test.cpp b/test/stats/stats_analyzer_test.cpp index 6659e8558..6dcaac4fd 100644 --- a/test/stats/stats_analyzer_test.cpp +++ b/test/stats/stats_analyzer_test.cpp @@ -140,4 +140,33 @@ TEST(StatsAnalyzer, Opcode) { EXPECT_EQ(expected_output, output); } +TEST(StatsAnalyzer, OpcodeMarkov) { + SpirvStats stats; + FillDefaultStats(&stats); + + stats.opcode_hist[SpvOpFMul] = 400; + stats.opcode_hist[SpvOpFAdd] = 200; + stats.opcode_hist[SpvOpFSub] = 400; + + stats.opcode_markov_hist.resize(1); + auto& hist = stats.opcode_markov_hist[0]; + hist[SpvOpFMul][SpvOpFAdd] = 100; + hist[SpvOpFMul][SpvOpFSub] = 300; + hist[SpvOpFAdd][SpvOpFMul] = 100; + hist[SpvOpFAdd][SpvOpFAdd] = 100; + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteOpcodeMarkov(ss); + const std::string output = ss.str(); + const std::string expected_output = + "FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n" + "FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n" + "FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n" + "FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n"; + + EXPECT_EQ(expected_output, output); +} + } // namespace diff --git a/tools/stats/stats.cpp b/tools/stats/stats.cpp index 15abbabfe..40e8f28f5 100644 --- a/tools/stats/stats.cpp +++ b/tools/stats/stats.cpp @@ -105,13 +105,13 @@ int main(int argc, char** argv) { return return_code; } - std::cerr << "Processing " << paths.size() - << " files..." << std::endl; + std::cerr << "Processing " << paths.size() << " files..." << std::endl; ScopedContext ctx(SPV_ENV_UNIVERSAL_1_1); SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler); libspirv::SpirvStats stats; + stats.opcode_markov_hist.resize(1); for (size_t index = 0; index < paths.size(); ++index) { const size_t kMilestonePeriod = 1000; @@ -148,5 +148,8 @@ int main(int argc, char** argv) { out << std::endl; analyzer.WriteOpcode(out); + out << std::endl; + analyzer.WriteOpcodeMarkov(out); + return 0; } diff --git a/tools/stats/stats_analyzer.cpp b/tools/stats/stats_analyzer.cpp index b2a514717..428265c43 100644 --- a/tools/stats/stats_analyzer.cpp +++ b/tools/stats/stats_analyzer.cpp @@ -74,6 +74,11 @@ std::unordered_map GetPrevalence( return GetRecall(hist, total); } +// Writes |freq| to |out| sorted by frequency in the following format: +// LABEL3 70% +// LABEL1 20% +// LABEL2 10% +// |label_from_key| is used to convert |Key| to label. template void WriteFreq(std::ostream& out, const std::unordered_map& freq, std::string (*label_from_key)(Key)) { @@ -90,6 +95,26 @@ void WriteFreq(std::ostream& out, const std::unordered_map& freq, } } +// Writes |hist| to |out| sorted by count in the following format: +// LABEL3 100 +// LABEL1 50 +// LABEL2 10 +// |label_from_key| is used to convert |Key| to label. +template +void WriteHist(std::ostream& out, const std::unordered_map& hist, + std::string (*label_from_key)(Key)) { + std::vector> sorted_hist(hist.begin(), hist.end()); + std::sort(sorted_hist.begin(), sorted_hist.end(), + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); + + for (const auto& pair : sorted_hist) { + out << label_from_key(pair.first) << " " << pair.second << std::endl; + } +} + } // namespace StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) { @@ -125,3 +150,59 @@ void StatsAnalyzer::WriteOpcode(std::ostream& out) { out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl; WriteFreq(out, opcode_freq_, GetOpcodeString); } + +void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) { + if (stats_.opcode_markov_hist.empty()) + return; + + const std::unordered_map>& + cue_to_hist = stats_.opcode_markov_hist[0]; + + // Sort by prevalence of the opcodes in opcode_freq_ (descending). + std::vector>> + sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end()); + std::sort(sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(), + [this]( + const std::pair>& left, + const std::pair>& right) { + const double lf = opcode_freq_[left.first]; + const double rf = opcode_freq_[right.first]; + if (lf == rf) + return right.first > left.first; + return lf > rf; + }); + + for (const auto& kv : sorted_cue_to_hist) { + const uint32_t cue = kv.first; + const double kFrequentEnoughToAnalyze = 0.0001; + if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue; + + const std::unordered_map& hist = kv.second; + + uint32_t total = 0; + for (const auto& pair : hist) { + total += pair.second; + } + + std::vector> + sorted_hist(hist.begin(), hist.end()); + std::sort(sorted_hist.begin(), sorted_hist.end(), + [](const std::pair& left, + const std::pair& right) { + if (left.second == right.second) + return right.first > left.first; + return left.second > right.second; + }); + + for (const auto& pair : sorted_hist) { + const double prior = opcode_freq_[pair.first]; + const double posterior = + static_cast(pair.second) / static_cast(total); + out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first) + << " " << posterior * 100 << "% (base rate " << prior * 100 + << "%, pair occurrences " << pair.second << ")" << std::endl; + } + } +} diff --git a/tools/stats/stats_analyzer.h b/tools/stats/stats_analyzer.h index 457c20407..809cde289 100644 --- a/tools/stats/stats_analyzer.h +++ b/tools/stats/stats_analyzer.h @@ -23,12 +23,18 @@ class StatsAnalyzer { public: explicit StatsAnalyzer(const libspirv::SpirvStats& stats); + // Writes respective histograms to |out|. void WriteVersion(std::ostream& out); void WriteGenerator(std::ostream& out); void WriteCapability(std::ostream& out); void WriteExtension(std::ostream& out); void WriteOpcode(std::ostream& out); + // Writes first order Markov analysis to |out|. + // stats_.opcode_markov_hist needs to contain raw data for at least one + // level. + void WriteOpcodeMarkov(std::ostream& out); + private: const libspirv::SpirvStats& stats_;