Remove ValidateInstructionAndUpdateValidationState (#1784)

This CL changes the stats aggregator to use
ValidateBinaryAndKeepValidationState to process the binary. This means
we can remove ValidateInstructionAndUpdateValidationState which expects
to be able to call ProcessInstruction in the validate anonymous
namespace. This decouples the stats aggregator from how validation
processes the binary.
This commit is contained in:
dan sinclair 2018-08-02 12:01:26 -04:00 committed by GitHub
parent ce644d4a24
commit 1946fb4ddb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 154 deletions

View File

@ -19,9 +19,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "binary.h"
#include "diagnostic.h"
#include "enum_string_mapping.h"
#include "extensions.h"
@ -30,8 +28,6 @@
#include "opcode.h"
#include "operand.h"
#include "spirv-tools/libspirv.h"
#include "spirv_endian.h"
#include "spirv_validator_options.h"
#include "val/instruction.h"
#include "val/validate.h"
#include "val/validation_state.h"
@ -44,68 +40,54 @@ namespace {
// instruction.
class StatsAggregator {
public:
StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context,
const uint32_t* words, size_t num_words) {
stats_ = in_out_stats;
vstate_.reset(new val::ValidationState_t(context, &validator_options_,
words, num_words));
}
StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state)
: stats_(in_out_stats), vstate_(state) {}
// Collects header statistics and sets correct id_bound.
spv_result_t ProcessHeader(spv_endianness_t /* endian */,
uint32_t /* magic */, uint32_t version,
uint32_t generator, uint32_t id_bound,
uint32_t /* schema */) {
vstate_->setIdBound(id_bound);
++stats_->version_hist[version];
++stats_->generator_hist[generator];
return SPV_SUCCESS;
}
// Processes the instructions to collect stats.
void aggregate() {
const auto& instructions = vstate_->ordered_instructions();
// Runs validator to validate the instruction and update vstate_,
// then procession the instruction to collect stats.
spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) {
const spv_result_t validation_result =
ValidateInstructionAndUpdateValidationState(vstate_.get(), inst);
if (validation_result != SPV_SUCCESS) return validation_result;
++stats_->version_hist[vstate_->version()];
++stats_->generator_hist[vstate_->generator()];
ProcessOpcode();
ProcessCapability();
ProcessExtension();
ProcessConstant();
ProcessEnums();
ProcessLiteralStrings();
ProcessNonIdWords();
ProcessIdDescriptors();
for (size_t i = 0; i < instructions.size(); ++i) {
const auto& inst = instructions[i];
return SPV_SUCCESS;
ProcessOpcode(&inst, i);
ProcessCapability(&inst);
ProcessExtension(&inst);
ProcessConstant(&inst);
ProcessEnums(&inst);
ProcessLiteralStrings(&inst);
ProcessNonIdWords(&inst);
ProcessIdDescriptors(&inst);
}
}
// Collects statistics of descriptors generated by IdDescriptorCollection.
void ProcessIdDescriptors() {
const val::Instruction& inst = GetCurrentInstruction();
void ProcessIdDescriptors(const val::Instruction* inst) {
const uint32_t new_descriptor =
id_descriptors_.ProcessInstruction(inst.c_inst());
id_descriptors_.ProcessInstruction(inst->c_inst());
if (new_descriptor) {
std::stringstream ss;
ss << spvOpcodeString(inst.opcode());
for (size_t i = 1; i < inst.words().size(); ++i) {
ss << " " << inst.word(i);
ss << spvOpcodeString(inst->opcode());
for (size_t i = 1; i < inst->words().size(); ++i) {
ss << " " << inst->word(i);
}
stats_->id_descriptor_labels.emplace(new_descriptor, ss.str());
}
uint32_t index = 0;
for (const auto& operand : inst.operands()) {
for (const auto& operand : inst->operands()) {
if (spvIsIdType(operand.type)) {
const uint32_t descriptor =
id_descriptors_.GetDescriptor(inst.word(operand.offset));
id_descriptors_.GetDescriptor(inst->word(operand.offset));
if (descriptor) {
++stats_->id_descriptor_hist[descriptor];
++stats_
->operand_slot_id_descriptor_hist[std::pair<uint32_t, uint32_t>(
inst.opcode(), index)][descriptor];
inst->opcode(), index)][descriptor];
}
}
++index;
@ -113,9 +95,8 @@ class StatsAggregator {
}
// Collects statistics of enum words for operands of specific types.
void ProcessEnums() {
const val::Instruction& inst = GetCurrentInstruction();
for (const auto& operand : inst.operands()) {
void ProcessEnums(const val::Instruction* inst) {
for (const auto& operand : inst->operands()) {
switch (operand.type) {
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
@ -139,7 +120,7 @@ class StatsAggregator {
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
case SPV_OPERAND_TYPE_CAPABILITY: {
++stats_->enum_hist[operand.type][inst.word(operand.offset)];
++stats_->enum_hist[operand.type][inst->word(operand.offset)];
break;
}
default:
@ -149,79 +130,74 @@ class StatsAggregator {
}
// Collects statistics of literal strings used by opcodes.
void ProcessLiteralStrings() {
const val::Instruction& inst = GetCurrentInstruction();
for (const auto& operand : inst.operands()) {
void ProcessLiteralStrings(const val::Instruction* inst) {
for (const auto& operand : inst->operands()) {
if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
const std::string str =
reinterpret_cast<const char*>(&inst.words()[operand.offset]);
++stats_->literal_strings_hist[inst.opcode()][str];
reinterpret_cast<const char*>(&inst->words()[operand.offset]);
++stats_->literal_strings_hist[inst->opcode()][str];
}
}
}
// Collects statistics of all single word non-id operand slots.
void ProcessNonIdWords() {
const val::Instruction& inst = GetCurrentInstruction();
void ProcessNonIdWords(const val::Instruction* inst) {
uint32_t index = 0;
for (const auto& operand : inst.operands()) {
for (const auto& operand : inst->operands()) {
if (operand.num_words == 1 && !spvIsIdType(operand.type)) {
++stats_->operand_slot_non_id_words_hist[std::pair<uint32_t, uint32_t>(
inst.opcode(), index)][inst.word(operand.offset)];
inst->opcode(), index)][inst->word(operand.offset)];
}
++index;
}
}
// Collects OpCapability statistics.
void ProcessCapability() {
const val::Instruction& inst = GetCurrentInstruction();
if (inst.opcode() != SpvOpCapability) return;
const uint32_t capability = inst.word(inst.operands()[0].offset);
void ProcessCapability(const val::Instruction* inst) {
if (inst->opcode() != SpvOpCapability) return;
const uint32_t capability = inst->word(inst->operands()[0].offset);
++stats_->capability_hist[capability];
}
// Collects OpExtension statistics.
void ProcessExtension() {
const val::Instruction& inst = GetCurrentInstruction();
if (inst.opcode() != SpvOpExtension) return;
const std::string extension = GetExtensionString(&inst.c_inst());
void ProcessExtension(const val::Instruction* inst) {
if (inst->opcode() != SpvOpExtension) return;
const std::string extension = GetExtensionString(&inst->c_inst());
++stats_->extension_hist[extension];
}
// Collects OpCode statistics.
void ProcessOpcode() {
auto inst_it = vstate_->ordered_instructions().rbegin();
const SpvOp opcode = inst_it->opcode();
void ProcessOpcode(const val::Instruction* inst, size_t idx) {
const SpvOp opcode = inst->opcode();
++stats_->opcode_hist[opcode];
const uint32_t opcode_and_num_operands =
(uint32_t(inst_it->operands().size()) << 16) | uint32_t(opcode);
(uint32_t(inst->operands().size()) << 16) | uint32_t(opcode);
++stats_->opcode_and_num_operands_hist[opcode_and_num_operands];
++inst_it;
if (idx == 0) return;
if (inst_it != vstate_->ordered_instructions().rend()) {
const SpvOp prev_opcode = inst_it->opcode();
++stats_->opcode_and_num_operands_markov_hist[prev_opcode]
[opcode_and_num_operands];
}
--idx;
const auto& instructions = vstate_->ordered_instructions();
const SpvOp prev_opcode = instructions[idx].opcode();
++stats_->opcode_and_num_operands_markov_hist[prev_opcode]
[opcode_and_num_operands];
auto step_it = stats_->opcode_markov_hist.begin();
for (; inst_it != vstate_->ordered_instructions().rend() &&
step_it != stats_->opcode_markov_hist.end();
++inst_it, ++step_it) {
auto& hist = (*step_it)[inst_it->opcode()];
for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) {
auto& hist = (*step_it)[instructions[idx].opcode()];
++hist[opcode];
if (idx == 0) break;
}
}
// Collects OpConstant statistics.
void ProcessConstant() {
const val::Instruction& inst = GetCurrentInstruction();
if (inst.opcode() != SpvOpConstant) return;
void ProcessConstant(const val::Instruction* inst) {
if (inst->opcode() != SpvOpConstant) return;
const uint32_t type_id = inst.GetOperandAs<uint32_t>(0);
const uint32_t type_id = inst->GetOperandAs<uint32_t>(0);
const auto type_decl_it = vstate_->all_definitions().find(type_id);
assert(type_decl_it != vstate_->all_definitions().end());
@ -233,90 +209,54 @@ class StatsAggregator {
assert(is_signed == 0 || is_signed == 1);
if (bit_width == 16) {
if (is_signed)
++stats_->s16_constant_hist[inst.GetOperandAs<int16_t>(2)];
++stats_->s16_constant_hist[inst->GetOperandAs<int16_t>(2)];
else
++stats_->u16_constant_hist[inst.GetOperandAs<uint16_t>(2)];
++stats_->u16_constant_hist[inst->GetOperandAs<uint16_t>(2)];
} else if (bit_width == 32) {
if (is_signed)
++stats_->s32_constant_hist[inst.GetOperandAs<int32_t>(2)];
++stats_->s32_constant_hist[inst->GetOperandAs<int32_t>(2)];
else
++stats_->u32_constant_hist[inst.GetOperandAs<uint32_t>(2)];
++stats_->u32_constant_hist[inst->GetOperandAs<uint32_t>(2)];
} else if (bit_width == 64) {
if (is_signed)
++stats_->s64_constant_hist[inst.GetOperandAs<int64_t>(2)];
++stats_->s64_constant_hist[inst->GetOperandAs<int64_t>(2)];
else
++stats_->u64_constant_hist[inst.GetOperandAs<uint64_t>(2)];
++stats_->u64_constant_hist[inst->GetOperandAs<uint64_t>(2)];
} else {
assert(false && "TypeInt bit width is not 16, 32 or 64");
}
} else if (type_op == SpvOpTypeFloat) {
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
if (bit_width == 32) {
++stats_->f32_constant_hist[inst.GetOperandAs<float>(2)];
++stats_->f32_constant_hist[inst->GetOperandAs<float>(2)];
} else if (bit_width == 64) {
++stats_->f64_constant_hist[inst.GetOperandAs<double>(2)];
++stats_->f64_constant_hist[inst->GetOperandAs<double>(2)];
} else {
assert(bit_width == 16);
}
}
}
SpirvStats* stats() { return stats_; }
private:
// Returns the current instruction (the one last processed by the validator).
const val::Instruction& GetCurrentInstruction() const {
return vstate_->ordered_instructions().back();
}
SpirvStats* stats_;
spv_validator_options_t validator_options_;
std::unique_ptr<val::ValidationState_t> vstate_;
const val::ValidationState_t* vstate_;
IdDescriptorCollection id_descriptors_;
};
spv_result_t ProcessHeader(void* user_data, spv_endianness_t endian,
uint32_t magic, uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t schema) {
StatsAggregator* stats_aggregator =
reinterpret_cast<StatsAggregator*>(user_data);
return stats_aggregator->ProcessHeader(endian, magic, version, generator,
id_bound, schema);
}
spv_result_t ProcessInstruction(void* user_data,
const spv_parsed_instruction_t* inst) {
StatsAggregator* stats_aggregator =
reinterpret_cast<StatsAggregator*>(user_data);
return stats_aggregator->ProcessInstruction(inst);
}
} // namespace
spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
SpirvStats* stats) {
spv_const_binary_t binary = {words, num_words};
std::unique_ptr<val::ValidationState_t> vstate;
spv_validator_options_t options;
spv_result_t result = ValidateBinaryAndKeepValidationState(
&context, &options, words, num_words, pDiagnostic, &vstate);
if (result != SPV_SUCCESS) return result;
spv_endianness_t endian;
spv_position_t position = {};
if (spvBinaryEndianness(&binary, &endian)) {
return DiagnosticStream(position, context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V magic number.";
}
spv_header_t header;
if (spvBinaryHeaderGet(&binary, endian, &header)) {
return DiagnosticStream(position, context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Invalid SPIR-V header.";
}
StatsAggregator stats_aggregator(stats, &context, words, num_words);
return spvBinaryParse(&context, &stats_aggregator, words, num_words,
ProcessHeader, ProcessInstruction, pDiagnostic);
StatsAggregator stats_aggregator(stats, vstate.get());
stats_aggregator.aggregate();
return SPV_SUCCESS;
}
} // namespace spvtools

View File

@ -59,19 +59,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInsts,
// TODO(umar): Validate header
// TODO(umar): The binary parser validates the magic word, and the length of the
// header, but nothing else.
spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic,
spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
uint32_t version, uint32_t generator, uint32_t id_bound,
uint32_t reserved) {
uint32_t) {
// Record the ID bound so that the validator can ensure no ID is out of bound.
ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
_.setIdBound(id_bound);
_.setGenerator(generator);
_.setVersion(version);
(void)endian;
(void)magic;
(void)version;
(void)generator;
(void)id_bound;
(void)reserved;
return SPV_SUCCESS;
}
@ -354,11 +350,6 @@ spv_result_t ValidateBinaryAndKeepValidationState(
hijack_context, words, num_words, pDiagnostic, vstate->get());
}
spv_result_t ValidateInstructionAndUpdateValidationState(
ValidationState_t* vstate, const spv_parsed_instruction_t* inst) {
return ProcessInstruction(vstate, inst);
}
} // namespace val
} // namespace spvtools

View File

@ -219,11 +219,6 @@ spv_result_t ValidateBinaryAndKeepValidationState(
const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<ValidationState_t>* vstate);
// Performs validation for a single instruction and updates given validation
// state.
spv_result_t ValidateInstructionAndUpdateValidationState(
ValidationState_t* vstate, const spv_parsed_instruction_t* inst);
} // namespace val
} // namespace spvtools

View File

@ -101,6 +101,18 @@ class ValidationState_t {
/// Returns the command line options
spv_const_validator_options options() const { return options_; }
/// Sets the ID of the generator for this module.
void setGenerator(uint32_t gen) { generator_ = gen; }
/// Returns the ID of the generator for this module.
uint32_t generator() const { return generator_; }
/// Sets the SPIR-V version of this module.
void setVersion(uint32_t ver) { version_ = ver; }
/// Gets the SPIR-V version of this module.
uint32_t version() const { return version_; }
/// Forward declares the id in the module
spv_result_t ForwardDeclareId(uint32_t id);
@ -523,6 +535,12 @@ class ValidationState_t {
const uint32_t* words_;
const size_t num_words_;
/// The generator of the SPIR-V.
uint32_t generator_ = 0;
/// The version of the SPIR-V.
uint32_t version_ = 0;
/// The total number of instructions in the binary.
size_t total_instructions_ = 0;
/// The total number of functions in the binary.