SPIRV-Tools/source/val/validate_constants.cpp

469 lines
18 KiB
C++
Raw Normal View History

// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidateConstantBool(ValidationState_t& _,
const Instruction* inst) {
auto type = _.FindDef(inst->type_id());
if (!type || type->opcode() != SpvOpTypeBool) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a boolean type.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateConstantComposite(ValidationState_t& _,
const Instruction* inst) {
std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a composite type.";
}
const auto constituent_count = inst->words().size() - 3;
switch (result_type->opcode()) {
case SpvOpTypeVector: {
const auto component_count = result_type->GetOperandAs<uint32_t>(2);
if (component_count != constituent_count) {
// TODO: Output ID's on diagnostic
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent <id> count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s vector component count.";
}
const auto component_type =
_.FindDef(result_type->GetOperandAs<uint32_t>(1));
if (!component_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Component type is not defined.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_result_type = _.FindDef(constituent->type_id());
if (!constituent_result_type ||
component_type->opcode() != constituent_result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "'s type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s vector element type.";
}
}
} break;
case SpvOpTypeMatrix: {
const auto column_count = result_type->GetOperandAs<uint32_t>(2);
if (column_count != constituent_count) {
// TODO: Output ID's on diagnostic
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent <id> count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s matrix column count.";
}
const auto column_type = _.FindDef(result_type->words()[2]);
if (!column_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Column type is not defined.";
}
const auto component_count = column_type->GetOperandAs<uint32_t>(2);
const auto component_type =
_.FindDef(column_type->GetOperandAs<uint32_t>(1));
if (!component_type) {
return _.diag(SPV_ERROR_INVALID_ID, column_type)
<< "Component type is not defined.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
// The message says "... or undef" because the spec does not say
// undef is a constant.
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto vector = _.FindDef(constituent->type_id());
if (!vector) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
if (column_type->opcode() != vector->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s matrix column type.";
}
const auto vector_component_type =
_.FindDef(vector->GetOperandAs<uint32_t>(1));
if (component_type->id() != vector_component_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' component type does not match Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s matrix column component type.";
}
if (component_count != vector->words()[3]) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' vector component count does not match Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s vector component count.";
}
}
} break;
case SpvOpTypeArray: {
auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
if (!element_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Element type is not defined.";
}
const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
if (!length) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Length is not defined.";
}
bool is_int32;
bool is_const;
uint32_t value;
std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
if (is_int32 && is_const && value != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s array length.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
if (element_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "'s type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s array element type.";
}
}
} break;
case SpvOpTypeStruct: {
const auto member_count = result_type->words().size() - 2;
if (member_count != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(inst->type_id())
<< "' count does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s struct member count.";
}
for (uint32_t constituent_index = 2, member_index = 1;
constituent_index < inst->operands().size();
constituent_index++, member_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
const auto member_type_id =
result_type->GetOperandAs<uint32_t>(member_index);
const auto member_type = _.FindDef(member_type_id);
if (!member_type || member_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match the Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s member type.";
}
}
} break;
case SpvOpTypeCooperativeMatrixNV: {
if (1 != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(inst->type_id()) << "' count must be one.";
}
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
const auto constituent = _.FindDef(constituent_id);
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
const auto component_type = _.FindDef(component_type_id);
if (!component_type || component_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match the Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s component type.";
}
} break;
default:
break;
}
return SPV_SUCCESS;
}
spv_result_t ValidateConstantSampler(ValidationState_t& _,
const Instruction* inst) {
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "OpConstantSampler Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a sampler type.";
}
return SPV_SUCCESS;
}
// True if instruction defines a type that can have a null value, as defined by
// the SPIR-V spec. Tracks composite-type components through module to check
// nullability transitively.
bool IsTypeNullable(const std::vector<uint32_t>& instruction,
const ValidationState_t& _) {
uint16_t opcode;
uint16_t word_count;
spvOpcodeSplit(instruction[0], &word_count, &opcode);
switch (static_cast<SpvOp>(opcode)) {
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypeEvent:
case SpvOpTypeDeviceEvent:
case SpvOpTypeReserveId:
case SpvOpTypeQueue:
return true;
case SpvOpTypeArray:
case SpvOpTypeMatrix:
case SpvOpTypeCooperativeMatrixNV:
case SpvOpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
}
case SpvOpTypeStruct: {
for (size_t elementIndex = 2; elementIndex < instruction.size();
++elementIndex) {
auto element = _.FindDef(instruction[elementIndex]);
if (!element || !IsTypeNullable(element->words(), _)) return false;
}
return true;
}
case SpvOpTypePointer:
if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) {
return false;
}
return true;
default:
return false;
}
}
spv_result_t ValidateConstantNull(ValidationState_t& _,
const Instruction* inst) {
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || !IsTypeNullable(result_type->words(), _)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpConstantNull Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' cannot have a null value.";
}
return SPV_SUCCESS;
}
// Validates that OpSpecConstant specializes to either int or float type.
spv_result_t ValidateSpecConstant(ValidationState_t& _,
const Instruction* inst) {
// Operand 0 is the <id> of the type that we're specializing to.
auto type_id = inst->GetOperandAs<const uint32_t>(0);
auto type_instruction = _.FindDef(type_id);
auto type_opcode = type_instruction->opcode();
if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
"must be an integer or "
"floating-point number.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
const Instruction* inst) {
const auto op = inst->GetOperandAs<SpvOp>(2);
// The binary parser already ensures that the op is valid for *some*
// environment. Here we check restrictions.
switch (op) {
case SpvOpQuantizeToF16:
if (!_.HasCapability(SpvCapabilityShader)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Specialization constant operation " << spvOpcodeString(op)
<< " requires Shader capability";
}
break;
case SpvOpUConvert:
if (!_.features().uconvert_spec_constant_op &&
!_.HasCapability(SpvCapabilityKernel)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
Support SPIR-V 1.4 (#2550) * SPIR-V 1.4 headers, add SPV_ENV_UNIVERSAL_1_4 * Support --target-env spv1.4 in help for command line tools * Support asm/dis of UniformId decoration * Validate UniformId decoration * Fix version check on instructions and operands Also register decorations used with OpDecorateId * Extension lists can differ between enums that match Example: SubgroupMaskEq vs SubgroupMaskEqKHR * Validate scope value for Uniform decoration, for SPIR-V 1.4 * More unioning of exts * Preserve grammar order within an enum value * 1.4: Validate OpSelect over composites * Tools default to 1.4 * Add asm/dis test for OpCopyLogical * 1.4: asm/dis tests for PtrEqual, PtrNotEqual, PtrDiff * Basic asm/Dis test for OpCopyMemory * Test asm/dis OpCopyMemory with 2-memory access Add asm/dis tests for OpCopyMemorySized Requires grammar update to add second optional memory access operand to OpCopyMemory and OpCopyMemorySized * Validate one or two memory accesses on OpCopyMemory* * Check av/vis on CopyMemory source and target memory access This is a proposed rule. See https://gitlab.khronos.org/spirv/SPIR-V/issues/413 * Validate operation for OpSpecConstantOp * Validate NonWritable decoration Also permit NonWritable on members of UBO and SSBO. * SPIR-V 1.4: NonWrtiable can decorate Function and Private vars * Update optimizer CLI tests for SPIR-V 1.4 * Testing tools: Give expected SPIR-V version in message * SPIR-V 1.4 validation for entry point interfaces * Allow only unique interfaces * Allow all global variables * Check that all statically used global variables are listed * new tests * Add validation fixture CompileFailure * Add 1.4 validation for pointer comparisons * New tests * Validate with image operands SignExtend, ZeroExtend Since we don't actually know the image texel format, we can't fully validate. We need more context. But we can make sure we allow the new image operands in known-good cases. * Validate OpCopyLogical * Recursively checks subtypes * new tests * Add SPIR-V 1.4 tests for NoSignedWrap, NoUnsignedWrap * Allow scalar conditions in 1.4 with OpSelect * Allows scalar conditions with vector operands * new tests * Validate uniform id scope as an execution scope * Validate the values of memory and execution scopes are valid scope values * new test * Remove SPIR-V 1.4 Vulkan 1.0 environment * SPIR-V 1.4 requires Vulkan 1.1 * FIX: include string for spvLog * FIX: validate nonwritable * FIX: test case suite for member decorate string * FIX: test case for hlsl functionality1 * Validation test fixture: ease debugging * Use binary version for SPIR-V 1.4 specific features * Switch checks based on the SPIR-V version from the target environment to instead use the version from the binary * Moved header parsing into the ValidationState_t constructor (where version based features are set) * Added new versions of tests that assemble a 1.3 binary and validate a 1.4 environment * Fix test for update to SPIR-V 1.4 headers * Fix formatting * Ext inst lookup: Add Vulkan 1.1 env with SPIR-V 1.4 * Update spirv-val help * Operand version checks should use module version Use the module version instead of the target environment version. * Fix comment about two-access form of OpCopyMemory
2019-05-07 16:27:18 +00:00
<< "Prior to SPIR-V 1.4, specialization constant operation "
"UConvert requires Kernel capability or extension "
"SPV_AMD_gpu_shader_int16";
}
break;
case SpvOpConvertFToS:
case SpvOpConvertSToF:
case SpvOpConvertFToU:
case SpvOpConvertUToF:
case SpvOpConvertPtrToU:
case SpvOpConvertUToPtr:
case SpvOpGenericCastToPtr:
case SpvOpPtrCastToGeneric:
case SpvOpBitcast:
case SpvOpFNegate:
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpPtrAccessChain:
case SpvOpInBoundsPtrAccessChain:
if (!_.HasCapability(SpvCapabilityKernel)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Specialization constant operation " << spvOpcodeString(op)
<< " requires Kernel capability";
}
break;
default:
break;
}
// TODO(dneto): Validate result type and arguments to the various operations.
return SPV_SUCCESS;
}
} // namespace
spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
if (auto error = ValidateConstantBool(_, inst)) return error;
break;
case SpvOpConstantComposite:
case SpvOpSpecConstantComposite:
if (auto error = ValidateConstantComposite(_, inst)) return error;
break;
case SpvOpConstantSampler:
if (auto error = ValidateConstantSampler(_, inst)) return error;
break;
case SpvOpConstantNull:
if (auto error = ValidateConstantNull(_, inst)) return error;
break;
case SpvOpSpecConstant:
if (auto error = ValidateSpecConstant(_, inst)) return error;
break;
case SpvOpSpecConstantOp:
if (auto error = ValidateSpecConstantOp(_, inst)) return error;
break;
default:
break;
}
// Generally disallow creating 8- or 16-bit constants unless the full
// capabilities are present.
if (spvOpcodeIsConstant(inst->opcode()) &&
_.HasCapability(SpvCapabilityShader) &&
!_.IsPointerType(inst->type_id()) &&
_.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Cannot form constants of 8- or 16-bit types";
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools