diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp index 4d51293ae..825f2eb5a 100644 --- a/source/comp/markv_codec.cpp +++ b/source/comp/markv_codec.cpp @@ -49,6 +49,7 @@ #include "markv_model.h" #include "opcode.h" #include "operand.h" +#include "source/assembly_grammar.h" #include "spirv-tools/libspirv.h" #include "spirv_endian.h" #include "spirv_validator_options.h" @@ -58,7 +59,6 @@ #include "util/parse_number.h" #include "val/instruction.h" #include "val/validate.h" -#include "val/validation_state.h" namespace spvtools { namespace comp { diff --git a/source/operand.h b/source/operand.h index 984a62328..d90f6bf83 100644 --- a/source/operand.h +++ b/source/operand.h @@ -15,7 +15,6 @@ #ifndef LIBSPIRV_OPERAND_H_ #define LIBSPIRV_OPERAND_H_ -#include #include #include "spirv-tools/libspirv.h" diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp index a790b9031..a9200f377 100644 --- a/source/spirv_stats.cpp +++ b/source/spirv_stats.cpp @@ -220,9 +220,11 @@ class StatsAggregator { void ProcessConstant() { const val::Instruction& inst = GetCurrentInstruction(); if (inst.opcode() != SpvOpConstant) return; + const uint32_t type_id = inst.GetOperandAs(0); const auto type_decl_it = vstate_->all_definitions().find(type_id); assert(type_decl_it != vstate_->all_definitions().end()); + const val::Instruction& type_decl_inst = *type_decl_it->second; const SpvOp type_op = type_decl_inst.opcode(); if (type_op == SpvOpTypeInt) { diff --git a/source/val/validate.cpp b/source/val/validate.cpp index 7971a9a32..74627a220 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp @@ -238,16 +238,16 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( } // Look for OpExtension instructions and register extensions. - // Diagnostics if any will be produced in the next pass (ProcessInstruction). spvBinaryParse(&context, vstate, words, num_words, /* parsed_header = */ nullptr, ProcessExtensions, /* diagnostic = */ nullptr); - // NOTE: Parse the module and perform inline validation checks. These - // checks do not require the the knowledge of the whole module. + // Parse the module and perform inline validation checks. These checks do + // not require the the knowledge of the whole module. if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader, - ProcessInstruction, pDiagnostic)) + ProcessInstruction, pDiagnostic)) { return error; + } if (!vstate->has_memory_model_specified()) return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr) diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 94bd59223..6fb05fb7c 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -18,18 +18,12 @@ #include #include "opcode.h" +#include "spirv-tools/libspirv.h" #include "spirv_target_env.h" #include "val/basic_block.h" #include "val/construct.h" #include "val/function.h" -using std::deque; -using std::make_pair; -using std::pair; -using std::string; -using std::unordered_map; -using std::vector; - namespace spvtools { namespace val { namespace { @@ -141,6 +135,16 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { return out; } +// Counts the number of instructions and functions in the file. +spv_result_t CountInstructions(void* user_data, + const spv_parsed_instruction_t* inst) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + if (inst->opcode == SpvOpFunction) _.increment_total_functions(); + _.increment_total_instructions(); + + return SPV_SUCCESS; +} + } // namespace ValidationState_t::ValidationState_t(const spv_const_context ctx, @@ -187,6 +191,21 @@ ValidationState_t::ValidationState_t(const spv_const_context ctx, default: break; } + + // Only attempt to count if we have words, otherwise let the other validation + // fail and generate an error. + if (num_words > 0) { + // Count the number of instructions in the binary. + spvBinaryParse(ctx, this, words, num_words, + /* parsed_header = */ nullptr, CountInstructions, + /* diagnostic = */ nullptr); + preallocateStorage(); + } +} + +void ValidationState_t::preallocateStorage() { + ordered_instructions_.reserve(total_instructions_); + module_functions_.reserve(total_functions_); } spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) { @@ -208,11 +227,11 @@ bool ValidationState_t::IsForwardPointer(uint32_t id) const { return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end()); } -void ValidationState_t::AssignNameToId(uint32_t id, string name) { +void ValidationState_t::AssignNameToId(uint32_t id, std::string name) { operand_names_[id] = name; } -string ValidationState_t::getIdName(uint32_t id) const { +std::string ValidationState_t::getIdName(uint32_t id) const { std::stringstream out; out << id; if (operand_names_.find(id) != end(operand_names_)) { @@ -221,7 +240,7 @@ string ValidationState_t::getIdName(uint32_t id) const { return out.str(); } -string ValidationState_t::getIdOrName(uint32_t id) const { +std::string ValidationState_t::getIdOrName(uint32_t id) const { std::stringstream out; if (operand_names_.find(id) != end(operand_names_)) { out << operand_names_.at(id); @@ -235,9 +254,9 @@ size_t ValidationState_t::unresolved_forward_id_count() const { return unresolved_forward_ids_.size(); } -vector ValidationState_t::UnresolvedForwardIds() const { - vector out(begin(unresolved_forward_ids_), - end(unresolved_forward_ids_)); +std::vector ValidationState_t::UnresolvedForwardIds() const { + std::vector out(std::begin(unresolved_forward_ids_), + std::end(unresolved_forward_ids_)); return out; } @@ -300,7 +319,9 @@ DiagnosticStream ValidationState_t::diag(spv_result_t error_code, error_code); } -deque& ValidationState_t::functions() { return module_functions_; } +std::vector& ValidationState_t::functions() { + return module_functions_; +} Function& ValidationState_t::current_function() { assert(in_function_body()); @@ -497,7 +518,7 @@ void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) { } void ValidationState_t::RegisterInstruction(Instruction* inst) { - if (inst->id()) all_definitions_.insert(make_pair(inst->id(), inst)); + if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst)); // If the instruction is using an OpTypeSampledImage as an operand, it should // be recorded. The validator will ensure that all usages of an diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 6195a4b42..7725cfe74 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -15,7 +15,6 @@ #ifndef LIBSPIRV_VAL_VALIDATIONSTATE_H_ #define LIBSPIRV_VAL_VALIDATIONSTATE_H_ -#include #include #include #include @@ -144,6 +143,17 @@ class ValidationState_t { /// Increments the instruction count. Used for diagnostic int increment_instruction_count(); + /// Increments the total number of instructions in the file. + void increment_total_instructions() { total_instructions_++; } + + /// Increments the total number of functions in the file. + void increment_total_functions() { total_functions_++; } + + /// Allocates internal storage. Note, calling this will invalidate any + /// pointers to |ordered_instructions_| or |module_functions_| and, hence, + /// should only be called at the beginning of validation. + void preallocateStorage(); + /// Returns the current layout section which is being processed ModuleLayoutSection current_layout_section() const; @@ -157,7 +167,7 @@ class ValidationState_t { DiagnosticStream diag(spv_result_t error_code, const Instruction* inst) const; /// Returns the function states - std::deque& functions(); + std::vector& functions(); /// Returns the function states Function& current_function(); @@ -355,8 +365,8 @@ class ValidationState_t { /// nullptr Instruction* FindDef(uint32_t id); - /// Returns a deque of instructions in the order they appear in the binary - const std::deque& ordered_instructions() const { + /// Returns the instructions in the order they appear in the binary + const std::vector& ordered_instructions() const { return ordered_instructions_; } @@ -520,6 +530,11 @@ class ValidationState_t { const uint32_t* words_; const size_t num_words_; + /// The total number of instructions in the binary. + size_t total_instructions_ = 0; + /// The total number of functions in the binary. + size_t total_functions_ = 0; + /// Tracks the number of instructions evaluated by the validator int instruction_counter_; @@ -542,7 +557,7 @@ class ValidationState_t { /// A list of functions in the module. /// Pointers to objects in this container are guaranteed to be stable and /// valid until the end of lifetime of the validation state. - std::deque module_functions_; + std::vector module_functions_; /// Capabilities declared in the module CapabilitySet module_capabilities_; @@ -551,9 +566,7 @@ class ValidationState_t { ExtensionSet module_extensions_; /// List of all instructions in the order they appear in the binary - /// Pointers to objects in this container are guaranteed to be stable and - /// valid until the end of lifetime of the validation state. - std::deque ordered_instructions_; + std::vector ordered_instructions_; /// Instructions that can be referenced by Ids std::unordered_map all_definitions_;