SPV_KHR_cooperative_matrix (#5286)

* SPV_KHR_cooperative_matrix

* Update DEPS with headers

* Update according to review recommendations

* Bugfix and formatting

* Formatting missed or damaged by VS2022
This commit is contained in:
archimedus 2023-06-23 00:33:36 +02:00 committed by GitHub
parent 16098b3c10
commit 04cdb2d344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1152 additions and 55 deletions

2
DEPS
View File

@ -13,7 +13,7 @@ vars = {
'protobuf_revision': 'v21.12', 'protobuf_revision': 'v21.12',
're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae', 're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae',
'spirv_headers_revision': '10db9d4e194246a020a4148e220837ac7c68cfd9', 'spirv_headers_revision': '3469b164e25cee24435029a569933cb42578db5d',
} }
deps = { deps = {

View File

@ -285,6 +285,13 @@ typedef enum spv_operand_type_t {
// An optional packed vector format // An optional packed vector format
SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT, SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
// Concrete operand types for cooperative matrix.
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
// An optional cooperative matrix operands
SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
// This is a sentinel value, and does not represent an operand type. // This is a sentinel value, and does not represent an operand type.
// It should come last. // It should come last.
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES, SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,

View File

@ -154,11 +154,12 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
CASE(InBoundsAccessChain), CASE(InBoundsAccessChain),
CASE(PtrAccessChain), CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain), CASE(InBoundsPtrAccessChain),
CASE(CooperativeMatrixLengthNV) CASE(CooperativeMatrixLengthNV),
CASE(CooperativeMatrixLengthKHR)
}; };
// The 60 is determined by counting the opcodes listed in the spec. // The 60 is determined by counting the opcodes listed in the spec.
static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]), static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete"); "OpSpecConstantOp opcode table is incomplete");
#undef CASE #undef CASE
// clang-format on // clang-format on

View File

@ -691,7 +691,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL: case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: { case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
// This operand is a mask. // This operand is a mask.
// Map an optional operand type to its corresponding concrete type. // Map an optional operand type to its corresponding concrete type.
@ -699,6 +701,8 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE; parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
// Check validity of set mask bits. Also prepare for operands for those // Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from // masks if they have any. To get operand order correct, scan from

View File

@ -274,6 +274,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
case spv::Op::OpTypeArray: case spv::Op::OpTypeArray:
case spv::Op::OpTypeStruct: case spv::Op::OpTypeStruct:
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true; return true;
default: default:
return false; return false;
@ -340,6 +341,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeNamedBarrier: case spv::Op::OpTypeNamedBarrier:
case spv::Op::OpTypeAccelerationStructureNV: case spv::Op::OpTypeAccelerationStructureNV:
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
// case spv::Op::OpTypeAccelerationStructureKHR: covered by // case spv::Op::OpTypeAccelerationStructureKHR: covered by
// spv::Op::OpTypeAccelerationStructureNV // spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR: case spv::Op::OpTypeRayQueryKHR:

View File

@ -236,6 +236,13 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT: case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT: case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
return "packed vector format"; return "packed vector format";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
return "cooperative matrix operands";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
return "cooperative matrix layout";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
return "cooperative matrix use";
case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
return "image"; return "image";
@ -369,6 +376,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_QUANTIZATION_MODES: case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
case SPV_OPERAND_TYPE_OVERFLOW_MODES: case SPV_OPERAND_TYPE_OVERFLOW_MODES:
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT: case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
return true; return true;
default: default:
break; break;
@ -387,6 +396,7 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE: case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
return true; return true;
default: default:
break; break;
@ -405,6 +415,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT: case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_CIV: case SPV_OPERAND_TYPE_OPTIONAL_CIV:
return true; return true;
default: default:

View File

@ -423,6 +423,23 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}}); {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
break; break;
} }
case Type::kCooperativeMatrixKHR: {
auto coop_mat = type->AsCooperativeMatrixKHR();
uint32_t const component_type =
GetTypeInstruction(coop_mat->component_type());
if (component_type == 0) {
return 0;
}
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {component_type}},
{SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
{SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
{SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
break;
}
default: default:
assert(false && "Unexpected type"); assert(false && "Unexpected type");
break; break;
@ -628,6 +645,14 @@ Type* TypeManager::RebuildType(const Type& type) {
cm_type->columns_id()); cm_type->columns_id());
break; break;
} }
case Type::kCooperativeMatrixKHR: {
const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id(), cm_type->use_id());
break;
}
default: default:
assert(false && "Unhandled type"); assert(false && "Unhandled type");
return nullptr; return nullptr;
@ -863,6 +888,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3)); inst.GetSingleWordInOperand(3));
break; break;
case spv::Op::OpTypeCooperativeMatrixKHR:
type = new CooperativeMatrixKHR(
GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
break;
case spv::Op::OpTypeRayQueryKHR: case spv::Op::OpTypeRayQueryKHR:
type = new RayQueryKHR(); type = new RayQueryKHR();
break; break;

View File

@ -128,6 +128,7 @@ std::unique_ptr<Type> Type::Clone() const {
DeclareKindCase(NamedBarrier); DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV); DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR); DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV); DeclareKindCase(HitObjectNV);
#undef DeclareKindCase #undef DeclareKindCase
@ -175,6 +176,7 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(NamedBarrier); DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV); DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR); DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV); DeclareKindCase(HitObjectNV);
#undef DeclareKindCase #undef DeclareKindCase
@ -230,6 +232,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(NamedBarrier); DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV); DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR); DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV); DeclareKindCase(HitObjectNV);
#undef DeclareKindCase #undef DeclareKindCase
@ -708,6 +711,45 @@ bool CooperativeMatrixNV::IsSameImpl(const Type* that,
columns_id_ == mt->columns_id_ && HasSameDecorations(that); columns_id_ == mt->columns_id_ && HasSameDecorations(that);
} }
CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
const uint32_t scope,
const uint32_t rows,
const uint32_t columns,
const uint32_t use)
: Type(kCooperativeMatrixKHR),
component_type_(type),
scope_id_(scope),
rows_id_(rows),
columns_id_(columns),
use_id_(use) {
assert(type != nullptr);
assert(scope != 0);
assert(rows != 0);
assert(columns != 0);
}
std::string CooperativeMatrixKHR::str() const {
std::ostringstream oss;
oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
<< ", " << columns_id_ << ", " << use_id_ << ">";
return oss.str();
}
size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
SeenTypes* seen) const {
hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
return component_type_->ComputeHashValue(hash, seen);
}
bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
IsSameCache* seen) const {
const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
}
} // namespace analysis } // namespace analysis
} // namespace opt } // namespace opt
} // namespace spvtools } // namespace spvtools

View File

@ -60,6 +60,7 @@ class PipeStorage;
class NamedBarrier; class NamedBarrier;
class AccelerationStructureNV; class AccelerationStructureNV;
class CooperativeMatrixNV; class CooperativeMatrixNV;
class CooperativeMatrixKHR;
class RayQueryKHR; class RayQueryKHR;
class HitObjectNV; class HitObjectNV;
@ -100,6 +101,7 @@ class Type {
kNamedBarrier, kNamedBarrier,
kAccelerationStructureNV, kAccelerationStructureNV,
kCooperativeMatrixNV, kCooperativeMatrixNV,
kCooperativeMatrixKHR,
kRayQueryKHR, kRayQueryKHR,
kHitObjectNV, kHitObjectNV,
kLast kLast
@ -201,6 +203,7 @@ class Type {
DeclareCastMethod(NamedBarrier) DeclareCastMethod(NamedBarrier)
DeclareCastMethod(AccelerationStructureNV) DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV) DeclareCastMethod(CooperativeMatrixNV)
DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(RayQueryKHR) DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV) DeclareCastMethod(HitObjectNV)
#undef DeclareCastMethod #undef DeclareCastMethod
@ -624,6 +627,38 @@ class CooperativeMatrixNV : public Type {
const uint32_t columns_id_; const uint32_t columns_id_;
}; };
class CooperativeMatrixKHR : public Type {
public:
CooperativeMatrixKHR(const Type* type, const uint32_t scope,
const uint32_t rows, const uint32_t columns,
const uint32_t use);
CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
std::string str() const override;
CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
return this;
}
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
const Type* component_type() const { return component_type_; }
uint32_t scope_id() const { return scope_id_; }
uint32_t rows_id() const { return rows_id_; }
uint32_t columns_id() const { return columns_id_; }
uint32_t use_id() const { return use_id_; }
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* component_type_;
const uint32_t scope_id_;
const uint32_t rows_id_;
const uint32_t columns_id_;
const uint32_t use_id_;
};
#define DefineParameterlessType(type, name) \ #define DefineParameterlessType(type, name) \
class type : public Type { \ class type : public Type { \
public: \ public: \

View File

@ -402,7 +402,8 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL: case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: { case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
uint32_t value; uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) { if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error) return context->diagnostic(error)

View File

@ -42,14 +42,29 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
opcode != spv::Op::OpFMod); opcode != spv::Op::OpFMod);
if (!_.IsFloatScalarType(result_type) && if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type) && !_.IsFloatVectorType(result_type) &&
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type))) !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpFMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsFloatCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: " << "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
for (size_t operand_index = 2; operand_index < inst->operands().size(); for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) { ++operand_index) {
if (_.GetOperandTypeId(inst, operand_index) != result_type) if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsFloatCooperativeMatrixType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: " << "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index " << spvOpcodeString(opcode) << " operand index "
@ -71,7 +86,19 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
for (size_t operand_index = 2; operand_index < inst->operands().size(); for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) { ++operand_index) {
if (_.GetOperandTypeId(inst, operand_index) != result_type) if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsUnsignedIntCooperativeMatrixType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: " << "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index " << spvOpcodeString(opcode) << " operand index "
@ -91,7 +118,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem && (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
opcode != spv::Op::OpSMod); opcode != spv::Op::OpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))) !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: " << "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
@ -102,9 +132,26 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
for (size_t operand_index = 2; operand_index < inst->operands().size(); for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) { ++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsIntCooperativeMatrixType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
}
if (!type_id || if (!type_id ||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) && (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))) !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: " << "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index " << spvOpcodeString(opcode) << " operand index "
@ -187,7 +234,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpMatrixTimesScalar: { case spv::Op::OpMatrixTimesScalar: {
if (!_.IsFloatMatrixType(result_type) && if (!_.IsFloatMatrixType(result_type) &&
!_.IsCooperativeMatrixType(result_type)) !(_.IsCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float matrix type as Result Type: " << "Expected float matrix type as Result Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
@ -459,22 +506,108 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3); const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4); const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
if (!_.IsCooperativeMatrixType(A_type_id)) { if (!_.IsCooperativeMatrixNVType(A_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as A Type: " << "Expected cooperative matrix type as A Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
} }
if (!_.IsCooperativeMatrixType(B_type_id)) { if (!_.IsCooperativeMatrixNVType(B_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as B Type: " << "Expected cooperative matrix type as B Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
} }
if (!_.IsCooperativeMatrixType(C_type_id)) { if (!_.IsCooperativeMatrixNVType(C_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as C Type: " << "Expected cooperative matrix type as C Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);
} }
if (!_.IsCooperativeMatrixType(D_type_id)) { if (!_.IsCooperativeMatrixNVType(D_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as Result Type: "
<< spvOpcodeString(opcode);
}
const auto A = _.FindDef(A_type_id);
const auto B = _.FindDef(B_type_id);
const auto C = _.FindDef(C_type_id);
const auto D = _.FindDef(D_type_id);
std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
std::tuple<bool, bool, uint32_t> Y) {
return (std::get<1>(X) && std::get<1>(Y) &&
std::get<2>(X) != std::get<2>(Y));
};
if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix scopes must match: "
<< spvOpcodeString(opcode);
}
if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
notEqual(C_rows, D_rows)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'M' mismatch: "
<< spvOpcodeString(opcode);
}
if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
notEqual(C_cols, D_cols)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'N' mismatch: "
<< spvOpcodeString(opcode);
}
if (notEqual(A_cols, B_rows)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'K' mismatch: "
<< spvOpcodeString(opcode);
}
break;
}
case spv::Op::OpCooperativeMatrixMulAddKHR: {
const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
if (!_.IsCooperativeMatrixAType(A_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix type must be A Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixBType(B_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix type must be B Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixAccType(C_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix type must be Accumulator Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as Result Type: " << "Expected cooperative matrix type as Result Type: "
<< spvOpcodeString(opcode); << spvOpcodeString(opcode);

View File

@ -122,6 +122,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
*member_type = type_inst->word(component_index + 2); *member_type = type_inst->word(component_index + 2);
break; break;
} }
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: { case spv::Op::OpTypeCooperativeMatrixNV: {
*member_type = type_inst->word(2); *member_type = type_inst->word(2);
break; break;
@ -335,6 +336,25 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
break; break;
} }
case spv::Op::OpTypeCooperativeMatrixKHR: {
const auto result_type_inst = _.FindDef(result_type);
assert(result_type_inst);
const auto component_type_id =
result_type_inst->GetOperandAs<uint32_t>(1);
if (3 != num_operands) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Must be only one constituent";
}
const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
if (operand_type_id != component_type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Constituent type to be equal to the component type";
}
break;
}
case spv::Op::OpTypeCooperativeMatrixNV: { case spv::Op::OpTypeCooperativeMatrixNV: {
const auto result_type_inst = _.FindDef(result_type); const auto result_type_inst = _.FindDef(result_type);
assert(result_type_inst); assert(result_type_inst);

View File

@ -243,6 +243,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
} }
} }
} break; } break;
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: { case spv::Op::OpTypeCooperativeMatrixNV: {
if (1 != constituent_count) { if (1 != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
@ -310,6 +311,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
case spv::Op::OpTypeArray: case spv::Op::OpTypeArray:
case spv::Op::OpTypeMatrix: case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeVector: { case spv::Op::OpTypeVector: {
auto base_type = _.FindDef(instruction[2]); auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _); return base_type && IsTypeNullable(base_type->words(), _);

View File

@ -473,7 +473,10 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
const bool input_is_pointer = _.IsPointerType(input_type); const bool input_is_pointer = _.IsPointerType(input_type);
const bool input_is_int_scalar = _.IsIntScalarType(input_type); const bool input_is_int_scalar = _.IsIntScalarType(input_type);
if (!result_is_pointer && !result_is_int_scalar && const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
!_.IsIntVectorType(result_type) && !_.IsIntVectorType(result_type) &&
!_.IsFloatScalarType(result_type) && !_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type)) !_.IsFloatVectorType(result_type))
@ -481,13 +484,24 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "Expected Result Type to be a pointer or int or float vector " << "Expected Result Type to be a pointer or int or float vector "
<< "or scalar type: " << spvOpcodeString(opcode); << "or scalar type: " << spvOpcodeString(opcode);
if (!input_is_pointer && !input_is_int_scalar && if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
!_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) && !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type)) !_.IsFloatVectorType(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst) return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be a pointer or int or float vector " << "Expected input to be a pointer or int or float vector "
<< "or scalar: " << spvOpcodeString(opcode); << "or scalar: " << spvOpcodeString(opcode);
if (result_is_coopmat != input_is_coopmat)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix can only be cast to another cooperative "
<< "matrix: " << spvOpcodeString(opcode);
if (result_is_coopmat) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
}
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) || if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
_.HasExtension(kSPV_KHR_physical_storage_buffer)) { _.HasExtension(kSPV_KHR_physical_storage_buffer)) {
const bool result_is_int_vector = _.IsIntVectorType(result_type); const bool result_is_int_vector = _.IsIntVectorType(result_type);

View File

@ -163,9 +163,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
!inst->IsDebugInfo() && !inst->IsNonSemantic() && !inst->IsDebugInfo() && !inst->IsNonSemantic() &&
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction && !spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV && opcode != spv::Op::OpCooperativeMatrixLengthNV &&
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp && !(opcode == spv::Op::OpSpecConstantOp &&
spv::Op(inst->word(3)) == (spv::Op(inst->word(3)) ==
spv::Op::OpCooperativeMatrixLengthNV)) { spv::Op::OpCooperativeMatrixLengthNV ||
spv::Op(inst->word(3)) ==
spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word) << "Operand " << _.getIdName(operand_word)
<< " cannot be a type"; << " cannot be a type";
@ -179,9 +182,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
opcode != spv::Op::OpLoopMerge && opcode != spv::Op::OpLoopMerge &&
opcode != spv::Op::OpFunction && opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV && opcode != spv::Op::OpCooperativeMatrixLengthNV &&
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp && !(opcode == spv::Op::OpSpecConstantOp &&
spv::Op(inst->word(3)) == (spv::Op(inst->word(3)) ==
spv::Op::OpCooperativeMatrixLengthNV)) { spv::Op::OpCooperativeMatrixLengthNV ||
spv::Op(inst->word(3)) ==
spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word) << "Operand " << _.getIdName(operand_word)
<< " requires a type"; << " requires a type";

View File

@ -204,6 +204,7 @@ bool ContainsCooperativeMatrix(ValidationState_t& _,
switch (storage->opcode()) { switch (storage->opcode()) {
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true; return true;
case spv::Op::OpTypeArray: case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeRuntimeArray:
@ -232,6 +233,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
spv::StorageClass src_sc = spv::StorageClass::Max; spv::StorageClass src_sc = spv::StorageClass::Max;
switch (inst->opcode()) { switch (inst->opcode()) {
case spv::Op::OpCooperativeMatrixLoadNV: case spv::Op::OpCooperativeMatrixLoadNV:
case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpLoad: { case spv::Op::OpLoad: {
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2)); auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
auto load_pointer_type = _.FindDef(load_pointer->type_id()); auto load_pointer_type = _.FindDef(load_pointer->type_id());
@ -239,6 +241,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
break; break;
} }
case spv::Op::OpCooperativeMatrixStoreNV: case spv::Op::OpCooperativeMatrixStoreNV:
case spv::Op::OpCooperativeMatrixStoreKHR:
case spv::Op::OpStore: { case spv::Op::OpStore: {
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0)); auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
auto store_pointer_type = _.FindDef(store_pointer->type_id()); auto store_pointer_type = _.FindDef(store_pointer->type_id());
@ -326,7 +329,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
const uint32_t mask = inst->GetOperandAs<uint32_t>(index); const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) { if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
if (inst->opcode() == spv::Op::OpLoad || if (inst->opcode() == spv::Op::OpLoad ||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) { inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerAvailableKHR cannot be used with OpLoad."; << "MakePointerAvailableKHR cannot be used with OpLoad.";
} }
@ -1357,6 +1361,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
case spv::Op::OpTypeMatrix: case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeVector: case spv::Op::OpTypeVector:
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeArray: case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray: { case spv::Op::OpTypeRuntimeArray: {
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV, // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
@ -1554,9 +1559,15 @@ spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
<< " must be OpTypeInt with width 32 and signedness 0."; << " must be OpTypeInt with width 32 and signedness 0.";
} }
bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
auto type_id = inst->GetOperandAs<uint32_t>(2); auto type_id = inst->GetOperandAs<uint32_t>(2);
auto type = state.FindDef(type_id); auto type = state.FindDef(type_id);
if (type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) { if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
return state.diag(SPV_ERROR_INVALID_ID, inst)
<< "The type in " << instr_name << " <id> "
<< state.getIdName(type_id)
<< " must be OpTypeCooperativeMatrixKHR.";
} else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
return state.diag(SPV_ERROR_INVALID_ID, inst) return state.diag(SPV_ERROR_INVALID_ID, inst)
<< "The type in " << instr_name << " <id> " << "The type in " << instr_name << " <id> "
<< state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV."; << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
@ -1668,6 +1679,112 @@ spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
return SPV_SUCCESS; return SPV_SUCCESS;
} }
spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
const Instruction* inst) {
uint32_t type_id;
const char* opname;
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
type_id = inst->type_id();
opname = "spv::Op::OpCooperativeMatrixLoadKHR";
} else {
// get Object operand's type
type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
opname = "spv::Op::OpCooperativeMatrixStoreKHR";
}
auto matrix_type = _.FindDef(type_id);
if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
} else {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
}
}
const auto pointer_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
const auto pointer = _.FindDef(pointer_id);
if (!pointer ||
((_.addressing_model() == spv::AddressingModel::Logical) &&
((!_.features().variable_pointers &&
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
(_.features().variable_pointers &&
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Pointer <id> " << _.getIdName(pointer_id)
<< " is not a logical pointer.";
}
const auto pointer_type_id = pointer->type_id();
const auto pointer_type = _.FindDef(pointer_type_id);
if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " type for pointer <id> " << _.getIdName(pointer_id)
<< " is not a pointer type.";
}
const auto storage_class_index = 1u;
const auto storage_class =
pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
if (storage_class != spv::StorageClass::Workgroup &&
storage_class != spv::StorageClass::StorageBuffer &&
storage_class != spv::StorageClass::PhysicalStorageBuffer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " storage class for pointer type <id> "
<< _.getIdName(pointer_type_id)
<< " is not Workgroup or StorageBuffer.";
}
const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
const auto pointee_type = _.FindDef(pointee_id);
if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
_.IsFloatScalarOrVectorType(pointee_id))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Pointer <id> " << _.getIdName(pointer->id())
<< "s Type must be a scalar or vector type.";
}
const auto layout_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
const auto colmajor = _.FindDef(colmajor_id);
if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
!(spvOpcodeIsConstant(colmajor->opcode()) ||
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
<< " must be a 32-bit integer constant instruction.";
}
const auto stride_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
if (inst->operands().size() > stride_index) {
const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
const auto stride = _.FindDef(stride_id);
if (!stride || !_.IsIntScalarType(stride->type_id())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Stride operand <id> " << _.getIdName(stride_id)
<< " must be a scalar integer type.";
}
}
const auto memory_access_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
if (inst->operands().size() > memory_access_index) {
if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
return error;
}
return SPV_SUCCESS;
}
spv_result_t ValidatePtrComparison(ValidationState_t& _, spv_result_t ValidatePtrComparison(ValidationState_t& _,
const Instruction* inst) { const Instruction* inst) {
if (_.addressing_model() == spv::AddressingModel::Logical && if (_.addressing_model() == spv::AddressingModel::Logical &&
@ -1757,9 +1874,15 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst)) if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
return error; return error;
break; break;
case spv::Op::OpCooperativeMatrixLengthKHR:
case spv::Op::OpCooperativeMatrixLengthNV: case spv::Op::OpCooperativeMatrixLengthNV:
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error; if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
break; break;
case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpCooperativeMatrixStoreKHR:
if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
return error;
break;
case spv::Op::OpPtrEqual: case spv::Op::OpPtrEqual:
case spv::Op::OpPtrNotEqual: case spv::Op::OpPtrNotEqual:
case spv::Op::OpPtrDiff: case spv::Op::OpPtrDiff:

View File

@ -552,8 +552,8 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
return SPV_SUCCESS; return SPV_SUCCESS;
} }
spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _, spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
const Instruction* inst) { const Instruction* inst) {
const auto component_type_index = 1; const auto component_type_index = 1;
const auto component_type_id = const auto component_type_id =
inst->GetOperandAs<uint32_t>(component_type_index); inst->GetOperandAs<uint32_t>(component_type_index);
@ -561,7 +561,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() && if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
spv::Op::OpTypeInt != component_type->opcode())) { spv::Op::OpTypeInt != component_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Component Type <id> " << "OpTypeCooperativeMatrix Component Type <id> "
<< _.getIdName(component_type_id) << _.getIdName(component_type_id)
<< " is not a scalar numerical type."; << " is not a scalar numerical type.";
} }
@ -572,7 +572,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
if (!scope || !_.IsIntScalarType(scope->type_id()) || if (!scope || !_.IsIntScalarType(scope->type_id()) ||
!spvOpcodeIsConstant(scope->opcode())) { !spvOpcodeIsConstant(scope->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Scope <id> " << _.getIdName(scope_id) << "OpTypeCooperativeMatrix Scope <id> " << _.getIdName(scope_id)
<< " is not a constant instruction with scalar integer type."; << " is not a constant instruction with scalar integer type.";
} }
@ -582,7 +582,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
if (!rows || !_.IsIntScalarType(rows->type_id()) || if (!rows || !_.IsIntScalarType(rows->type_id()) ||
!spvOpcodeIsConstant(rows->opcode())) { !spvOpcodeIsConstant(rows->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Rows <id> " << _.getIdName(rows_id) << "OpTypeCooperativeMatrix Rows <id> " << _.getIdName(rows_id)
<< " is not a constant instruction with scalar integer type."; << " is not a constant instruction with scalar integer type.";
} }
@ -592,10 +592,22 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
if (!cols || !_.IsIntScalarType(cols->type_id()) || if (!cols || !_.IsIntScalarType(cols->type_id()) ||
!spvOpcodeIsConstant(cols->opcode())) { !spvOpcodeIsConstant(cols->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst) return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Cols <id> " << _.getIdName(cols_id) << "OpTypeCooperativeMatrix Cols <id> " << _.getIdName(cols_id)
<< " is not a constant instruction with scalar integer type."; << " is not a constant instruction with scalar integer type.";
} }
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
const auto use_index = 5;
const auto use_id = inst->GetOperandAs<uint32_t>(use_index);
const auto use = _.FindDef(use_id);
if (!use || !_.IsIntScalarType(use->type_id()) ||
!spvOpcodeIsConstant(use->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixKHR Use <id> " << _.getIdName(use_id)
<< " is not a constant instruction with scalar integer type.";
}
}
return SPV_SUCCESS; return SPV_SUCCESS;
} }
} // namespace } // namespace
@ -640,7 +652,8 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
if (auto error = ValidateTypeForwardPointer(_, inst)) return error; if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
break; break;
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error; case spv::Op::OpTypeCooperativeMatrixKHR:
if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
break; break;
default: default:
break; break;

View File

@ -859,6 +859,7 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
return GetComponentType(inst->word(2)); return GetComponentType(inst->word(2));
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return inst->word(2); return inst->word(2);
default: default:
@ -886,6 +887,7 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
return inst->word(3); return inst->word(3);
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
// Actual dimension isn't known, return 0 // Actual dimension isn't known, return 0
return 0; return 0;
@ -1142,22 +1144,68 @@ bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
} }
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const { bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && (inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV ||
inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR);
}
bool ValidationState_t::IsCooperativeMatrixNVType(uint32_t id) const {
const Instruction* inst = FindDef(id); const Instruction* inst = FindDef(id);
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV; return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
} }
bool ValidationState_t::IsCooperativeMatrixKHRType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR;
}
bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
}
return false;
}
bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
}
return false;
}
bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse == static_cast<uint64_t>(
spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
}
return false;
}
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const { bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false; if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
return false;
return IsFloatScalarType(FindDef(id)->word(2)); return IsFloatScalarType(FindDef(id)->word(2));
} }
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const { bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false; if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
return false;
return IsIntScalarType(FindDef(id)->word(2)); return IsIntScalarType(FindDef(id)->word(2));
} }
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const { bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false; if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
return false;
return IsUnsignedIntScalarType(FindDef(id)->word(2)); return IsUnsignedIntScalarType(FindDef(id)->word(2));
} }
@ -1173,8 +1221,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
const auto m1_type = FindDef(m1); const auto m1_type = FindDef(m1);
const auto m2_type = FindDef(m2); const auto m2_type = FindDef(m2);
if (m1_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV || if (m1_type->opcode() != m2_type->opcode()) {
m2_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
return diag(SPV_ERROR_INVALID_DATA, inst) return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix types"; << "Expected cooperative matrix types";
} }
@ -1224,6 +1271,21 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
<< "identical"; << "identical";
} }
if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
uint32_t m1_use_id = m1_type->GetOperandAs<uint32_t>(5);
uint32_t m2_use_id = m2_type->GetOperandAs<uint32_t>(5);
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
EvalInt32IfConst(m1_use_id);
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
EvalInt32IfConst(m2_use_id);
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Use of Matrix type and Result Type to be "
<< "identical";
}
}
return SPV_SUCCESS; return SPV_SUCCESS;
} }
@ -1489,6 +1551,7 @@ bool ValidationState_t::ContainsType(
case spv::Op::OpTypeImage: case spv::Op::OpTypeImage:
case spv::Op::OpTypeSampledImage: case spv::Op::OpTypeSampledImage:
case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return ContainsType(inst->GetOperandAs<uint32_t>(1u), f, return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
traverse_all_types); traverse_all_types);
case spv::Op::OpTypePointer: case spv::Op::OpTypePointer:

View File

@ -610,6 +610,11 @@ class ValidationState_t {
bool IsPointerType(uint32_t id) const; bool IsPointerType(uint32_t id) const;
bool IsAccelerationStructureType(uint32_t id) const; bool IsAccelerationStructureType(uint32_t id) const;
bool IsCooperativeMatrixType(uint32_t id) const; bool IsCooperativeMatrixType(uint32_t id) const;
bool IsCooperativeMatrixNVType(uint32_t id) const;
bool IsCooperativeMatrixKHRType(uint32_t id) const;
bool IsCooperativeMatrixAType(uint32_t id) const;
bool IsCooperativeMatrixBType(uint32_t id) const;
bool IsCooperativeMatrixAccType(uint32_t id) const;
bool IsFloatCooperativeMatrixType(uint32_t id) const; bool IsFloatCooperativeMatrixType(uint32_t id) const;
bool IsIntCooperativeMatrixType(uint32_t id) const; bool IsIntCooperativeMatrixType(uint32_t id) const;
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const; bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;

View File

@ -171,6 +171,7 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new NamedBarrier()); types.emplace_back(new NamedBarrier());
types.emplace_back(new AccelerationStructureNV()); types.emplace_back(new AccelerationStructureNV());
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24)); types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002));
types.emplace_back(new RayQueryKHR()); types.emplace_back(new RayQueryKHR());
types.emplace_back(new HitObjectNV()); types.emplace_back(new HitObjectNV());
@ -237,6 +238,8 @@ TEST(TypeManager, TypeStrings) {
%arr_long_constant = OpTypeArray %s32 %long_constant %arr_long_constant = OpTypeArray %s32 %long_constant
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op %arr_spec_const_op = OpTypeArray %s32 %spec_const_op
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4 %cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
%id2 = OpConstant %u32 2
%cmkhr = OpTypeCooperativeMatrixKHR %f64 %id4 %id4 %id4 %id2
)"; )";
std::vector<std::pair<uint32_t, std::string>> type_id_strs = { std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@ -275,6 +278,7 @@ TEST(TypeManager, TypeStrings) {
{37, "[sint32, id(33), words(0,705032704,1)]"}, {37, "[sint32, id(33), words(0,705032704,1)]"},
{38, "[sint32, id(34), words(2,34)]"}, {38, "[sint32, id(34), words(2,34)]"},
{39, "<float64, 6, 6, 6>"}, {39, "<float64, 6, 6, 6>"},
{41, "<float64, 6, 6, 6, 40>"},
}; };
std::unique_ptr<IRContext> context = std::unique_ptr<IRContext> context =
@ -940,12 +944,15 @@ OpMemoryModel Logical GLSL450
std::vector<std::unique_ptr<Type>> types = GenerateAllTypes(); std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
uint32_t id = 1u; uint32_t id = 1u;
for (auto& t : types) { for (auto& t : types) {
std::cout << ". id " << id << std::endl;
context->get_type_mgr()->RegisterType(id, *t); context->get_type_mgr()->RegisterType(id, *t);
EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id)); EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
} }
std::cout << "clear" << id << std::endl;
types.clear(); types.clear();
for (; id > 0; --id) { for (; id > 0; --id) {
std::cout << ". remove id " << id << std::endl;
context->get_type_mgr()->RemoveId(id); context->get_type_mgr()->RemoveId(id);
EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id)); EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
} }
@ -1030,6 +1037,8 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]] ; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]] ; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
; CHECK: [[uint8:%\w+]] = OpConstant [[uint]] 8
; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24 ; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42 ; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100 ; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
@ -1085,6 +1094,7 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
; CHECK: OpTypeNamedBarrier ; CHECK: OpTypeNamedBarrier
; CHECK: OpTypeAccelerationStructureKHR ; CHECK: OpTypeAccelerationStructureKHR
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]] ; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
; CHECK: OpTypeRayQueryKHR ; CHECK: OpTypeRayQueryKHR
; CHECK: OpTypeHitObjectNV ; CHECK: OpTypeHitObjectNV
OpCapability Shader OpCapability Shader
@ -1094,6 +1104,8 @@ OpMemoryModel Logical GLSL450
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
%1 = OpTypePointer Input %uint %1 = OpTypePointer Input %uint
%2 = OpTypePointer Uniform %uint %2 = OpTypePointer Uniform %uint
%1002 = OpConstant %uint 2
%8 = OpConstant %uint 8
%24 = OpConstant %uint 24 %24 = OpConstant %uint 24
%42 = OpConstant %uint 42 %42 = OpConstant %uint 42
%100 = OpConstant %uint 100 %100 = OpConstant %uint 100

View File

@ -1318,7 +1318,7 @@ TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str()); CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> " HasSubstr("OpTypeCooperativeMatrix Component Type <id> "
"'4[%bool]' is not a scalar numerical type.")); "'4[%bool]' is not a scalar numerical type."));
} }
@ -1331,7 +1331,7 @@ TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a " HasSubstr("OpTypeCooperativeMatrix Scope <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type.")); "constant instruction with scalar integer type."));
} }
@ -1344,7 +1344,7 @@ TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a " HasSubstr("OpTypeCooperativeMatrix Rows <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type.")); "constant instruction with scalar integer type."));
} }
@ -1357,7 +1357,7 @@ TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a " HasSubstr("OpTypeCooperativeMatrix Cols <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type.")); "constant instruction with scalar integer type."));
} }
@ -1469,6 +1469,146 @@ TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
"SMulExtended")); "SMulExtended"));
} }
std::string GenerateCoopMatKHRCode(const std::string& extra_types,
const std::string& main_body) {
const std::string prefix = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%u32_16 = OpConstant %u32 16
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%useA = OpConstant %u32 0
%useB = OpConstant %u32 1
%useC = OpConstant %u32 2
%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useA
%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useA
%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useB
%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useB
%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useB
%f16matC = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useC
%f32matC = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useC
%u32matC = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useC
%s32matC = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useC
%f16_1 = OpConstant %f16 1
%f32_1 = OpConstant %f32 1
%u32_1 = OpConstant %u32 1
%s32_1 = OpConstant %s32 1
%f16mat_A_1 = OpConstantComposite %f16matA %f16_1
%u32mat_A_1 = OpConstantComposite %u32matA %u32_1
%s32mat_A_1 = OpConstantComposite %s32matA %s32_1
%f16mat_B_1 = OpConstantComposite %f16matB %f16_1
%u32mat_B_1 = OpConstantComposite %u32matB %u32_1
%s32mat_B_1 = OpConstantComposite %s32matB %s32_1
%f16mat_C_1 = OpConstantComposite %f16matC %f16_1
%u32mat_C_1 = OpConstantComposite %u32matC %u32_1
%s32mat_C_1 = OpConstantComposite %s32matC %s32_1
)";
const std::string func_begin = R"(
%main = OpFunction %void None %func
%main_entry = OpLabel)";
const std::string suffix = R"(
OpReturn
OpFunctionEnd)";
return prefix + extra_types + func_begin + main_body + suffix;
}
TEST_F(ValidateArithmetics, CoopMatKHRSuccess) {
const std::string body = R"(
%val1 = OpFAdd %f16matA %f16mat_A_1 %f16mat_A_1
%val2 = OpFSub %f16matA %f16mat_A_1 %f16mat_A_1
%val3 = OpFMul %f16matA %f16mat_A_1 %f16mat_A_1
%val4 = OpFDiv %f16matA %f16mat_A_1 %f16mat_A_1
%val5 = OpFNegate %f16matA %f16mat_A_1
%val6 = OpIAdd %u32matA %u32mat_A_1 %u32mat_A_1
%val7 = OpISub %u32matA %u32mat_A_1 %u32mat_A_1
%val8 = OpUDiv %u32matA %u32mat_A_1 %u32mat_A_1
%val9 = OpIAdd %s32matA %s32mat_A_1 %s32mat_A_1
%val10 = OpISub %s32matA %s32mat_A_1 %s32mat_A_1
%val11 = OpSDiv %s32matA %s32mat_A_1 %s32mat_A_1
%val12 = OpSNegate %s32matA %s32mat_A_1
%val13 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f16_1
%val14 = OpMatrixTimesScalar %u32matA %u32mat_A_1 %u32_1
%val15 = OpMatrixTimesScalar %s32matA %s32mat_A_1 %s32_1
%val16 = OpCooperativeMatrixMulAddKHR %f32matC %f16mat_A_1 %f16mat_B_1 %f16mat_C_1
%val17 = OpCooperativeMatrixMulAddKHR %s32matC %s32mat_A_1 %s32mat_B_1 %s32mat_C_1)";
CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
const std::string body = R"(
%val1 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f32_1
)";
CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected scalar operand type to be equal to the component "
"type of the matrix operand: MatrixTimesScalar"));
}
TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
const std::string types = R"(
%workgroup = OpConstant %u32 2
%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
)";
const std::string body = R"(
%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1
)";
CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
}
TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
const std::string types = R"(
%mat16x4 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_4 %useC
%mat16x4_C_1 = OpConstantComposite %mat16x4 %f16_1
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixMulAddKHR %mat16x4 %f16mat_A_1 %f16mat_B_1 %mat16x4_C_1
)";
CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
}
} // namespace } // namespace
} // namespace val } // namespace val
} // namespace spvtools } // namespace spvtools

View File

@ -1486,8 +1486,7 @@ OpFunctionEnd
} }
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) { TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
const std::string body = const std::string body = R"(
R"(
OpCapability Shader OpCapability Shader
OpCapability Float16 OpCapability Float16
OpCapability CooperativeMatrixNV OpCapability CooperativeMatrixNV
@ -1525,8 +1524,7 @@ OpFunctionEnd)";
} }
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) { TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
const std::string body = const std::string body = R"(
R"(
OpCapability Shader OpCapability Shader
OpCapability Float16 OpCapability Float16
OpCapability CooperativeMatrixNV OpCapability CooperativeMatrixNV
@ -1562,6 +1560,86 @@ OpFunctionEnd)";
HasSubstr("Expected Constituent type to be equal to the component type")); HasSubstr("Expected Constituent type to be equal to the component type"));
} }
TEST_F(ValidateComposites, CoopMatKHRConstantCompositeMismatchFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%u32_16 = OpConstant %u32 16
%useA = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
%f32_1 = OpConstant %f32 1
%f16mat_1 = OpConstantComposite %f16mat %f32_1
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"OpConstantComposite Constituent <id> '12[%float_1]' type "
"does not match the Result Type <id> '11[%11]'s component type."));
}
TEST_F(ValidateComposites, CoopMatKHRCompositeConstructMismatchFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%u32_16 = OpConstant %u32 16
%useA = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
%f32_1 = OpConstant %f32 1
%main = OpFunction %void None %func
%main_entry = OpLabel
%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected Constituent type to be equal to the component type"));
}
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) { TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
const std::string spirv = R"( const std::string spirv = R"(
OpCapability Shader OpCapability Shader

View File

@ -1149,8 +1149,7 @@ OpFunctionEnd)";
} }
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) { TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
const std::string body = const std::string body = R"(
R"(
OpCapability Shader OpCapability Shader
OpCapability Float16 OpCapability Float16
OpCapability Int16 OpCapability Int16
@ -1191,6 +1190,179 @@ OpFunctionEnd)";
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateConversion, CoopMatKHRConversionSuccess) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
%u16mat = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A
%u32mat = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A
%s16mat = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A
%s32mat = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A
%f16_1 = OpConstant %f16 1
%f32_1 = OpConstant %f32 1
%u16_1 = OpConstant %u16 1
%u32_1 = OpConstant %u32 1
%s16_1 = OpConstant %s16 1
%s32_1 = OpConstant %s32 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%f32mat_1 = OpConstantComposite %f32mat %f32_1
%u16mat_1 = OpConstantComposite %u16mat %u16_1
%u32mat_1 = OpConstantComposite %u32mat %u32_1
%s16mat_1 = OpConstantComposite %s16mat %s16_1
%s32mat_1 = OpConstantComposite %s32mat %s32_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val11 = OpConvertFToU %u16mat %f16mat_1
%val12 = OpConvertFToU %u32mat %f16mat_1
%val13 = OpConvertFToS %s16mat %f16mat_1
%val14 = OpConvertFToS %s32mat %f16mat_1
%val15 = OpFConvert %f32mat %f16mat_1
%val21 = OpConvertFToU %u16mat %f32mat_1
%val22 = OpConvertFToU %u32mat %f32mat_1
%val23 = OpConvertFToS %s16mat %f32mat_1
%val24 = OpConvertFToS %s32mat %f32mat_1
%val25 = OpFConvert %f16mat %f32mat_1
%val31 = OpConvertUToF %f16mat %u16mat_1
%val32 = OpConvertUToF %f32mat %u16mat_1
%val33 = OpUConvert %u32mat %u16mat_1
%val34 = OpSConvert %s32mat %u16mat_1
%val41 = OpConvertSToF %f16mat %s16mat_1
%val42 = OpConvertSToF %f32mat %s16mat_1
%val43 = OpUConvert %u32mat %s16mat_1
%val44 = OpSConvert %s32mat %s16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateConversion, CoopMatKHRConversionUseMismatchFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%use_A = OpConstant %u32 0
%use_B = OpConstant %u32 1
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B
%f16_1 = OpConstant %f16 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val1 = OpFConvert %f32mat %f16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected Use of Matrix type and Result Type to be identical"));
}
TEST_F(ValidateConversion, CoopMatKHRConversionScopeMismatchFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%workgroup = OpConstant %u32 2
%use_A = OpConstant %u32 0
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A
%f16_1 = OpConstant %f16 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val1 = OpFConvert %f32mat %f16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
}
TEST_F(ValidateConversion, BitcastSuccess) { TEST_F(ValidateConversion, BitcastSuccess) {
const std::string body = R"( const std::string body = R"(
%ptr = OpVariable %f32ptr_func Function %ptr = OpVariable %f32ptr_func Function

View File

@ -23,12 +23,14 @@
#include "test/val/val_fixtures.h" #include "test/val/val_fixtures.h"
// For pretty-printing tuples with spv_target_env. // For pretty-printing tuples with spv_target_env.
std::ostream& operator<<(std::ostream& stream, spv_target_env target) std::ostream& operator<<(std::ostream& stream, spv_target_env target) {
{
switch (target) { switch (target) {
case SPV_ENV_UNIVERSAL_1_3: return stream << "SPV_ENV_UNIVERSAL_1_3"; case SPV_ENV_UNIVERSAL_1_3:
case SPV_ENV_UNIVERSAL_1_4: return stream << "SPV_ENV_UNIVERSAL_1_4"; return stream << "SPV_ENV_UNIVERSAL_1_3";
default: return stream << (unsigned)target; case SPV_ENV_UNIVERSAL_1_4:
return stream << "SPV_ENV_UNIVERSAL_1_4";
default:
return stream << (unsigned)target;
} }
} }
@ -2346,6 +2348,186 @@ OpFunctionEnd)";
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}
TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
}
TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerAvailableKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
}
TEST_F(ValidateMemory, CoopMatKHRInvalidStorageClassFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%str = OpTypeStruct %f16mat
%str_ptr = OpTypePointer Workgroup %str
%sh = OpVariable %str_ptr Workgroup
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Cooperative matrix types (or types containing them) can only be "
"allocated in Function or Private storage classes or as function "
"parameters"));
}
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthResultTypeBad) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthKHR %i32 %f16mat
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("The Result Type of OpCooperativeMatrixLengthKHR <id> "
"'12[%12]' must be OpTypeInt with width 32 and signedness 0"));
}
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthOperandTypeBad) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthKHR %u32 %u32
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("The type in OpCooperativeMatrixLengthKHR <id> '5[%uint]' "
"must be OpTypeCooperativeMatrixKHR"));
}
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthGood) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthKHR %u32 %f16mat
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) { TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
std::string spirv = R"( std::string spirv = R"(
OpCapability Shader OpCapability Shader
@ -3765,9 +3947,8 @@ OpFunctionEnd
HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks")); HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks"));
} }
using ValidateSizedVariable = using ValidateSizedVariable = spvtest::ValidateBase<
spvtest::ValidateBase<std::tuple<std::string, std::string, std::tuple<std::string, std::string, std::string, spv_target_env>>;
std::string, spv_target_env>>;
CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) { CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
CodeGenerator generator; CodeGenerator generator;
@ -3777,7 +3958,8 @@ CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
"\"SPV_KHR_8bit_storage\"\n"; "\"SPV_KHR_8bit_storage\"\n";
generator.memory_model_ = "OpMemoryModel Logical GLSL450\n"; generator.memory_model_ = "OpMemoryModel Logical GLSL450\n";
if (is_8bit) { if (is_8bit) {
generator.before_types_ = "OpMemberDecorate %char_buffer_block 0 Offset 0\n"; generator.before_types_ =
"OpMemberDecorate %char_buffer_block 0 Offset 0\n";
if (buffer_block) if (buffer_block)
generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n"; generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n";

View File

@ -540,7 +540,7 @@ def generate_operand_kind_table(enums):
# We have a few operand kinds that require their optional counterpart to # We have a few operand kinds that require their optional counterpart to
# exist in the operand info table. # exist in the operand info table.
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat'] optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands']
optional_enums = [e for e in enums if e[0] in optional_enums] optional_enums = [e for e in enums if e[0] in optional_enums]
enums.extend(optional_enums) enums.extend(optional_enums)