diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp index a9200f377..9720d5502 100644 --- a/source/spirv_stats.cpp +++ b/source/spirv_stats.cpp @@ -19,9 +19,7 @@ #include #include #include -#include -#include "binary.h" #include "diagnostic.h" #include "enum_string_mapping.h" #include "extensions.h" @@ -30,8 +28,6 @@ #include "opcode.h" #include "operand.h" #include "spirv-tools/libspirv.h" -#include "spirv_endian.h" -#include "spirv_validator_options.h" #include "val/instruction.h" #include "val/validate.h" #include "val/validation_state.h" @@ -44,68 +40,54 @@ namespace { // instruction. class StatsAggregator { public: - StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context, - const uint32_t* words, size_t num_words) { - stats_ = in_out_stats; - vstate_.reset(new val::ValidationState_t(context, &validator_options_, - words, num_words)); - } + StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state) + : stats_(in_out_stats), vstate_(state) {} - // Collects header statistics and sets correct id_bound. - spv_result_t ProcessHeader(spv_endianness_t /* endian */, - uint32_t /* magic */, uint32_t version, - uint32_t generator, uint32_t id_bound, - uint32_t /* schema */) { - vstate_->setIdBound(id_bound); - ++stats_->version_hist[version]; - ++stats_->generator_hist[generator]; - return SPV_SUCCESS; - } + // Processes the instructions to collect stats. + void aggregate() { + const auto& instructions = vstate_->ordered_instructions(); - // Runs validator to validate the instruction and update vstate_, - // then procession the instruction to collect stats. - spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) { - const spv_result_t validation_result = - ValidateInstructionAndUpdateValidationState(vstate_.get(), inst); - if (validation_result != SPV_SUCCESS) return validation_result; + ++stats_->version_hist[vstate_->version()]; + ++stats_->generator_hist[vstate_->generator()]; - ProcessOpcode(); - ProcessCapability(); - ProcessExtension(); - ProcessConstant(); - ProcessEnums(); - ProcessLiteralStrings(); - ProcessNonIdWords(); - ProcessIdDescriptors(); + for (size_t i = 0; i < instructions.size(); ++i) { + const auto& inst = instructions[i]; - return SPV_SUCCESS; + ProcessOpcode(&inst, i); + ProcessCapability(&inst); + ProcessExtension(&inst); + ProcessConstant(&inst); + ProcessEnums(&inst); + ProcessLiteralStrings(&inst); + ProcessNonIdWords(&inst); + ProcessIdDescriptors(&inst); + } } // Collects statistics of descriptors generated by IdDescriptorCollection. - void ProcessIdDescriptors() { - const val::Instruction& inst = GetCurrentInstruction(); + void ProcessIdDescriptors(const val::Instruction* inst) { const uint32_t new_descriptor = - id_descriptors_.ProcessInstruction(inst.c_inst()); + id_descriptors_.ProcessInstruction(inst->c_inst()); if (new_descriptor) { std::stringstream ss; - ss << spvOpcodeString(inst.opcode()); - for (size_t i = 1; i < inst.words().size(); ++i) { - ss << " " << inst.word(i); + ss << spvOpcodeString(inst->opcode()); + for (size_t i = 1; i < inst->words().size(); ++i) { + ss << " " << inst->word(i); } stats_->id_descriptor_labels.emplace(new_descriptor, ss.str()); } uint32_t index = 0; - for (const auto& operand : inst.operands()) { + for (const auto& operand : inst->operands()) { if (spvIsIdType(operand.type)) { const uint32_t descriptor = - id_descriptors_.GetDescriptor(inst.word(operand.offset)); + id_descriptors_.GetDescriptor(inst->word(operand.offset)); if (descriptor) { ++stats_->id_descriptor_hist[descriptor]; ++stats_ ->operand_slot_id_descriptor_hist[std::pair( - inst.opcode(), index)][descriptor]; + inst->opcode(), index)][descriptor]; } } ++index; @@ -113,9 +95,8 @@ class StatsAggregator { } // Collects statistics of enum words for operands of specific types. - void ProcessEnums() { - const val::Instruction& inst = GetCurrentInstruction(); - for (const auto& operand : inst.operands()) { + void ProcessEnums(const val::Instruction* inst) { + for (const auto& operand : inst->operands()) { switch (operand.type) { case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: case SPV_OPERAND_TYPE_EXECUTION_MODEL: @@ -139,7 +120,7 @@ class StatsAggregator { case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: case SPV_OPERAND_TYPE_CAPABILITY: { - ++stats_->enum_hist[operand.type][inst.word(operand.offset)]; + ++stats_->enum_hist[operand.type][inst->word(operand.offset)]; break; } default: @@ -149,79 +130,74 @@ class StatsAggregator { } // Collects statistics of literal strings used by opcodes. - void ProcessLiteralStrings() { - const val::Instruction& inst = GetCurrentInstruction(); - for (const auto& operand : inst.operands()) { + void ProcessLiteralStrings(const val::Instruction* inst) { + for (const auto& operand : inst->operands()) { if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) { const std::string str = - reinterpret_cast(&inst.words()[operand.offset]); - ++stats_->literal_strings_hist[inst.opcode()][str]; + reinterpret_cast(&inst->words()[operand.offset]); + ++stats_->literal_strings_hist[inst->opcode()][str]; } } } // Collects statistics of all single word non-id operand slots. - void ProcessNonIdWords() { - const val::Instruction& inst = GetCurrentInstruction(); + void ProcessNonIdWords(const val::Instruction* inst) { uint32_t index = 0; - for (const auto& operand : inst.operands()) { + for (const auto& operand : inst->operands()) { if (operand.num_words == 1 && !spvIsIdType(operand.type)) { ++stats_->operand_slot_non_id_words_hist[std::pair( - inst.opcode(), index)][inst.word(operand.offset)]; + inst->opcode(), index)][inst->word(operand.offset)]; } ++index; } } // Collects OpCapability statistics. - void ProcessCapability() { - const val::Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpCapability) return; - const uint32_t capability = inst.word(inst.operands()[0].offset); + void ProcessCapability(const val::Instruction* inst) { + if (inst->opcode() != SpvOpCapability) return; + const uint32_t capability = inst->word(inst->operands()[0].offset); ++stats_->capability_hist[capability]; } // Collects OpExtension statistics. - void ProcessExtension() { - const val::Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpExtension) return; - const std::string extension = GetExtensionString(&inst.c_inst()); + void ProcessExtension(const val::Instruction* inst) { + if (inst->opcode() != SpvOpExtension) return; + const std::string extension = GetExtensionString(&inst->c_inst()); ++stats_->extension_hist[extension]; } // Collects OpCode statistics. - void ProcessOpcode() { - auto inst_it = vstate_->ordered_instructions().rbegin(); - const SpvOp opcode = inst_it->opcode(); + void ProcessOpcode(const val::Instruction* inst, size_t idx) { + const SpvOp opcode = inst->opcode(); ++stats_->opcode_hist[opcode]; const uint32_t opcode_and_num_operands = - (uint32_t(inst_it->operands().size()) << 16) | uint32_t(opcode); + (uint32_t(inst->operands().size()) << 16) | uint32_t(opcode); ++stats_->opcode_and_num_operands_hist[opcode_and_num_operands]; - ++inst_it; + if (idx == 0) return; - if (inst_it != vstate_->ordered_instructions().rend()) { - const SpvOp prev_opcode = inst_it->opcode(); - ++stats_->opcode_and_num_operands_markov_hist[prev_opcode] - [opcode_and_num_operands]; - } + --idx; + + const auto& instructions = vstate_->ordered_instructions(); + const SpvOp prev_opcode = instructions[idx].opcode(); + ++stats_->opcode_and_num_operands_markov_hist[prev_opcode] + [opcode_and_num_operands]; auto step_it = stats_->opcode_markov_hist.begin(); - for (; inst_it != vstate_->ordered_instructions().rend() && - step_it != stats_->opcode_markov_hist.end(); - ++inst_it, ++step_it) { - auto& hist = (*step_it)[inst_it->opcode()]; + for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) { + auto& hist = (*step_it)[instructions[idx].opcode()]; ++hist[opcode]; + + if (idx == 0) break; } } // Collects OpConstant statistics. - void ProcessConstant() { - const val::Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpConstant) return; + void ProcessConstant(const val::Instruction* inst) { + if (inst->opcode() != SpvOpConstant) return; - const uint32_t type_id = inst.GetOperandAs(0); + const uint32_t type_id = inst->GetOperandAs(0); const auto type_decl_it = vstate_->all_definitions().find(type_id); assert(type_decl_it != vstate_->all_definitions().end()); @@ -233,90 +209,54 @@ class StatsAggregator { assert(is_signed == 0 || is_signed == 1); if (bit_width == 16) { if (is_signed) - ++stats_->s16_constant_hist[inst.GetOperandAs(2)]; + ++stats_->s16_constant_hist[inst->GetOperandAs(2)]; else - ++stats_->u16_constant_hist[inst.GetOperandAs(2)]; + ++stats_->u16_constant_hist[inst->GetOperandAs(2)]; } else if (bit_width == 32) { if (is_signed) - ++stats_->s32_constant_hist[inst.GetOperandAs(2)]; + ++stats_->s32_constant_hist[inst->GetOperandAs(2)]; else - ++stats_->u32_constant_hist[inst.GetOperandAs(2)]; + ++stats_->u32_constant_hist[inst->GetOperandAs(2)]; } else if (bit_width == 64) { if (is_signed) - ++stats_->s64_constant_hist[inst.GetOperandAs(2)]; + ++stats_->s64_constant_hist[inst->GetOperandAs(2)]; else - ++stats_->u64_constant_hist[inst.GetOperandAs(2)]; + ++stats_->u64_constant_hist[inst->GetOperandAs(2)]; } else { assert(false && "TypeInt bit width is not 16, 32 or 64"); } } else if (type_op == SpvOpTypeFloat) { const uint32_t bit_width = type_decl_inst.GetOperandAs(1); if (bit_width == 32) { - ++stats_->f32_constant_hist[inst.GetOperandAs(2)]; + ++stats_->f32_constant_hist[inst->GetOperandAs(2)]; } else if (bit_width == 64) { - ++stats_->f64_constant_hist[inst.GetOperandAs(2)]; + ++stats_->f64_constant_hist[inst->GetOperandAs(2)]; } else { assert(bit_width == 16); } } } - SpirvStats* stats() { return stats_; } - private: - // Returns the current instruction (the one last processed by the validator). - const val::Instruction& GetCurrentInstruction() const { - return vstate_->ordered_instructions().back(); - } - SpirvStats* stats_; - spv_validator_options_t validator_options_; - std::unique_ptr vstate_; + const val::ValidationState_t* vstate_; IdDescriptorCollection id_descriptors_; }; -spv_result_t ProcessHeader(void* user_data, spv_endianness_t endian, - uint32_t magic, uint32_t version, uint32_t generator, - uint32_t id_bound, uint32_t schema) { - StatsAggregator* stats_aggregator = - reinterpret_cast(user_data); - return stats_aggregator->ProcessHeader(endian, magic, version, generator, - id_bound, schema); -} - -spv_result_t ProcessInstruction(void* user_data, - const spv_parsed_instruction_t* inst) { - StatsAggregator* stats_aggregator = - reinterpret_cast(user_data); - return stats_aggregator->ProcessInstruction(inst); -} - } // namespace spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, SpirvStats* stats) { - spv_const_binary_t binary = {words, num_words}; + std::unique_ptr vstate; + spv_validator_options_t options; + spv_result_t result = ValidateBinaryAndKeepValidationState( + &context, &options, words, num_words, pDiagnostic, &vstate); + if (result != SPV_SUCCESS) return result; - spv_endianness_t endian; - spv_position_t position = {}; - if (spvBinaryEndianness(&binary, &endian)) { - return DiagnosticStream(position, context.consumer, "", - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V magic number."; - } - - spv_header_t header; - if (spvBinaryHeaderGet(&binary, endian, &header)) { - return DiagnosticStream(position, context.consumer, "", - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V header."; - } - - StatsAggregator stats_aggregator(stats, &context, words, num_words); - - return spvBinaryParse(&context, &stats_aggregator, words, num_words, - ProcessHeader, ProcessInstruction, pDiagnostic); + StatsAggregator stats_aggregator(stats, vstate.get()); + stats_aggregator.aggregate(); + return SPV_SUCCESS; } } // namespace spvtools diff --git a/source/val/validate.cpp b/source/val/validate.cpp index a56019cbc..b1587c2b6 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp @@ -59,19 +59,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, // TODO(umar): Validate header // TODO(umar): The binary parser validates the magic word, and the length of the // header, but nothing else. -spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic, +spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t, uint32_t version, uint32_t generator, uint32_t id_bound, - uint32_t reserved) { + uint32_t) { // Record the ID bound so that the validator can ensure no ID is out of bound. ValidationState_t& _ = *(reinterpret_cast(user_data)); _.setIdBound(id_bound); + _.setGenerator(generator); + _.setVersion(version); - (void)endian; - (void)magic; - (void)version; - (void)generator; - (void)id_bound; - (void)reserved; return SPV_SUCCESS; } @@ -354,11 +350,6 @@ spv_result_t ValidateBinaryAndKeepValidationState( hijack_context, words, num_words, pDiagnostic, vstate->get()); } -spv_result_t ValidateInstructionAndUpdateValidationState( - ValidationState_t* vstate, const spv_parsed_instruction_t* inst) { - return ProcessInstruction(vstate, inst); -} - } // namespace val } // namespace spvtools diff --git a/source/val/validate.h b/source/val/validate.h index 654e87f32..45709c9c7 100644 --- a/source/val/validate.h +++ b/source/val/validate.h @@ -219,11 +219,6 @@ spv_result_t ValidateBinaryAndKeepValidationState( const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, std::unique_ptr* vstate); -// Performs validation for a single instruction and updates given validation -// state. -spv_result_t ValidateInstructionAndUpdateValidationState( - ValidationState_t* vstate, const spv_parsed_instruction_t* inst); - } // namespace val } // namespace spvtools diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 72c563bfc..3f63047f9 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -101,6 +101,18 @@ class ValidationState_t { /// Returns the command line options spv_const_validator_options options() const { return options_; } + /// Sets the ID of the generator for this module. + void setGenerator(uint32_t gen) { generator_ = gen; } + + /// Returns the ID of the generator for this module. + uint32_t generator() const { return generator_; } + + /// Sets the SPIR-V version of this module. + void setVersion(uint32_t ver) { version_ = ver; } + + /// Gets the SPIR-V version of this module. + uint32_t version() const { return version_; } + /// Forward declares the id in the module spv_result_t ForwardDeclareId(uint32_t id); @@ -523,6 +535,12 @@ class ValidationState_t { const uint32_t* words_; const size_t num_words_; + /// The generator of the SPIR-V. + uint32_t generator_ = 0; + + /// The version of the SPIR-V. + uint32_t version_ = 0; + /// The total number of instructions in the binary. size_t total_instructions_ = 0; /// The total number of functions in the binary.