diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp index 42fbc52a7..a7e9942a0 100644 --- a/source/val/validate_builtins.cpp +++ b/source/val/validate_builtins.cpp @@ -1120,7 +1120,7 @@ spv_result_t BuiltInsValidator::ValidateF32ArrHelper( if (num_components != 0) { uint64_t actual_num_components = 0; - if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) { + if (!_.EvalConstantValUint64(type_inst->word(3), &actual_num_components)) { assert(0 && "Array type definition is corrupt"); } if (actual_num_components != num_components) { diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp index ed043b688..26486dac7 100644 --- a/source/val/validate_composites.cpp +++ b/source/val/validate_composites.cpp @@ -94,7 +94,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _, break; } - if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) { assert(0 && "Array type definition is corrupt"); } if (component_index >= array_size) { @@ -289,7 +289,7 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _, } uint64_t array_size = 0; - if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) { + if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) { assert(0 && "Array type definition is corrupt"); } diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index 0334b6064..7b73c9c6e 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp @@ -3100,7 +3100,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { uint32_t vector_count = inst->word(6); uint64_t const_val; - if (!_.GetConstantValUint64(vector_count, &const_val)) { + if (!_.EvalConstantValUint64(vector_count, &const_val)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": Vector Count must be 32-bit integer OpConstant"; @@ -3191,7 +3191,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { uint32_t component_count = inst->word(6); if (vulkanDebugInfo) { uint64_t const_val; - if (!_.GetConstantValUint64(component_count, &const_val)) { + if (!_.EvalConstantValUint64(component_count, &const_val)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": Component Count must be 32-bit integer OpConstant"; diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp index 46a32f241..543f345ec 100644 --- a/source/val/validate_image.cpp +++ b/source/val/validate_image.cpp @@ -495,7 +495,7 @@ spv_result_t ValidateImageOperands(ValidationState_t& _, } uint64_t array_size = 0; - if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) { assert(0 && "Array type definition is corrupt"); } @@ -1210,7 +1210,7 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _, if (info.multisampled == 0) { uint64_t ms = 0; - if (!_.GetConstantValUint64(inst->GetOperandAs(4), &ms) || + if (!_.EvalConstantValUint64(inst->GetOperandAs(4), &ms) || ms != 0) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Sample for Image with MS 0 to be a valid for " diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 5b25eeb3c..41dd71e9f 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -1374,22 +1374,18 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, case spv::Op::OpTypeStruct: { // In case of structures, there is an additional constraint on the // index: the index must be an OpConstant. - if (spv::Op::OpConstant != cur_word_instr->opcode()) { + int64_t cur_index; + if (!_.EvalConstantValInt64(cur_word, &cur_index)) { return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) << "The passed to " << instr_name << " to index into a " "structure must be an OpConstant."; } - // Get the index value from the OpConstant (word 3 of OpConstant). - // OpConstant could be a signed integer. But it's okay to treat it as - // unsigned because a negative constant int would never be seen as - // correct as a struct offset, since structs can't have more than 2 - // billion members. - const uint32_t cur_index = cur_word_instr->word(3); + // The index points to the struct member we want, therefore, the index // should be less than the number of struct members. - const uint32_t num_struct_members = - static_cast(type_pointee->words().size() - 2); + const int64_t num_struct_members = + static_cast(type_pointee->words().size() - 2); if (cur_index >= num_struct_members) { return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) << "Index is out of bounds: " << instr_name @@ -1400,7 +1396,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, << num_struct_members - 1 << "."; } // Struct members IDs start at word 2 of OpTypeStruct. - auto structMemberId = type_pointee->word(cur_index + 2); + const size_t word_index = static_cast(cur_index) + 2; + auto structMemberId = type_pointee->word(word_index); type_pointee = _.FindDef(structMemberId); break; } diff --git a/source/val/validate_non_uniform.cpp b/source/val/validate_non_uniform.cpp index 74449e9dc..75967d2ff 100644 --- a/source/val/validate_non_uniform.cpp +++ b/source/val/validate_non_uniform.cpp @@ -389,20 +389,25 @@ spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _, if (inst->words().size() > 6) { const uint32_t cluster_size_op_id = inst->GetOperandAs(5); - const uint32_t cluster_size_type = _.GetTypeId(cluster_size_op_id); + const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id); + const uint32_t cluster_size_type = + cluster_size_inst ? cluster_size_inst->type_id() : 0; if (!_.IsUnsignedIntScalarType(cluster_size_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "ClusterSize must be a scalar of integer type, whose " "Signedness operand is 0."; } - uint64_t cluster_size; - if (!_.GetConstantValUint64(cluster_size_op_id, &cluster_size)) { + if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "ClusterSize must come from a constant instruction."; } - if ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0)) { + uint64_t cluster_size; + const bool valid_const = + _.EvalConstantValUint64(cluster_size_op_id, &cluster_size); + if (valid_const && + ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) { return _.diag(SPV_WARNING, inst) << "Behavior is undefined unless ClusterSize is at least 1 and a " "power of 2."; diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index 7edd12ffa..cb26a527c 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -24,21 +24,6 @@ namespace spvtools { namespace val { namespace { -// Returns, as an int64_t, the literal value from an OpConstant or the -// default value of an OpSpecConstant, assuming it is an integral type. -// For signed integers, relies the rule that literal value is sign extended -// to fill out to word granularity. Assumes that the constant value -// has -int64_t ConstantLiteralAsInt64(uint32_t width, - const std::vector& const_words) { - const uint32_t lo_word = const_words[3]; - if (width <= 32) return int32_t(lo_word); - assert(width <= 64); - assert(const_words.size() > 4); - const uint32_t hi_word = const_words[4]; // Must exist, per spec. - return static_cast(uint64_t(lo_word) | uint64_t(hi_word) << 32); -} - // Validates that type declarations are unique, unless multiple declarations // of the same data type are allowed by the specification. // (see section 2.8 Types and Variables) @@ -252,29 +237,17 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) { << " is not a constant integer type."; } - switch (length->opcode()) { - case spv::Op::OpSpecConstant: - case spv::Op::OpConstant: { - auto& type_words = const_result_type->words(); - const bool is_signed = type_words[3] > 0; - const uint32_t width = type_words[2]; - const int64_t ivalue = ConstantLiteralAsInt64(width, length->words()); - if (ivalue == 0 || (ivalue < 0 && is_signed)) { - return _.diag(SPV_ERROR_INVALID_ID, inst) - << "OpTypeArray Length " << _.getIdName(length_id) - << " default value must be at least 1: found " << ivalue; - } - } break; - case spv::Op::OpConstantNull: + int64_t length_value; + if (_.EvalConstantValInt64(length_id, &length_value)) { + auto& type_words = const_result_type->words(); + const bool is_signed = type_words[3] > 0; + if (length_value == 0 || (length_value < 0 && is_signed)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpTypeArray Length " << _.getIdName(length_id) - << " default value must be at least 1."; - case spv::Op::OpSpecConstantOp: - // Assume it's OK, rather than try to evaluate the operation. - break; - default: - assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int"); + << " default value must be at least 1: found " << length_value; + } } + return SPV_SUCCESS; } diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 25b374de1..fa5ae3e00 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -1209,7 +1209,7 @@ bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const { if (!IsCooperativeMatrixKHRType(id)) return false; const Instruction* inst = FindDef(id); uint64_t matrixUse = 0; - if (GetConstantValUint64(inst->word(6), &matrixUse)) { + if (EvalConstantValUint64(inst->word(6), &matrixUse)) { return matrixUse == static_cast(spv::CooperativeMatrixUse::MatrixAKHR); } @@ -1220,7 +1220,7 @@ bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const { if (!IsCooperativeMatrixKHRType(id)) return false; const Instruction* inst = FindDef(id); uint64_t matrixUse = 0; - if (GetConstantValUint64(inst->word(6), &matrixUse)) { + if (EvalConstantValUint64(inst->word(6), &matrixUse)) { return matrixUse == static_cast(spv::CooperativeMatrixUse::MatrixBKHR); } @@ -1230,7 +1230,7 @@ bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const { if (!IsCooperativeMatrixKHRType(id)) return false; const Instruction* inst = FindDef(id); uint64_t matrixUse = 0; - if (GetConstantValUint64(inst->word(6), &matrixUse)) { + if (EvalConstantValUint64(inst->word(6), &matrixUse)) { return matrixUse == static_cast( spv::CooperativeMatrixUse::MatrixAccumulatorKHR); } @@ -1340,20 +1340,23 @@ uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst, return GetTypeId(inst->GetOperandAs(operand_index)); } -bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const { +bool ValidationState_t::EvalConstantValUint64(uint32_t id, + uint64_t* val) const { const Instruction* inst = FindDef(id); if (!inst) { assert(0 && "Instruction not found"); return false; } - if (inst->opcode() != spv::Op::OpConstant && - inst->opcode() != spv::Op::OpSpecConstant) - return false; - if (!IsIntScalarType(inst->type_id())) return false; - if (inst->words().size() == 4) { + if (inst->opcode() == spv::Op::OpConstantNull) { + *val = 0; + } else if (inst->opcode() != spv::Op::OpConstant) { + // Spec constant values cannot be evaluated so don't consider constant for + // static validation + return false; + } else if (inst->words().size() == 4) { *val = inst->word(3); } else { assert(inst->words().size() == 5); @@ -1363,6 +1366,32 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const { return true; } +bool ValidationState_t::EvalConstantValInt64(uint32_t id, int64_t* val) const { + const Instruction* inst = FindDef(id); + if (!inst) { + assert(0 && "Instruction not found"); + return false; + } + + if (!IsIntScalarType(inst->type_id())) return false; + + if (inst->opcode() == spv::Op::OpConstantNull) { + *val = 0; + } else if (inst->opcode() != spv::Op::OpConstant) { + // Spec constant values cannot be evaluated so don't consider constant for + // static validation + return false; + } else if (inst->words().size() == 4) { + *val = int32_t(inst->word(3)); + } else { + assert(inst->words().size() == 5); + const uint32_t lo_word = inst->word(3); + const uint32_t hi_word = inst->word(4); + *val = static_cast(uint64_t(lo_word) | uint64_t(hi_word) << 32); + } + return true; +} + std::tuple ValidationState_t::EvalInt32IfConst( uint32_t id) const { const Instruction* const inst = FindDef(id); diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 46a8cbfa6..27acdcc2f 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -648,10 +648,6 @@ class ValidationState_t { const std::function& f, bool traverse_all_types = true) const; - // Gets value from OpConstant and OpSpecConstant as uint64. - // Returns false on failure (no instruction, wrong instruction, not int). - bool GetConstantValUint64(uint32_t id, uint64_t* val) const; - // Returns type_id if id has type or zero otherwise. uint32_t GetTypeId(uint32_t id) const; @@ -726,6 +722,14 @@ class ValidationState_t { pointer_to_storage_image_.insert(type_id); } + // Tries to evaluate a any scalar integer OpConstant as uint64. + // OpConstantNull is defined as zero for scalar int (will return true) + // OpSpecConstant* return false since their values cannot be relied upon + // during validation. + bool EvalConstantValUint64(uint32_t id, uint64_t* val) const; + // Same as EvalConstantValUint64 but returns a signed int + bool EvalConstantValInt64(uint32_t id, int64_t* val) const; + // Tries to evaluate a 32-bit signed or unsigned scalar integer constant. // Returns tuple . // OpSpecConstant* return |is_const_int32| as false since their values cannot diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index 7acac563e..e236134b7 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp @@ -1056,7 +1056,7 @@ TEST_P(ValidateIdWithMessage, OpTypeArrayLengthNull) { EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(make_message("OpTypeArray Length '2[%2]' default " - "value must be at least 1."))); + "value must be at least 1: found 0"))); } TEST_P(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) {