spirv-val: Make Constant evaluation consistent (#5587)

Bring 64-bit evaluation in line with 32-bit evaluation.
This commit is contained in:
Spencer Fricke 2024-02-22 07:52:13 +09:00 committed by GitHub
parent dc6676445b
commit 1b643eac5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 78 additions and 70 deletions

View File

@ -1120,7 +1120,7 @@ spv_result_t BuiltInsValidator::ValidateF32ArrHelper(
if (num_components != 0) {
uint64_t actual_num_components = 0;
if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) {
if (!_.EvalConstantValUint64(type_inst->word(3), &actual_num_components)) {
assert(0 && "Array type definition is corrupt");
}
if (actual_num_components != num_components) {

View File

@ -94,7 +94,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
break;
}
if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
if (component_index >= array_size) {
@ -289,7 +289,7 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
}
uint64_t array_size = 0;
if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}

View File

@ -3100,7 +3100,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
uint32_t vector_count = inst->word(6);
uint64_t const_val;
if (!_.GetConstantValUint64(vector_count, &const_val)) {
if (!_.EvalConstantValUint64(vector_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Vector Count must be 32-bit integer OpConstant";
@ -3191,7 +3191,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
uint32_t component_count = inst->word(6);
if (vulkanDebugInfo) {
uint64_t const_val;
if (!_.GetConstantValUint64(component_count, &const_val)) {
if (!_.EvalConstantValUint64(component_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Component Count must be 32-bit integer OpConstant";

View File

@ -495,7 +495,7 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
}
uint64_t array_size = 0;
if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
@ -1210,7 +1210,7 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
if (info.multisampled == 0) {
uint64_t ms = 0;
if (!_.GetConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
ms != 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Sample for Image with MS 0 to be a valid <id> for "

View File

@ -1374,22 +1374,18 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
case spv::Op::OpTypeStruct: {
// In case of structures, there is an additional constraint on the
// index: the index must be an OpConstant.
if (spv::Op::OpConstant != cur_word_instr->opcode()) {
int64_t cur_index;
if (!_.EvalConstantValInt64(cur_word, &cur_index)) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "The <id> passed to " << instr_name
<< " to index into a "
"structure must be an OpConstant.";
}
// Get the index value from the OpConstant (word 3 of OpConstant).
// OpConstant could be a signed integer. But it's okay to treat it as
// unsigned because a negative constant int would never be seen as
// correct as a struct offset, since structs can't have more than 2
// billion members.
const uint32_t cur_index = cur_word_instr->word(3);
// The index points to the struct member we want, therefore, the index
// should be less than the number of struct members.
const uint32_t num_struct_members =
static_cast<uint32_t>(type_pointee->words().size() - 2);
const int64_t num_struct_members =
static_cast<int64_t>(type_pointee->words().size() - 2);
if (cur_index >= num_struct_members) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "Index is out of bounds: " << instr_name
@ -1400,7 +1396,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
<< num_struct_members - 1 << ".";
}
// Struct members IDs start at word 2 of OpTypeStruct.
auto structMemberId = type_pointee->word(cur_index + 2);
const size_t word_index = static_cast<size_t>(cur_index) + 2;
auto structMemberId = type_pointee->word(word_index);
type_pointee = _.FindDef(structMemberId);
break;
}

View File

@ -389,20 +389,25 @@ spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _,
if (inst->words().size() > 6) {
const uint32_t cluster_size_op_id = inst->GetOperandAs<uint32_t>(5);
const uint32_t cluster_size_type = _.GetTypeId(cluster_size_op_id);
const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id);
const uint32_t cluster_size_type =
cluster_size_inst ? cluster_size_inst->type_id() : 0;
if (!_.IsUnsignedIntScalarType(cluster_size_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must be a scalar of integer type, whose "
"Signedness operand is 0.";
}
uint64_t cluster_size;
if (!_.GetConstantValUint64(cluster_size_op_id, &cluster_size)) {
if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must come from a constant instruction.";
}
if ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0)) {
uint64_t cluster_size;
const bool valid_const =
_.EvalConstantValUint64(cluster_size_op_id, &cluster_size);
if (valid_const &&
((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) {
return _.diag(SPV_WARNING, inst)
<< "Behavior is undefined unless ClusterSize is at least 1 and a "
"power of 2.";

View File

@ -24,21 +24,6 @@ namespace spvtools {
namespace val {
namespace {
// Returns, as an int64_t, the literal value from an OpConstant or the
// default value of an OpSpecConstant, assuming it is an integral type.
// For signed integers, relies the rule that literal value is sign extended
// to fill out to word granularity. Assumes that the constant value
// has
int64_t ConstantLiteralAsInt64(uint32_t width,
const std::vector<uint32_t>& const_words) {
const uint32_t lo_word = const_words[3];
if (width <= 32) return int32_t(lo_word);
assert(width <= 64);
assert(const_words.size() > 4);
const uint32_t hi_word = const_words[4]; // Must exist, per spec.
return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
}
// Validates that type declarations are unique, unless multiple declarations
// of the same data type are allowed by the specification.
// (see section 2.8 Types and Variables)
@ -252,29 +237,17 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
<< " is not a constant integer type.";
}
switch (length->opcode()) {
case spv::Op::OpSpecConstant:
case spv::Op::OpConstant: {
int64_t length_value;
if (_.EvalConstantValInt64(length_id, &length_value)) {
auto& type_words = const_result_type->words();
const bool is_signed = type_words[3] > 0;
const uint32_t width = type_words[2];
const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
if (ivalue == 0 || (ivalue < 0 && is_signed)) {
if (length_value == 0 || (length_value < 0 && is_signed)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeArray Length <id> " << _.getIdName(length_id)
<< " default value must be at least 1: found " << ivalue;
<< " default value must be at least 1: found " << length_value;
}
} break;
case spv::Op::OpConstantNull:
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeArray Length <id> " << _.getIdName(length_id)
<< " default value must be at least 1.";
case spv::Op::OpSpecConstantOp:
// Assume it's OK, rather than try to evaluate the operation.
break;
default:
assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
}
return SPV_SUCCESS;
}

View File

@ -1209,7 +1209,7 @@ bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
}
@ -1220,7 +1220,7 @@ bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
}
@ -1230,7 +1230,7 @@ bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse == static_cast<uint64_t>(
spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
}
@ -1340,20 +1340,23 @@ uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
}
bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
bool ValidationState_t::EvalConstantValUint64(uint32_t id,
uint64_t* val) const {
const Instruction* inst = FindDef(id);
if (!inst) {
assert(0 && "Instruction not found");
return false;
}
if (inst->opcode() != spv::Op::OpConstant &&
inst->opcode() != spv::Op::OpSpecConstant)
return false;
if (!IsIntScalarType(inst->type_id())) return false;
if (inst->words().size() == 4) {
if (inst->opcode() == spv::Op::OpConstantNull) {
*val = 0;
} else if (inst->opcode() != spv::Op::OpConstant) {
// Spec constant values cannot be evaluated so don't consider constant for
// static validation
return false;
} else if (inst->words().size() == 4) {
*val = inst->word(3);
} else {
assert(inst->words().size() == 5);
@ -1363,6 +1366,32 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
return true;
}
bool ValidationState_t::EvalConstantValInt64(uint32_t id, int64_t* val) const {
const Instruction* inst = FindDef(id);
if (!inst) {
assert(0 && "Instruction not found");
return false;
}
if (!IsIntScalarType(inst->type_id())) return false;
if (inst->opcode() == spv::Op::OpConstantNull) {
*val = 0;
} else if (inst->opcode() != spv::Op::OpConstant) {
// Spec constant values cannot be evaluated so don't consider constant for
// static validation
return false;
} else if (inst->words().size() == 4) {
*val = int32_t(inst->word(3));
} else {
assert(inst->words().size() == 5);
const uint32_t lo_word = inst->word(3);
const uint32_t hi_word = inst->word(4);
*val = static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
}
return true;
}
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
uint32_t id) const {
const Instruction* const inst = FindDef(id);

View File

@ -648,10 +648,6 @@ class ValidationState_t {
const std::function<bool(const Instruction*)>& f,
bool traverse_all_types = true) const;
// Gets value from OpConstant and OpSpecConstant as uint64.
// Returns false on failure (no instruction, wrong instruction, not int).
bool GetConstantValUint64(uint32_t id, uint64_t* val) const;
// Returns type_id if id has type or zero otherwise.
uint32_t GetTypeId(uint32_t id) const;
@ -726,6 +722,14 @@ class ValidationState_t {
pointer_to_storage_image_.insert(type_id);
}
// Tries to evaluate a any scalar integer OpConstant as uint64.
// OpConstantNull is defined as zero for scalar int (will return true)
// OpSpecConstant* return false since their values cannot be relied upon
// during validation.
bool EvalConstantValUint64(uint32_t id, uint64_t* val) const;
// Same as EvalConstantValUint64 but returns a signed int
bool EvalConstantValInt64(uint32_t id, int64_t* val) const;
// Tries to evaluate a 32-bit signed or unsigned scalar integer constant.
// Returns tuple <is_int32, is_const_int32, value>.
// OpSpecConstant* return |is_const_int32| as false since their values cannot

View File

@ -1056,7 +1056,7 @@ TEST_P(ValidateIdWithMessage, OpTypeArrayLengthNull) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr(make_message("OpTypeArray Length <id> '2[%2]' default "
"value must be at least 1.")));
"value must be at least 1: found 0")));
}
TEST_P(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) {