mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-29 22:41:03 +00:00
1324 lines
40 KiB
C++
1324 lines
40 KiB
C++
// Copyright (c) 2015-2016 The Khronos Group Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/val/validation_state.h"
|
|
|
|
#include <cassert>
|
|
#include <stack>
|
|
#include <utility>
|
|
|
|
#include "source/opcode.h"
|
|
#include "source/spirv_constant.h"
|
|
#include "source/spirv_target_env.h"
|
|
#include "source/val/basic_block.h"
|
|
#include "source/val/construct.h"
|
|
#include "source/val/function.h"
|
|
#include "spirv-tools/libspirv.h"
|
|
|
|
namespace spvtools {
|
|
namespace val {
|
|
namespace {
|
|
|
|
bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
|
|
// See Section 2.4
|
|
bool out = false;
|
|
// clang-format off
|
|
switch (layout) {
|
|
case kLayoutCapabilities: out = op == SpvOpCapability; break;
|
|
case kLayoutExtensions: out = op == SpvOpExtension; break;
|
|
case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break;
|
|
case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break;
|
|
case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break;
|
|
case kLayoutExecutionMode:
|
|
out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId;
|
|
break;
|
|
case kLayoutDebug1:
|
|
switch (op) {
|
|
case SpvOpSourceContinued:
|
|
case SpvOpSource:
|
|
case SpvOpSourceExtension:
|
|
case SpvOpString:
|
|
out = true;
|
|
break;
|
|
default: break;
|
|
}
|
|
break;
|
|
case kLayoutDebug2:
|
|
switch (op) {
|
|
case SpvOpName:
|
|
case SpvOpMemberName:
|
|
out = true;
|
|
break;
|
|
default: break;
|
|
}
|
|
break;
|
|
case kLayoutDebug3:
|
|
// Only OpModuleProcessed is allowed here.
|
|
out = (op == SpvOpModuleProcessed);
|
|
break;
|
|
case kLayoutAnnotations:
|
|
switch (op) {
|
|
case SpvOpDecorate:
|
|
case SpvOpMemberDecorate:
|
|
case SpvOpGroupDecorate:
|
|
case SpvOpGroupMemberDecorate:
|
|
case SpvOpDecorationGroup:
|
|
case SpvOpDecorateId:
|
|
case SpvOpDecorateStringGOOGLE:
|
|
case SpvOpMemberDecorateStringGOOGLE:
|
|
out = true;
|
|
break;
|
|
default: break;
|
|
}
|
|
break;
|
|
case kLayoutTypes:
|
|
if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
|
|
out = true;
|
|
break;
|
|
}
|
|
switch (op) {
|
|
case SpvOpTypeForwardPointer:
|
|
case SpvOpVariable:
|
|
case SpvOpLine:
|
|
case SpvOpNoLine:
|
|
case SpvOpUndef:
|
|
// SpvOpExtInst is only allowed here for certain extended instruction
|
|
// sets. This will be checked separately
|
|
case SpvOpExtInst:
|
|
out = true;
|
|
break;
|
|
default: break;
|
|
}
|
|
break;
|
|
case kLayoutFunctionDeclarations:
|
|
case kLayoutFunctionDefinitions:
|
|
// NOTE: These instructions should NOT be in these layout sections
|
|
if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
|
|
out = false;
|
|
break;
|
|
}
|
|
switch (op) {
|
|
case SpvOpCapability:
|
|
case SpvOpExtension:
|
|
case SpvOpExtInstImport:
|
|
case SpvOpMemoryModel:
|
|
case SpvOpEntryPoint:
|
|
case SpvOpExecutionMode:
|
|
case SpvOpExecutionModeId:
|
|
case SpvOpSourceContinued:
|
|
case SpvOpSource:
|
|
case SpvOpSourceExtension:
|
|
case SpvOpString:
|
|
case SpvOpName:
|
|
case SpvOpMemberName:
|
|
case SpvOpModuleProcessed:
|
|
case SpvOpDecorate:
|
|
case SpvOpMemberDecorate:
|
|
case SpvOpGroupDecorate:
|
|
case SpvOpGroupMemberDecorate:
|
|
case SpvOpDecorationGroup:
|
|
case SpvOpTypeForwardPointer:
|
|
out = false;
|
|
break;
|
|
default:
|
|
out = true;
|
|
break;
|
|
}
|
|
}
|
|
// clang-format on
|
|
return out;
|
|
}
|
|
|
|
// Counts the number of instructions and functions in the file.
|
|
spv_result_t CountInstructions(void* user_data,
|
|
const spv_parsed_instruction_t* inst) {
|
|
ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
|
|
if (inst->opcode == SpvOpFunction) _.increment_total_functions();
|
|
_.increment_total_instructions();
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
|
|
uint32_t version, uint32_t generator, uint32_t id_bound,
|
|
uint32_t) {
|
|
ValidationState_t& vstate =
|
|
*(reinterpret_cast<ValidationState_t*>(user_data));
|
|
vstate.setIdBound(id_bound);
|
|
vstate.setGenerator(generator);
|
|
vstate.setVersion(version);
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
// Add features based on SPIR-V core version number.
|
|
void UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature* features,
|
|
uint32_t version) {
|
|
assert(features);
|
|
if (version >= SPV_SPIRV_VERSION_WORD(1, 4)) {
|
|
features->select_between_composites = true;
|
|
features->copy_memory_permits_two_memory_accesses = true;
|
|
features->uconvert_spec_constant_op = true;
|
|
features->nonwritable_var_in_function_or_private = true;
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
ValidationState_t::ValidationState_t(const spv_const_context ctx,
|
|
const spv_const_validator_options opt,
|
|
const uint32_t* words,
|
|
const size_t num_words,
|
|
const uint32_t max_warnings)
|
|
: context_(ctx),
|
|
options_(opt),
|
|
words_(words),
|
|
num_words_(num_words),
|
|
unresolved_forward_ids_{},
|
|
operand_names_{},
|
|
current_layout_section_(kLayoutCapabilities),
|
|
module_functions_(),
|
|
module_capabilities_(),
|
|
module_extensions_(),
|
|
ordered_instructions_(),
|
|
all_definitions_(),
|
|
global_vars_(),
|
|
local_vars_(),
|
|
struct_nesting_depth_(),
|
|
struct_has_nested_blockorbufferblock_struct_(),
|
|
grammar_(ctx),
|
|
addressing_model_(SpvAddressingModelMax),
|
|
memory_model_(SpvMemoryModelMax),
|
|
pointer_size_and_alignment_(0),
|
|
in_function_(false),
|
|
num_of_warnings_(0),
|
|
max_num_of_warnings_(max_warnings) {
|
|
assert(opt && "Validator options may not be Null.");
|
|
|
|
const auto env = context_->target_env;
|
|
|
|
if (spvIsVulkanEnv(env)) {
|
|
// Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
|
|
if (env != SPV_ENV_VULKAN_1_0) {
|
|
features_.env_relaxed_block_layout = true;
|
|
}
|
|
}
|
|
|
|
// Only attempt to count if we have words, otherwise let the other validation
|
|
// fail and generate an error.
|
|
if (num_words > 0) {
|
|
// Count the number of instructions in the binary.
|
|
// This parse should not produce any error messages. Hijack the context and
|
|
// replace the message consumer so that we do not pollute any state in input
|
|
// consumer.
|
|
spv_context_t hijacked_context = *ctx;
|
|
hijacked_context.consumer = [](spv_message_level_t, const char*,
|
|
const spv_position_t&, const char*) {};
|
|
spvBinaryParse(&hijacked_context, this, words, num_words, setHeader,
|
|
CountInstructions,
|
|
/* diagnostic = */ nullptr);
|
|
preallocateStorage();
|
|
}
|
|
UpdateFeaturesBasedOnSpirvVersion(&features_, version_);
|
|
|
|
friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
|
|
context_, words_, num_words_);
|
|
name_mapper_ = friendly_mapper_->GetNameMapper();
|
|
}
|
|
|
|
void ValidationState_t::preallocateStorage() {
|
|
ordered_instructions_.reserve(total_instructions_);
|
|
module_functions_.reserve(total_functions_);
|
|
}
|
|
|
|
spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
|
|
unresolved_forward_ids_.insert(id);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
|
|
unresolved_forward_ids_.erase(id);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
|
|
forward_pointer_ids_.insert(id);
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
bool ValidationState_t::IsForwardPointer(uint32_t id) const {
|
|
return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
|
|
}
|
|
|
|
void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
|
|
operand_names_[id] = name;
|
|
}
|
|
|
|
std::string ValidationState_t::getIdName(uint32_t id) const {
|
|
const std::string id_name = name_mapper_(id);
|
|
|
|
std::stringstream out;
|
|
out << id << "[%" << id_name << "]";
|
|
return out.str();
|
|
}
|
|
|
|
size_t ValidationState_t::unresolved_forward_id_count() const {
|
|
return unresolved_forward_ids_.size();
|
|
}
|
|
|
|
std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
|
|
std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
|
|
std::end(unresolved_forward_ids_));
|
|
return out;
|
|
}
|
|
|
|
bool ValidationState_t::IsDefinedId(uint32_t id) const {
|
|
return all_definitions_.find(id) != std::end(all_definitions_);
|
|
}
|
|
|
|
const Instruction* ValidationState_t::FindDef(uint32_t id) const {
|
|
auto it = all_definitions_.find(id);
|
|
if (it == all_definitions_.end()) return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
Instruction* ValidationState_t::FindDef(uint32_t id) {
|
|
auto it = all_definitions_.find(id);
|
|
if (it == all_definitions_.end()) return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
ModuleLayoutSection ValidationState_t::current_layout_section() const {
|
|
return current_layout_section_;
|
|
}
|
|
|
|
void ValidationState_t::ProgressToNextLayoutSectionOrder() {
|
|
// Guard against going past the last element(kLayoutFunctionDefinitions)
|
|
if (current_layout_section_ <= kLayoutFunctionDefinitions) {
|
|
current_layout_section_ =
|
|
static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
|
|
}
|
|
}
|
|
|
|
bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
|
|
return IsInstructionInLayoutSection(current_layout_section_, op);
|
|
}
|
|
|
|
DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
|
|
const Instruction* inst) {
|
|
if (error_code == SPV_WARNING) {
|
|
if (num_of_warnings_ == max_num_of_warnings_) {
|
|
DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
|
|
<< "Other warnings have been suppressed.\n";
|
|
}
|
|
if (num_of_warnings_ >= max_num_of_warnings_) {
|
|
return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
|
|
}
|
|
++num_of_warnings_;
|
|
}
|
|
|
|
std::string disassembly;
|
|
if (inst) disassembly = Disassemble(*inst);
|
|
|
|
return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
|
|
context_->consumer, disassembly, error_code);
|
|
}
|
|
|
|
std::vector<Function>& ValidationState_t::functions() {
|
|
return module_functions_;
|
|
}
|
|
|
|
Function& ValidationState_t::current_function() {
|
|
assert(in_function_body());
|
|
return module_functions_.back();
|
|
}
|
|
|
|
const Function& ValidationState_t::current_function() const {
|
|
assert(in_function_body());
|
|
return module_functions_.back();
|
|
}
|
|
|
|
const Function* ValidationState_t::function(uint32_t id) const {
|
|
const auto it = id_to_function_.find(id);
|
|
if (it == id_to_function_.end()) return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
Function* ValidationState_t::function(uint32_t id) {
|
|
auto it = id_to_function_.find(id);
|
|
if (it == id_to_function_.end()) return nullptr;
|
|
return it->second;
|
|
}
|
|
|
|
bool ValidationState_t::in_function_body() const { return in_function_; }
|
|
|
|
bool ValidationState_t::in_block() const {
|
|
return module_functions_.empty() == false &&
|
|
module_functions_.back().current_block() != nullptr;
|
|
}
|
|
|
|
void ValidationState_t::RegisterCapability(SpvCapability cap) {
|
|
// Avoid redundant work. Otherwise the recursion could induce work
|
|
// quadrdatic in the capability dependency depth. (Ok, not much, but
|
|
// it's something.)
|
|
if (module_capabilities_.Contains(cap)) return;
|
|
|
|
module_capabilities_.Add(cap);
|
|
spv_operand_desc desc;
|
|
if (SPV_SUCCESS ==
|
|
grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
|
|
CapabilitySet(desc->numCapabilities, desc->capabilities)
|
|
.ForEach([this](SpvCapability c) { RegisterCapability(c); });
|
|
}
|
|
|
|
switch (cap) {
|
|
case SpvCapabilityKernel:
|
|
features_.group_ops_reduce_and_scans = true;
|
|
break;
|
|
case SpvCapabilityInt8:
|
|
features_.use_int8_type = true;
|
|
features_.declare_int8_type = true;
|
|
break;
|
|
case SpvCapabilityStorageBuffer8BitAccess:
|
|
case SpvCapabilityUniformAndStorageBuffer8BitAccess:
|
|
case SpvCapabilityStoragePushConstant8:
|
|
features_.declare_int8_type = true;
|
|
break;
|
|
case SpvCapabilityInt16:
|
|
features_.declare_int16_type = true;
|
|
break;
|
|
case SpvCapabilityFloat16:
|
|
case SpvCapabilityFloat16Buffer:
|
|
features_.declare_float16_type = true;
|
|
break;
|
|
case SpvCapabilityStorageUniformBufferBlock16:
|
|
case SpvCapabilityStorageUniform16:
|
|
case SpvCapabilityStoragePushConstant16:
|
|
case SpvCapabilityStorageInputOutput16:
|
|
features_.declare_int16_type = true;
|
|
features_.declare_float16_type = true;
|
|
features_.free_fp_rounding_mode = true;
|
|
break;
|
|
case SpvCapabilityVariablePointers:
|
|
features_.variable_pointers = true;
|
|
features_.variable_pointers_storage_buffer = true;
|
|
break;
|
|
case SpvCapabilityVariablePointersStorageBuffer:
|
|
features_.variable_pointers_storage_buffer = true;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
void ValidationState_t::RegisterExtension(Extension ext) {
|
|
if (module_extensions_.Contains(ext)) return;
|
|
|
|
module_extensions_.Add(ext);
|
|
|
|
switch (ext) {
|
|
case kSPV_AMD_gpu_shader_half_float:
|
|
case kSPV_AMD_gpu_shader_half_float_fetch:
|
|
// SPV_AMD_gpu_shader_half_float enables float16 type.
|
|
// https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
|
|
features_.declare_float16_type = true;
|
|
break;
|
|
case kSPV_AMD_gpu_shader_int16:
|
|
// This is not yet in the extension, but it's recommended for it.
|
|
// See https://github.com/KhronosGroup/glslang/issues/848
|
|
features_.uconvert_spec_constant_op = true;
|
|
break;
|
|
case kSPV_AMD_shader_ballot:
|
|
// The grammar doesn't encode the fact that SPV_AMD_shader_ballot
|
|
// enables the use of group operations Reduce, InclusiveScan,
|
|
// and ExclusiveScan. Enable it manually.
|
|
// https://github.com/KhronosGroup/SPIRV-Tools/issues/991
|
|
features_.group_ops_reduce_and_scans = true;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
bool ValidationState_t::HasAnyOfCapabilities(
|
|
const CapabilitySet& capabilities) const {
|
|
return module_capabilities_.HasAnyOf(capabilities);
|
|
}
|
|
|
|
bool ValidationState_t::HasAnyOfExtensions(
|
|
const ExtensionSet& extensions) const {
|
|
return module_extensions_.HasAnyOf(extensions);
|
|
}
|
|
|
|
void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
|
|
addressing_model_ = am;
|
|
switch (am) {
|
|
case SpvAddressingModelPhysical32:
|
|
pointer_size_and_alignment_ = 4;
|
|
break;
|
|
default:
|
|
// fall through
|
|
case SpvAddressingModelPhysical64:
|
|
case SpvAddressingModelPhysicalStorageBuffer64EXT:
|
|
pointer_size_and_alignment_ = 8;
|
|
break;
|
|
}
|
|
}
|
|
|
|
SpvAddressingModel ValidationState_t::addressing_model() const {
|
|
return addressing_model_;
|
|
}
|
|
|
|
void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
|
|
memory_model_ = mm;
|
|
}
|
|
|
|
SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
|
|
|
|
spv_result_t ValidationState_t::RegisterFunction(
|
|
uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
|
|
uint32_t function_type_id) {
|
|
assert(in_function_body() == false &&
|
|
"RegisterFunction can only be called when parsing the binary outside "
|
|
"of another function");
|
|
in_function_ = true;
|
|
module_functions_.emplace_back(id, ret_type_id, function_control,
|
|
function_type_id);
|
|
id_to_function_.emplace(id, ¤t_function());
|
|
|
|
// TODO(umar): validate function type and type_id
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidationState_t::RegisterFunctionEnd() {
|
|
assert(in_function_body() == true &&
|
|
"RegisterFunctionEnd can only be called when parsing the binary "
|
|
"inside of another function");
|
|
assert(in_block() == false &&
|
|
"RegisterFunctionParameter can only be called when parsing the binary "
|
|
"ouside of a block");
|
|
current_function().RegisterFunctionEnd();
|
|
in_function_ = false;
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
Instruction* ValidationState_t::AddOrderedInstruction(
|
|
const spv_parsed_instruction_t* inst) {
|
|
ordered_instructions_.emplace_back(inst);
|
|
ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
|
|
return &ordered_instructions_.back();
|
|
}
|
|
|
|
// Improves diagnostic messages by collecting names of IDs
|
|
void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
|
|
switch (inst->opcode()) {
|
|
case SpvOpName: {
|
|
const auto target = inst->GetOperandAs<uint32_t>(0);
|
|
const auto* str = reinterpret_cast<const char*>(inst->words().data() +
|
|
inst->operand(1).offset);
|
|
AssignNameToId(target, str);
|
|
break;
|
|
}
|
|
case SpvOpMemberName: {
|
|
const auto target = inst->GetOperandAs<uint32_t>(0);
|
|
const auto* str = reinterpret_cast<const char*>(inst->words().data() +
|
|
inst->operand(2).offset);
|
|
AssignNameToId(target, str);
|
|
break;
|
|
}
|
|
case SpvOpSourceContinued:
|
|
case SpvOpSource:
|
|
case SpvOpSourceExtension:
|
|
case SpvOpString:
|
|
case SpvOpLine:
|
|
case SpvOpNoLine:
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
void ValidationState_t::RegisterInstruction(Instruction* inst) {
|
|
if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
|
|
|
|
// If the instruction is using an OpTypeSampledImage as an operand, it should
|
|
// be recorded. The validator will ensure that all usages of an
|
|
// OpTypeSampledImage and its definition are in the same basic block.
|
|
for (uint16_t i = 0; i < inst->operands().size(); ++i) {
|
|
const spv_parsed_operand_t& operand = inst->operand(i);
|
|
if (SPV_OPERAND_TYPE_ID == operand.type) {
|
|
const uint32_t operand_word = inst->word(operand.offset);
|
|
Instruction* operand_inst = FindDef(operand_word);
|
|
if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
|
|
RegisterSampledImageConsumer(operand_word, inst);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<Instruction*> ValidationState_t::getSampledImageConsumers(
|
|
uint32_t sampled_image_id) const {
|
|
std::vector<Instruction*> result;
|
|
auto iter = sampled_image_consumers_.find(sampled_image_id);
|
|
if (iter != sampled_image_consumers_.end()) {
|
|
result = iter->second;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
|
|
Instruction* consumer) {
|
|
sampled_image_consumers_[sampled_image_id].push_back(consumer);
|
|
}
|
|
|
|
uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
|
|
|
|
void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
|
|
|
|
bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
|
|
std::vector<uint32_t> key;
|
|
key.push_back(static_cast<uint32_t>(inst->opcode()));
|
|
for (size_t index = 0; index < inst->operands().size(); ++index) {
|
|
const spv_parsed_operand_t& operand = inst->operand(index);
|
|
|
|
if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
|
|
|
|
const int words_begin = operand.offset;
|
|
const int words_end = words_begin + operand.num_words;
|
|
assert(words_end <= static_cast<int>(inst->words().size()));
|
|
|
|
key.insert(key.end(), inst->words().begin() + words_begin,
|
|
inst->words().begin() + words_end);
|
|
}
|
|
|
|
return unique_type_declarations_.insert(std::move(key)).second;
|
|
}
|
|
|
|
uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
return inst ? inst->type_id() : 0;
|
|
}
|
|
|
|
SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
return inst ? inst->opcode() : SpvOpNop;
|
|
}
|
|
|
|
uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
switch (inst->opcode()) {
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
return id;
|
|
|
|
case SpvOpTypeVector:
|
|
return inst->word(2);
|
|
|
|
case SpvOpTypeMatrix:
|
|
return GetComponentType(inst->word(2));
|
|
|
|
case SpvOpTypeCooperativeMatrixNV:
|
|
return inst->word(2);
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (inst->type_id()) return GetComponentType(inst->type_id());
|
|
|
|
assert(0);
|
|
return 0;
|
|
}
|
|
|
|
uint32_t ValidationState_t::GetDimension(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
switch (inst->opcode()) {
|
|
case SpvOpTypeFloat:
|
|
case SpvOpTypeInt:
|
|
case SpvOpTypeBool:
|
|
return 1;
|
|
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeMatrix:
|
|
return inst->word(3);
|
|
|
|
case SpvOpTypeCooperativeMatrixNV:
|
|
// Actual dimension isn't known, return 0
|
|
return 0;
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (inst->type_id()) return GetDimension(inst->type_id());
|
|
|
|
assert(0);
|
|
return 0;
|
|
}
|
|
|
|
uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
|
|
const uint32_t component_type_id = GetComponentType(id);
|
|
const Instruction* inst = FindDef(component_type_id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
|
|
return inst->word(2);
|
|
|
|
if (inst->opcode() == SpvOpTypeBool) return 1;
|
|
|
|
assert(0);
|
|
return 0;
|
|
}
|
|
|
|
bool ValidationState_t::IsVoidType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeVoid;
|
|
}
|
|
|
|
bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeFloat;
|
|
}
|
|
|
|
bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsFloatScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeFloat) {
|
|
return true;
|
|
}
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsFloatScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsIntScalarType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeInt;
|
|
}
|
|
|
|
bool ValidationState_t::IsIntVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsIntScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeInt) {
|
|
return true;
|
|
}
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsIntScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
|
|
}
|
|
|
|
bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsUnsignedIntScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
|
|
}
|
|
|
|
bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsSignedIntScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeBool;
|
|
}
|
|
|
|
bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsBoolScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeBool) {
|
|
return true;
|
|
}
|
|
|
|
if (inst->opcode() == SpvOpTypeVector) {
|
|
return IsBoolScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
|
|
if (inst->opcode() == SpvOpTypeMatrix) {
|
|
return IsFloatScalarType(GetComponentType(id));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
|
|
uint32_t* num_cols,
|
|
uint32_t* column_type,
|
|
uint32_t* component_type) const {
|
|
if (!id) return false;
|
|
|
|
const Instruction* mat_inst = FindDef(id);
|
|
assert(mat_inst);
|
|
if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
|
|
|
|
const uint32_t vec_type = mat_inst->word(2);
|
|
const Instruction* vec_inst = FindDef(vec_type);
|
|
assert(vec_inst);
|
|
|
|
if (vec_inst->opcode() != SpvOpTypeVector) {
|
|
assert(0);
|
|
return false;
|
|
}
|
|
|
|
*num_cols = mat_inst->word(3);
|
|
*num_rows = vec_inst->word(3);
|
|
*column_type = mat_inst->word(2);
|
|
*component_type = vec_inst->word(2);
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ValidationState_t::GetStructMemberTypes(
|
|
uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
|
|
member_types->clear();
|
|
if (!struct_type_id) return false;
|
|
|
|
const Instruction* inst = FindDef(struct_type_id);
|
|
assert(inst);
|
|
if (inst->opcode() != SpvOpTypeStruct) return false;
|
|
|
|
*member_types =
|
|
std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
|
|
|
|
if (member_types->empty()) return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ValidationState_t::IsPointerType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypePointer;
|
|
}
|
|
|
|
bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
|
|
uint32_t* storage_class) const {
|
|
if (!id) return false;
|
|
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
if (inst->opcode() != SpvOpTypePointer) return false;
|
|
|
|
*storage_class = inst->word(2);
|
|
*data_type = inst->word(3);
|
|
return true;
|
|
}
|
|
|
|
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
|
|
const Instruction* inst = FindDef(id);
|
|
assert(inst);
|
|
return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
|
|
}
|
|
|
|
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
|
|
if (!IsCooperativeMatrixType(id)) return false;
|
|
return IsFloatScalarType(FindDef(id)->word(2));
|
|
}
|
|
|
|
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
|
|
if (!IsCooperativeMatrixType(id)) return false;
|
|
return IsIntScalarType(FindDef(id)->word(2));
|
|
}
|
|
|
|
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
|
|
if (!IsCooperativeMatrixType(id)) return false;
|
|
return IsUnsignedIntScalarType(FindDef(id)->word(2));
|
|
}
|
|
|
|
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
|
const Instruction* inst, uint32_t m1, uint32_t m2) {
|
|
const auto m1_type = FindDef(m1);
|
|
const auto m2_type = FindDef(m2);
|
|
|
|
if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
|
|
m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
|
<< "Expected cooperative matrix types";
|
|
}
|
|
|
|
uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
|
|
uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
|
|
uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
|
|
|
|
uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
|
|
uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
|
|
uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
|
|
|
|
bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
|
|
m2_is_const_int32 = false;
|
|
uint32_t m1_value = 0, m2_value = 0;
|
|
|
|
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
|
EvalInt32IfConst(m1_scope_id);
|
|
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
|
EvalInt32IfConst(m2_scope_id);
|
|
|
|
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
|
<< "Expected scopes of Matrix and Result Type to be "
|
|
<< "identical";
|
|
}
|
|
|
|
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
|
EvalInt32IfConst(m1_rows_id);
|
|
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
|
EvalInt32IfConst(m2_rows_id);
|
|
|
|
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
|
<< "Expected rows of Matrix type and Result Type to be "
|
|
<< "identical";
|
|
}
|
|
|
|
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
|
EvalInt32IfConst(m1_cols_id);
|
|
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
|
EvalInt32IfConst(m2_cols_id);
|
|
|
|
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
|
<< "Expected columns of Matrix type and Result Type to be "
|
|
<< "identical";
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
|
|
size_t operand_index) const {
|
|
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
|
|
}
|
|
|
|
bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
|
|
const Instruction* inst = FindDef(id);
|
|
if (!inst) {
|
|
assert(0 && "Instruction not found");
|
|
return false;
|
|
}
|
|
|
|
if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
|
|
return false;
|
|
|
|
if (!IsIntScalarType(inst->type_id())) return false;
|
|
|
|
if (inst->words().size() == 4) {
|
|
*val = inst->word(3);
|
|
} else {
|
|
assert(inst->words().size() == 5);
|
|
*val = inst->word(3);
|
|
*val |= uint64_t(inst->word(4)) << 32;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
|
|
uint32_t id) const {
|
|
const Instruction* const inst = FindDef(id);
|
|
assert(inst);
|
|
const uint32_t type = inst->type_id();
|
|
|
|
if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
|
|
return std::make_tuple(false, false, 0);
|
|
}
|
|
|
|
// Spec constant values cannot be evaluated so don't consider constant for
|
|
// the purpose of this method.
|
|
if (!spvOpcodeIsConstant(inst->opcode()) ||
|
|
spvOpcodeIsSpecConstant(inst->opcode())) {
|
|
return std::make_tuple(true, false, 0);
|
|
}
|
|
|
|
if (inst->opcode() == SpvOpConstantNull) {
|
|
return std::make_tuple(true, true, 0);
|
|
}
|
|
|
|
assert(inst->words().size() == 4);
|
|
return std::make_tuple(true, true, inst->word(3));
|
|
}
|
|
|
|
void ValidationState_t::ComputeFunctionToEntryPointMapping() {
|
|
for (const uint32_t entry_point : entry_points()) {
|
|
std::stack<uint32_t> call_stack;
|
|
std::set<uint32_t> visited;
|
|
call_stack.push(entry_point);
|
|
while (!call_stack.empty()) {
|
|
const uint32_t called_func_id = call_stack.top();
|
|
call_stack.pop();
|
|
if (!visited.insert(called_func_id).second) continue;
|
|
|
|
function_to_entry_points_[called_func_id].push_back(entry_point);
|
|
|
|
const Function* called_func = function(called_func_id);
|
|
if (called_func) {
|
|
// Other checks should error out on this invalid SPIR-V.
|
|
for (const uint32_t new_call : called_func->function_call_targets()) {
|
|
call_stack.push(new_call);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void ValidationState_t::ComputeRecursiveEntryPoints() {
|
|
for (const Function& func : functions()) {
|
|
std::stack<uint32_t> call_stack;
|
|
std::set<uint32_t> visited;
|
|
|
|
for (const uint32_t new_call : func.function_call_targets()) {
|
|
call_stack.push(new_call);
|
|
}
|
|
|
|
while (!call_stack.empty()) {
|
|
const uint32_t called_func_id = call_stack.top();
|
|
call_stack.pop();
|
|
|
|
if (!visited.insert(called_func_id).second) continue;
|
|
|
|
if (called_func_id == func.id()) {
|
|
for (const uint32_t entry_point :
|
|
function_to_entry_points_[called_func_id])
|
|
recursive_entry_points_.insert(entry_point);
|
|
break;
|
|
}
|
|
|
|
const Function* called_func = function(called_func_id);
|
|
if (called_func) {
|
|
// Other checks should error out on this invalid SPIR-V.
|
|
for (const uint32_t new_call : called_func->function_call_targets()) {
|
|
call_stack.push(new_call);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
|
|
uint32_t func) const {
|
|
auto iter = function_to_entry_points_.find(func);
|
|
if (iter == function_to_entry_points_.end()) {
|
|
return empty_ids_;
|
|
} else {
|
|
return iter->second;
|
|
}
|
|
}
|
|
|
|
std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
|
|
std::set<uint32_t> referenced_entry_points;
|
|
const auto inst = FindDef(id);
|
|
if (!inst) return referenced_entry_points;
|
|
|
|
std::vector<const Instruction*> stack;
|
|
stack.push_back(inst);
|
|
while (!stack.empty()) {
|
|
const auto current_inst = stack.back();
|
|
stack.pop_back();
|
|
|
|
if (const auto func = current_inst->function()) {
|
|
// Instruction lives in a function, we can stop searching.
|
|
const auto function_entry_points = FunctionEntryPoints(func->id());
|
|
referenced_entry_points.insert(function_entry_points.begin(),
|
|
function_entry_points.end());
|
|
} else {
|
|
// Instruction is in the global scope, keep searching its uses.
|
|
for (auto pair : current_inst->uses()) {
|
|
const auto next_inst = pair.first;
|
|
stack.push_back(next_inst);
|
|
}
|
|
}
|
|
}
|
|
|
|
return referenced_entry_points;
|
|
}
|
|
|
|
std::string ValidationState_t::Disassemble(const Instruction& inst) const {
|
|
const spv_parsed_instruction_t& c_inst(inst.c_inst());
|
|
return Disassemble(c_inst.words, c_inst.num_words);
|
|
}
|
|
|
|
std::string ValidationState_t::Disassemble(const uint32_t* words,
|
|
uint16_t num_words) const {
|
|
uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
|
|
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
|
|
|
|
return spvInstructionBinaryToText(context()->target_env, words, num_words,
|
|
words_, num_words_, disassembly_options);
|
|
}
|
|
|
|
bool ValidationState_t::LogicallyMatch(const Instruction* lhs,
|
|
const Instruction* rhs,
|
|
bool check_decorations) {
|
|
if (lhs->opcode() != rhs->opcode()) {
|
|
return false;
|
|
}
|
|
|
|
if (check_decorations) {
|
|
const auto& dec_a = id_decorations(lhs->id());
|
|
const auto& dec_b = id_decorations(rhs->id());
|
|
|
|
for (const auto& dec : dec_b) {
|
|
if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (lhs->opcode() == SpvOpTypeArray) {
|
|
// Size operands must match.
|
|
if (lhs->GetOperandAs<uint32_t>(2u) != rhs->GetOperandAs<uint32_t>(2u)) {
|
|
return false;
|
|
}
|
|
|
|
// Elements must match or logically match.
|
|
const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(1u);
|
|
const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(1u);
|
|
if (lhs_ele_id == rhs_ele_id) {
|
|
return true;
|
|
}
|
|
|
|
const auto lhs_ele = FindDef(lhs_ele_id);
|
|
const auto rhs_ele = FindDef(rhs_ele_id);
|
|
if (!lhs_ele || !rhs_ele) {
|
|
return false;
|
|
}
|
|
return LogicallyMatch(lhs_ele, rhs_ele, check_decorations);
|
|
} else if (lhs->opcode() == SpvOpTypeStruct) {
|
|
// Number of elements must match.
|
|
if (lhs->operands().size() != rhs->operands().size()) {
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 1u; i < lhs->operands().size(); ++i) {
|
|
const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(i);
|
|
const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(i);
|
|
// Elements must match or logically match.
|
|
if (lhs_ele_id == rhs_ele_id) {
|
|
continue;
|
|
}
|
|
|
|
const auto lhs_ele = FindDef(lhs_ele_id);
|
|
const auto rhs_ele = FindDef(rhs_ele_id);
|
|
if (!lhs_ele || !rhs_ele) {
|
|
return false;
|
|
}
|
|
|
|
if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// All checks passed.
|
|
return true;
|
|
}
|
|
|
|
// No other opcodes are acceptable at this point. Arrays and structs are
|
|
// caught above and if they're elements are not arrays or structs they are
|
|
// required to match exactly.
|
|
return false;
|
|
}
|
|
|
|
const Instruction* ValidationState_t::TracePointer(
|
|
const Instruction* inst) const {
|
|
auto base_ptr = inst;
|
|
while (base_ptr->opcode() == SpvOpAccessChain ||
|
|
base_ptr->opcode() == SpvOpInBoundsAccessChain ||
|
|
base_ptr->opcode() == SpvOpPtrAccessChain ||
|
|
base_ptr->opcode() == SpvOpInBoundsPtrAccessChain ||
|
|
base_ptr->opcode() == SpvOpCopyObject) {
|
|
base_ptr = FindDef(base_ptr->GetOperandAs<uint32_t>(2u));
|
|
}
|
|
return base_ptr;
|
|
}
|
|
|
|
bool ValidationState_t::ContainsSizedIntOrFloatType(uint32_t id, SpvOp type,
|
|
uint32_t width) const {
|
|
if (type != SpvOpTypeInt && type != SpvOpTypeFloat) return false;
|
|
|
|
const auto inst = FindDef(id);
|
|
if (!inst) return false;
|
|
|
|
if (inst->opcode() == type) {
|
|
return inst->GetOperandAs<uint32_t>(1u) == width;
|
|
}
|
|
|
|
switch (inst->opcode()) {
|
|
case SpvOpTypeArray:
|
|
case SpvOpTypeRuntimeArray:
|
|
case SpvOpTypeVector:
|
|
case SpvOpTypeMatrix:
|
|
case SpvOpTypeImage:
|
|
case SpvOpTypeSampledImage:
|
|
case SpvOpTypeCooperativeMatrixNV:
|
|
return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(1u), type,
|
|
width);
|
|
case SpvOpTypePointer:
|
|
if (IsForwardPointer(id)) return false;
|
|
return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(2u), type,
|
|
width);
|
|
case SpvOpTypeFunction:
|
|
case SpvOpTypeStruct: {
|
|
for (uint32_t i = 1; i < inst->operands().size(); ++i) {
|
|
if (ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(i), type,
|
|
width))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool ValidationState_t::ContainsLimitedUseIntOrFloatType(uint32_t id) const {
|
|
if ((!HasCapability(SpvCapabilityInt16) &&
|
|
ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 16)) ||
|
|
(!HasCapability(SpvCapabilityInt8) &&
|
|
ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 8)) ||
|
|
(!HasCapability(SpvCapabilityFloat16) &&
|
|
ContainsSizedIntOrFloatType(id, SpvOpTypeFloat, 16))) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ValidationState_t::IsValidStorageClass(
|
|
SpvStorageClass storage_class) const {
|
|
if (spvIsWebGPUEnv(context()->target_env)) {
|
|
switch (storage_class) {
|
|
case SpvStorageClassUniformConstant:
|
|
case SpvStorageClassUniform:
|
|
case SpvStorageClassStorageBuffer:
|
|
case SpvStorageClassInput:
|
|
case SpvStorageClassOutput:
|
|
case SpvStorageClassImage:
|
|
case SpvStorageClassWorkgroup:
|
|
case SpvStorageClassPrivate:
|
|
case SpvStorageClassFunction:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (spvIsVulkanEnv(context()->target_env)) {
|
|
switch (storage_class) {
|
|
case SpvStorageClassUniformConstant:
|
|
case SpvStorageClassUniform:
|
|
case SpvStorageClassStorageBuffer:
|
|
case SpvStorageClassInput:
|
|
case SpvStorageClassOutput:
|
|
case SpvStorageClassImage:
|
|
case SpvStorageClassWorkgroup:
|
|
case SpvStorageClassPrivate:
|
|
case SpvStorageClassFunction:
|
|
case SpvStorageClassPushConstant:
|
|
case SpvStorageClassPhysicalStorageBuffer:
|
|
case SpvStorageClassRayPayloadNV:
|
|
case SpvStorageClassIncomingRayPayloadNV:
|
|
case SpvStorageClassHitAttributeNV:
|
|
case SpvStorageClassCallableDataNV:
|
|
case SpvStorageClassIncomingCallableDataNV:
|
|
case SpvStorageClassShaderRecordBufferNV:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace val
|
|
} // namespace spvtools
|