mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-25 04:50:04 +00:00
SPV_KHR_cooperative_matrix (#5286)
* SPV_KHR_cooperative_matrix * Update DEPS with headers * Update according to review recommendations * Bugfix and formatting * Formatting missed or damaged by VS2022
This commit is contained in:
parent
16098b3c10
commit
04cdb2d344
2
DEPS
2
DEPS
@ -13,7 +13,7 @@ vars = {
|
|||||||
'protobuf_revision': 'v21.12',
|
'protobuf_revision': 'v21.12',
|
||||||
|
|
||||||
're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae',
|
're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae',
|
||||||
'spirv_headers_revision': '10db9d4e194246a020a4148e220837ac7c68cfd9',
|
'spirv_headers_revision': '3469b164e25cee24435029a569933cb42578db5d',
|
||||||
}
|
}
|
||||||
|
|
||||||
deps = {
|
deps = {
|
||||||
|
@ -285,6 +285,13 @@ typedef enum spv_operand_type_t {
|
|||||||
// An optional packed vector format
|
// An optional packed vector format
|
||||||
SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
|
SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
|
||||||
|
|
||||||
|
// Concrete operand types for cooperative matrix.
|
||||||
|
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
|
||||||
|
// An optional cooperative matrix operands
|
||||||
|
SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
|
||||||
|
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
|
||||||
|
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
|
||||||
|
|
||||||
// This is a sentinel value, and does not represent an operand type.
|
// This is a sentinel value, and does not represent an operand type.
|
||||||
// It should come last.
|
// It should come last.
|
||||||
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
|
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
|
||||||
|
@ -154,11 +154,12 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
|
|||||||
CASE(InBoundsAccessChain),
|
CASE(InBoundsAccessChain),
|
||||||
CASE(PtrAccessChain),
|
CASE(PtrAccessChain),
|
||||||
CASE(InBoundsPtrAccessChain),
|
CASE(InBoundsPtrAccessChain),
|
||||||
CASE(CooperativeMatrixLengthNV)
|
CASE(CooperativeMatrixLengthNV),
|
||||||
|
CASE(CooperativeMatrixLengthKHR)
|
||||||
};
|
};
|
||||||
|
|
||||||
// The 60 is determined by counting the opcodes listed in the spec.
|
// The 60 is determined by counting the opcodes listed in the spec.
|
||||||
static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
|
static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
|
||||||
"OpSpecConstantOp opcode table is incomplete");
|
"OpSpecConstantOp opcode table is incomplete");
|
||||||
#undef CASE
|
#undef CASE
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -691,7 +691,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
|
|||||||
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
||||||
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
||||||
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
|
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
|
||||||
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: {
|
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
|
||||||
|
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
|
||||||
// This operand is a mask.
|
// This operand is a mask.
|
||||||
|
|
||||||
// Map an optional operand type to its corresponding concrete type.
|
// Map an optional operand type to its corresponding concrete type.
|
||||||
@ -699,6 +701,8 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
|
|||||||
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
|
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
|
||||||
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
|
||||||
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
|
||||||
|
if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
|
||||||
|
parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
|
||||||
|
|
||||||
// Check validity of set mask bits. Also prepare for operands for those
|
// Check validity of set mask bits. Also prepare for operands for those
|
||||||
// masks if they have any. To get operand order correct, scan from
|
// masks if they have any. To get operand order correct, scan from
|
||||||
|
@ -274,6 +274,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
|
|||||||
case spv::Op::OpTypeArray:
|
case spv::Op::OpTypeArray:
|
||||||
case spv::Op::OpTypeStruct:
|
case spv::Op::OpTypeStruct:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -340,6 +341,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
|
|||||||
case spv::Op::OpTypeNamedBarrier:
|
case spv::Op::OpTypeNamedBarrier:
|
||||||
case spv::Op::OpTypeAccelerationStructureNV:
|
case spv::Op::OpTypeAccelerationStructureNV:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
|
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
|
||||||
// spv::Op::OpTypeAccelerationStructureNV
|
// spv::Op::OpTypeAccelerationStructureNV
|
||||||
case spv::Op::OpTypeRayQueryKHR:
|
case spv::Op::OpTypeRayQueryKHR:
|
||||||
|
@ -236,6 +236,13 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
|
|||||||
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
|
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
|
||||||
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
|
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
|
||||||
return "packed vector format";
|
return "packed vector format";
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
|
||||||
|
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
|
||||||
|
return "cooperative matrix operands";
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
|
||||||
|
return "cooperative matrix layout";
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
|
||||||
|
return "cooperative matrix use";
|
||||||
case SPV_OPERAND_TYPE_IMAGE:
|
case SPV_OPERAND_TYPE_IMAGE:
|
||||||
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
|
||||||
return "image";
|
return "image";
|
||||||
@ -369,6 +376,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
|
|||||||
case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
|
case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
|
||||||
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
|
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
|
||||||
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
|
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@ -387,6 +396,7 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
|
|||||||
case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
|
case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
|
||||||
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
|
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
|
||||||
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
|
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
|
||||||
|
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@ -405,6 +415,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
|
|||||||
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
|
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
|
||||||
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
|
||||||
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
|
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
|
||||||
|
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
|
||||||
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
|
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -423,6 +423,23 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
|
|||||||
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
|
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Type::kCooperativeMatrixKHR: {
|
||||||
|
auto coop_mat = type->AsCooperativeMatrixKHR();
|
||||||
|
uint32_t const component_type =
|
||||||
|
GetTypeInstruction(coop_mat->component_type());
|
||||||
|
if (component_type == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
typeInst = MakeUnique<Instruction>(
|
||||||
|
context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
|
||||||
|
std::initializer_list<Operand>{
|
||||||
|
{SPV_OPERAND_TYPE_ID, {component_type}},
|
||||||
|
{SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
|
||||||
|
{SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
|
||||||
|
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
|
||||||
|
{SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
assert(false && "Unexpected type");
|
assert(false && "Unexpected type");
|
||||||
break;
|
break;
|
||||||
@ -628,6 +645,14 @@ Type* TypeManager::RebuildType(const Type& type) {
|
|||||||
cm_type->columns_id());
|
cm_type->columns_id());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Type::kCooperativeMatrixKHR: {
|
||||||
|
const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
|
||||||
|
const Type* component_type = cm_type->component_type();
|
||||||
|
rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
|
||||||
|
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
|
||||||
|
cm_type->columns_id(), cm_type->use_id());
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
assert(false && "Unhandled type");
|
assert(false && "Unhandled type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -863,6 +888,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
|
|||||||
inst.GetSingleWordInOperand(2),
|
inst.GetSingleWordInOperand(2),
|
||||||
inst.GetSingleWordInOperand(3));
|
inst.GetSingleWordInOperand(3));
|
||||||
break;
|
break;
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
|
type = new CooperativeMatrixKHR(
|
||||||
|
GetType(inst.GetSingleWordInOperand(0)),
|
||||||
|
inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
|
||||||
|
inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
|
||||||
|
break;
|
||||||
case spv::Op::OpTypeRayQueryKHR:
|
case spv::Op::OpTypeRayQueryKHR:
|
||||||
type = new RayQueryKHR();
|
type = new RayQueryKHR();
|
||||||
break;
|
break;
|
||||||
|
@ -128,6 +128,7 @@ std::unique_ptr<Type> Type::Clone() const {
|
|||||||
DeclareKindCase(NamedBarrier);
|
DeclareKindCase(NamedBarrier);
|
||||||
DeclareKindCase(AccelerationStructureNV);
|
DeclareKindCase(AccelerationStructureNV);
|
||||||
DeclareKindCase(CooperativeMatrixNV);
|
DeclareKindCase(CooperativeMatrixNV);
|
||||||
|
DeclareKindCase(CooperativeMatrixKHR);
|
||||||
DeclareKindCase(RayQueryKHR);
|
DeclareKindCase(RayQueryKHR);
|
||||||
DeclareKindCase(HitObjectNV);
|
DeclareKindCase(HitObjectNV);
|
||||||
#undef DeclareKindCase
|
#undef DeclareKindCase
|
||||||
@ -175,6 +176,7 @@ bool Type::operator==(const Type& other) const {
|
|||||||
DeclareKindCase(NamedBarrier);
|
DeclareKindCase(NamedBarrier);
|
||||||
DeclareKindCase(AccelerationStructureNV);
|
DeclareKindCase(AccelerationStructureNV);
|
||||||
DeclareKindCase(CooperativeMatrixNV);
|
DeclareKindCase(CooperativeMatrixNV);
|
||||||
|
DeclareKindCase(CooperativeMatrixKHR);
|
||||||
DeclareKindCase(RayQueryKHR);
|
DeclareKindCase(RayQueryKHR);
|
||||||
DeclareKindCase(HitObjectNV);
|
DeclareKindCase(HitObjectNV);
|
||||||
#undef DeclareKindCase
|
#undef DeclareKindCase
|
||||||
@ -230,6 +232,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
|
|||||||
DeclareKindCase(NamedBarrier);
|
DeclareKindCase(NamedBarrier);
|
||||||
DeclareKindCase(AccelerationStructureNV);
|
DeclareKindCase(AccelerationStructureNV);
|
||||||
DeclareKindCase(CooperativeMatrixNV);
|
DeclareKindCase(CooperativeMatrixNV);
|
||||||
|
DeclareKindCase(CooperativeMatrixKHR);
|
||||||
DeclareKindCase(RayQueryKHR);
|
DeclareKindCase(RayQueryKHR);
|
||||||
DeclareKindCase(HitObjectNV);
|
DeclareKindCase(HitObjectNV);
|
||||||
#undef DeclareKindCase
|
#undef DeclareKindCase
|
||||||
@ -708,6 +711,45 @@ bool CooperativeMatrixNV::IsSameImpl(const Type* that,
|
|||||||
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
|
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
|
||||||
|
const uint32_t scope,
|
||||||
|
const uint32_t rows,
|
||||||
|
const uint32_t columns,
|
||||||
|
const uint32_t use)
|
||||||
|
: Type(kCooperativeMatrixKHR),
|
||||||
|
component_type_(type),
|
||||||
|
scope_id_(scope),
|
||||||
|
rows_id_(rows),
|
||||||
|
columns_id_(columns),
|
||||||
|
use_id_(use) {
|
||||||
|
assert(type != nullptr);
|
||||||
|
assert(scope != 0);
|
||||||
|
assert(rows != 0);
|
||||||
|
assert(columns != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CooperativeMatrixKHR::str() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
|
||||||
|
<< ", " << columns_id_ << ", " << use_id_ << ">";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
|
||||||
|
SeenTypes* seen) const {
|
||||||
|
hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
|
||||||
|
return component_type_->ComputeHashValue(hash, seen);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
|
||||||
|
IsSameCache* seen) const {
|
||||||
|
const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
|
||||||
|
if (!mt) return false;
|
||||||
|
return component_type_->IsSameImpl(mt->component_type_, seen) &&
|
||||||
|
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
|
||||||
|
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace analysis
|
} // namespace analysis
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
@ -60,6 +60,7 @@ class PipeStorage;
|
|||||||
class NamedBarrier;
|
class NamedBarrier;
|
||||||
class AccelerationStructureNV;
|
class AccelerationStructureNV;
|
||||||
class CooperativeMatrixNV;
|
class CooperativeMatrixNV;
|
||||||
|
class CooperativeMatrixKHR;
|
||||||
class RayQueryKHR;
|
class RayQueryKHR;
|
||||||
class HitObjectNV;
|
class HitObjectNV;
|
||||||
|
|
||||||
@ -100,6 +101,7 @@ class Type {
|
|||||||
kNamedBarrier,
|
kNamedBarrier,
|
||||||
kAccelerationStructureNV,
|
kAccelerationStructureNV,
|
||||||
kCooperativeMatrixNV,
|
kCooperativeMatrixNV,
|
||||||
|
kCooperativeMatrixKHR,
|
||||||
kRayQueryKHR,
|
kRayQueryKHR,
|
||||||
kHitObjectNV,
|
kHitObjectNV,
|
||||||
kLast
|
kLast
|
||||||
@ -201,6 +203,7 @@ class Type {
|
|||||||
DeclareCastMethod(NamedBarrier)
|
DeclareCastMethod(NamedBarrier)
|
||||||
DeclareCastMethod(AccelerationStructureNV)
|
DeclareCastMethod(AccelerationStructureNV)
|
||||||
DeclareCastMethod(CooperativeMatrixNV)
|
DeclareCastMethod(CooperativeMatrixNV)
|
||||||
|
DeclareCastMethod(CooperativeMatrixKHR)
|
||||||
DeclareCastMethod(RayQueryKHR)
|
DeclareCastMethod(RayQueryKHR)
|
||||||
DeclareCastMethod(HitObjectNV)
|
DeclareCastMethod(HitObjectNV)
|
||||||
#undef DeclareCastMethod
|
#undef DeclareCastMethod
|
||||||
@ -624,6 +627,38 @@ class CooperativeMatrixNV : public Type {
|
|||||||
const uint32_t columns_id_;
|
const uint32_t columns_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CooperativeMatrixKHR : public Type {
|
||||||
|
public:
|
||||||
|
CooperativeMatrixKHR(const Type* type, const uint32_t scope,
|
||||||
|
const uint32_t rows, const uint32_t columns,
|
||||||
|
const uint32_t use);
|
||||||
|
CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
|
||||||
|
|
||||||
|
std::string str() const override;
|
||||||
|
|
||||||
|
CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
|
||||||
|
const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
|
||||||
|
|
||||||
|
const Type* component_type() const { return component_type_; }
|
||||||
|
uint32_t scope_id() const { return scope_id_; }
|
||||||
|
uint32_t rows_id() const { return rows_id_; }
|
||||||
|
uint32_t columns_id() const { return columns_id_; }
|
||||||
|
uint32_t use_id() const { return use_id_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IsSameImpl(const Type* that, IsSameCache*) const override;
|
||||||
|
|
||||||
|
const Type* component_type_;
|
||||||
|
const uint32_t scope_id_;
|
||||||
|
const uint32_t rows_id_;
|
||||||
|
const uint32_t columns_id_;
|
||||||
|
const uint32_t use_id_;
|
||||||
|
};
|
||||||
|
|
||||||
#define DefineParameterlessType(type, name) \
|
#define DefineParameterlessType(type, name) \
|
||||||
class type : public Type { \
|
class type : public Type { \
|
||||||
public: \
|
public: \
|
||||||
|
@ -402,7 +402,8 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
|
|||||||
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
|
||||||
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
|
||||||
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
|
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
|
||||||
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: {
|
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
|
||||||
|
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
|
||||||
uint32_t value;
|
uint32_t value;
|
||||||
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
|
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
|
||||||
return context->diagnostic(error)
|
return context->diagnostic(error)
|
||||||
|
@ -42,14 +42,29 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
opcode != spv::Op::OpFMod);
|
opcode != spv::Op::OpFMod);
|
||||||
if (!_.IsFloatScalarType(result_type) &&
|
if (!_.IsFloatScalarType(result_type) &&
|
||||||
!_.IsFloatVectorType(result_type) &&
|
!_.IsFloatVectorType(result_type) &&
|
||||||
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
|
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
|
||||||
|
!(opcode == spv::Op::OpFMul &&
|
||||||
|
_.IsCooperativeMatrixKHRType(result_type) &&
|
||||||
|
_.IsFloatCooperativeMatrixType(result_type)))
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected floating scalar or vector type as Result Type: "
|
<< "Expected floating scalar or vector type as Result Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
|
|
||||||
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
||||||
++operand_index) {
|
++operand_index) {
|
||||||
if (_.GetOperandTypeId(inst, operand_index) != result_type)
|
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
|
||||||
|
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
|
||||||
|
if (!_.IsCooperativeMatrixKHRType(type_id) ||
|
||||||
|
!_.IsFloatCooperativeMatrixType(type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected arithmetic operands to be of Result Type: "
|
||||||
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
|
<< operand_index;
|
||||||
|
}
|
||||||
|
spv_result_t ret =
|
||||||
|
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
|
||||||
|
if (ret != SPV_SUCCESS) return ret;
|
||||||
|
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected arithmetic operands to be of Result Type: "
|
<< "Expected arithmetic operands to be of Result Type: "
|
||||||
<< spvOpcodeString(opcode) << " operand index "
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
@ -71,7 +86,19 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
|
|
||||||
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
||||||
++operand_index) {
|
++operand_index) {
|
||||||
if (_.GetOperandTypeId(inst, operand_index) != result_type)
|
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
|
||||||
|
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
|
||||||
|
if (!_.IsCooperativeMatrixKHRType(type_id) ||
|
||||||
|
!_.IsUnsignedIntCooperativeMatrixType(type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected arithmetic operands to be of Result Type: "
|
||||||
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
|
<< operand_index;
|
||||||
|
}
|
||||||
|
spv_result_t ret =
|
||||||
|
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
|
||||||
|
if (ret != SPV_SUCCESS) return ret;
|
||||||
|
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected arithmetic operands to be of Result Type: "
|
<< "Expected arithmetic operands to be of Result Type: "
|
||||||
<< spvOpcodeString(opcode) << " operand index "
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
@ -91,7 +118,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
|
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
|
||||||
opcode != spv::Op::OpSMod);
|
opcode != spv::Op::OpSMod);
|
||||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
|
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
|
||||||
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
|
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
|
||||||
|
!(opcode == spv::Op::OpIMul &&
|
||||||
|
_.IsCooperativeMatrixKHRType(result_type) &&
|
||||||
|
_.IsIntCooperativeMatrixType(result_type)))
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected int scalar or vector type as Result Type: "
|
<< "Expected int scalar or vector type as Result Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
@ -102,9 +132,26 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
for (size_t operand_index = 2; operand_index < inst->operands().size();
|
||||||
++operand_index) {
|
++operand_index) {
|
||||||
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
|
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
|
||||||
|
|
||||||
|
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
|
||||||
|
if (!_.IsCooperativeMatrixKHRType(type_id) ||
|
||||||
|
!_.IsIntCooperativeMatrixType(type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected arithmetic operands to be of Result Type: "
|
||||||
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
|
<< operand_index;
|
||||||
|
}
|
||||||
|
spv_result_t ret =
|
||||||
|
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
|
||||||
|
if (ret != SPV_SUCCESS) return ret;
|
||||||
|
}
|
||||||
|
|
||||||
if (!type_id ||
|
if (!type_id ||
|
||||||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
|
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
|
||||||
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
|
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
|
||||||
|
!(opcode == spv::Op::OpIMul &&
|
||||||
|
_.IsCooperativeMatrixKHRType(result_type) &&
|
||||||
|
_.IsIntCooperativeMatrixType(result_type))))
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected int scalar or vector type as operand: "
|
<< "Expected int scalar or vector type as operand: "
|
||||||
<< spvOpcodeString(opcode) << " operand index "
|
<< spvOpcodeString(opcode) << " operand index "
|
||||||
@ -187,7 +234,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
|
|
||||||
case spv::Op::OpMatrixTimesScalar: {
|
case spv::Op::OpMatrixTimesScalar: {
|
||||||
if (!_.IsFloatMatrixType(result_type) &&
|
if (!_.IsFloatMatrixType(result_type) &&
|
||||||
!_.IsCooperativeMatrixType(result_type))
|
!(_.IsCooperativeMatrixType(result_type)))
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected float matrix type as Result Type: "
|
<< "Expected float matrix type as Result Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
@ -459,22 +506,108 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
|
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
|
||||||
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
|
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
|
||||||
|
|
||||||
if (!_.IsCooperativeMatrixType(A_type_id)) {
|
if (!_.IsCooperativeMatrixNVType(A_type_id)) {
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected cooperative matrix type as A Type: "
|
<< "Expected cooperative matrix type as A Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
}
|
}
|
||||||
if (!_.IsCooperativeMatrixType(B_type_id)) {
|
if (!_.IsCooperativeMatrixNVType(B_type_id)) {
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected cooperative matrix type as B Type: "
|
<< "Expected cooperative matrix type as B Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
}
|
}
|
||||||
if (!_.IsCooperativeMatrixType(C_type_id)) {
|
if (!_.IsCooperativeMatrixNVType(C_type_id)) {
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected cooperative matrix type as C Type: "
|
<< "Expected cooperative matrix type as C Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
}
|
}
|
||||||
if (!_.IsCooperativeMatrixType(D_type_id)) {
|
if (!_.IsCooperativeMatrixNVType(D_type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected cooperative matrix type as Result Type: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto A = _.FindDef(A_type_id);
|
||||||
|
const auto B = _.FindDef(B_type_id);
|
||||||
|
const auto C = _.FindDef(C_type_id);
|
||||||
|
const auto D = _.FindDef(D_type_id);
|
||||||
|
|
||||||
|
std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
|
||||||
|
A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
|
||||||
|
|
||||||
|
A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
|
||||||
|
B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
|
||||||
|
C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
|
||||||
|
D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
|
||||||
|
|
||||||
|
A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
|
||||||
|
B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
|
||||||
|
C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
|
||||||
|
D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
|
||||||
|
|
||||||
|
A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
|
||||||
|
B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
|
||||||
|
C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
|
||||||
|
D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
|
||||||
|
|
||||||
|
const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
|
||||||
|
std::tuple<bool, bool, uint32_t> Y) {
|
||||||
|
return (std::get<1>(X) && std::get<1>(Y) &&
|
||||||
|
std::get<2>(X) != std::get<2>(Y));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
|
||||||
|
notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
|
||||||
|
notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix scopes must match: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
|
||||||
|
notEqual(C_rows, D_rows)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix 'M' mismatch: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
|
||||||
|
notEqual(C_cols, D_cols)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix 'N' mismatch: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (notEqual(A_cols, B_rows)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix 'K' mismatch: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case spv::Op::OpCooperativeMatrixMulAddKHR: {
|
||||||
|
const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
|
||||||
|
const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
|
||||||
|
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
|
||||||
|
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
|
||||||
|
|
||||||
|
if (!_.IsCooperativeMatrixAType(A_type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix type must be A Type: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
if (!_.IsCooperativeMatrixBType(B_type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix type must be B Type: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
if (!_.IsCooperativeMatrixAccType(C_type_id)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix type must be Accumulator Type: "
|
||||||
|
<< spvOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected cooperative matrix type as Result Type: "
|
<< "Expected cooperative matrix type as Result Type: "
|
||||||
<< spvOpcodeString(opcode);
|
<< spvOpcodeString(opcode);
|
||||||
|
@ -122,6 +122,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
|
|||||||
*member_type = type_inst->word(component_index + 2);
|
*member_type = type_inst->word(component_index + 2);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV: {
|
case spv::Op::OpTypeCooperativeMatrixNV: {
|
||||||
*member_type = type_inst->word(2);
|
*member_type = type_inst->word(2);
|
||||||
break;
|
break;
|
||||||
@ -335,6 +336,25 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
|
|||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR: {
|
||||||
|
const auto result_type_inst = _.FindDef(result_type);
|
||||||
|
assert(result_type_inst);
|
||||||
|
const auto component_type_id =
|
||||||
|
result_type_inst->GetOperandAs<uint32_t>(1);
|
||||||
|
|
||||||
|
if (3 != num_operands) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Must be only one constituent";
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
|
||||||
|
|
||||||
|
if (operand_type_id != component_type_id) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected Constituent type to be equal to the component type";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV: {
|
case spv::Op::OpTypeCooperativeMatrixNV: {
|
||||||
const auto result_type_inst = _.FindDef(result_type);
|
const auto result_type_inst = _.FindDef(result_type);
|
||||||
assert(result_type_inst);
|
assert(result_type_inst);
|
||||||
|
@ -243,6 +243,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV: {
|
case spv::Op::OpTypeCooperativeMatrixNV: {
|
||||||
if (1 != constituent_count) {
|
if (1 != constituent_count) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
@ -310,6 +311,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
|
|||||||
case spv::Op::OpTypeArray:
|
case spv::Op::OpTypeArray:
|
||||||
case spv::Op::OpTypeMatrix:
|
case spv::Op::OpTypeMatrix:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
case spv::Op::OpTypeVector: {
|
case spv::Op::OpTypeVector: {
|
||||||
auto base_type = _.FindDef(instruction[2]);
|
auto base_type = _.FindDef(instruction[2]);
|
||||||
return base_type && IsTypeNullable(base_type->words(), _);
|
return base_type && IsTypeNullable(base_type->words(), _);
|
||||||
|
@ -473,7 +473,10 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
const bool input_is_pointer = _.IsPointerType(input_type);
|
const bool input_is_pointer = _.IsPointerType(input_type);
|
||||||
const bool input_is_int_scalar = _.IsIntScalarType(input_type);
|
const bool input_is_int_scalar = _.IsIntScalarType(input_type);
|
||||||
|
|
||||||
if (!result_is_pointer && !result_is_int_scalar &&
|
const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
|
||||||
|
const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
|
||||||
|
|
||||||
|
if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
|
||||||
!_.IsIntVectorType(result_type) &&
|
!_.IsIntVectorType(result_type) &&
|
||||||
!_.IsFloatScalarType(result_type) &&
|
!_.IsFloatScalarType(result_type) &&
|
||||||
!_.IsFloatVectorType(result_type))
|
!_.IsFloatVectorType(result_type))
|
||||||
@ -481,13 +484,24 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
<< "Expected Result Type to be a pointer or int or float vector "
|
<< "Expected Result Type to be a pointer or int or float vector "
|
||||||
<< "or scalar type: " << spvOpcodeString(opcode);
|
<< "or scalar type: " << spvOpcodeString(opcode);
|
||||||
|
|
||||||
if (!input_is_pointer && !input_is_int_scalar &&
|
if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
|
||||||
!_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
|
!_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
|
||||||
!_.IsFloatVectorType(input_type))
|
!_.IsFloatVectorType(input_type))
|
||||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected input to be a pointer or int or float vector "
|
<< "Expected input to be a pointer or int or float vector "
|
||||||
<< "or scalar: " << spvOpcodeString(opcode);
|
<< "or scalar: " << spvOpcodeString(opcode);
|
||||||
|
|
||||||
|
if (result_is_coopmat != input_is_coopmat)
|
||||||
|
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Cooperative matrix can only be cast to another cooperative "
|
||||||
|
<< "matrix: " << spvOpcodeString(opcode);
|
||||||
|
|
||||||
|
if (result_is_coopmat) {
|
||||||
|
spv_result_t ret =
|
||||||
|
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||||
|
if (ret != SPV_SUCCESS) return ret;
|
||||||
|
}
|
||||||
|
|
||||||
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
|
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
|
||||||
_.HasExtension(kSPV_KHR_physical_storage_buffer)) {
|
_.HasExtension(kSPV_KHR_physical_storage_buffer)) {
|
||||||
const bool result_is_int_vector = _.IsIntVectorType(result_type);
|
const bool result_is_int_vector = _.IsIntVectorType(result_type);
|
||||||
|
@ -163,9 +163,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
|
|||||||
!inst->IsDebugInfo() && !inst->IsNonSemantic() &&
|
!inst->IsDebugInfo() && !inst->IsNonSemantic() &&
|
||||||
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
|
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
|
||||||
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
||||||
|
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
|
||||||
!(opcode == spv::Op::OpSpecConstantOp &&
|
!(opcode == spv::Op::OpSpecConstantOp &&
|
||||||
|
(spv::Op(inst->word(3)) ==
|
||||||
|
spv::Op::OpCooperativeMatrixLengthNV ||
|
||||||
spv::Op(inst->word(3)) ==
|
spv::Op(inst->word(3)) ==
|
||||||
spv::Op::OpCooperativeMatrixLengthNV)) {
|
spv::Op::OpCooperativeMatrixLengthKHR))) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "Operand " << _.getIdName(operand_word)
|
<< "Operand " << _.getIdName(operand_word)
|
||||||
<< " cannot be a type";
|
<< " cannot be a type";
|
||||||
@ -179,9 +182,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
|
|||||||
opcode != spv::Op::OpLoopMerge &&
|
opcode != spv::Op::OpLoopMerge &&
|
||||||
opcode != spv::Op::OpFunction &&
|
opcode != spv::Op::OpFunction &&
|
||||||
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
||||||
|
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
|
||||||
!(opcode == spv::Op::OpSpecConstantOp &&
|
!(opcode == spv::Op::OpSpecConstantOp &&
|
||||||
|
(spv::Op(inst->word(3)) ==
|
||||||
|
spv::Op::OpCooperativeMatrixLengthNV ||
|
||||||
spv::Op(inst->word(3)) ==
|
spv::Op(inst->word(3)) ==
|
||||||
spv::Op::OpCooperativeMatrixLengthNV)) {
|
spv::Op::OpCooperativeMatrixLengthKHR))) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "Operand " << _.getIdName(operand_word)
|
<< "Operand " << _.getIdName(operand_word)
|
||||||
<< " requires a type";
|
<< " requires a type";
|
||||||
|
@ -204,6 +204,7 @@ bool ContainsCooperativeMatrix(ValidationState_t& _,
|
|||||||
|
|
||||||
switch (storage->opcode()) {
|
switch (storage->opcode()) {
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
return true;
|
return true;
|
||||||
case spv::Op::OpTypeArray:
|
case spv::Op::OpTypeArray:
|
||||||
case spv::Op::OpTypeRuntimeArray:
|
case spv::Op::OpTypeRuntimeArray:
|
||||||
@ -232,6 +233,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
|
|||||||
spv::StorageClass src_sc = spv::StorageClass::Max;
|
spv::StorageClass src_sc = spv::StorageClass::Max;
|
||||||
switch (inst->opcode()) {
|
switch (inst->opcode()) {
|
||||||
case spv::Op::OpCooperativeMatrixLoadNV:
|
case spv::Op::OpCooperativeMatrixLoadNV:
|
||||||
|
case spv::Op::OpCooperativeMatrixLoadKHR:
|
||||||
case spv::Op::OpLoad: {
|
case spv::Op::OpLoad: {
|
||||||
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||||
auto load_pointer_type = _.FindDef(load_pointer->type_id());
|
auto load_pointer_type = _.FindDef(load_pointer->type_id());
|
||||||
@ -239,6 +241,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case spv::Op::OpCooperativeMatrixStoreNV:
|
case spv::Op::OpCooperativeMatrixStoreNV:
|
||||||
|
case spv::Op::OpCooperativeMatrixStoreKHR:
|
||||||
case spv::Op::OpStore: {
|
case spv::Op::OpStore: {
|
||||||
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
||||||
auto store_pointer_type = _.FindDef(store_pointer->type_id());
|
auto store_pointer_type = _.FindDef(store_pointer->type_id());
|
||||||
@ -326,7 +329,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
|||||||
const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
|
const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
|
||||||
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
|
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
|
||||||
if (inst->opcode() == spv::Op::OpLoad ||
|
if (inst->opcode() == spv::Op::OpLoad ||
|
||||||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
|
inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
|
||||||
|
inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
|
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
|
||||||
}
|
}
|
||||||
@ -1357,6 +1361,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
|
|||||||
case spv::Op::OpTypeMatrix:
|
case spv::Op::OpTypeMatrix:
|
||||||
case spv::Op::OpTypeVector:
|
case spv::Op::OpTypeVector:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
case spv::Op::OpTypeArray:
|
case spv::Op::OpTypeArray:
|
||||||
case spv::Op::OpTypeRuntimeArray: {
|
case spv::Op::OpTypeRuntimeArray: {
|
||||||
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
|
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
|
||||||
@ -1554,9 +1559,15 @@ spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
|
|||||||
<< " must be OpTypeInt with width 32 and signedness 0.";
|
<< " must be OpTypeInt with width 32 and signedness 0.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
|
||||||
auto type_id = inst->GetOperandAs<uint32_t>(2);
|
auto type_id = inst->GetOperandAs<uint32_t>(2);
|
||||||
auto type = state.FindDef(type_id);
|
auto type = state.FindDef(type_id);
|
||||||
if (type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
|
if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||||
|
return state.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "The type in " << instr_name << " <id> "
|
||||||
|
<< state.getIdName(type_id)
|
||||||
|
<< " must be OpTypeCooperativeMatrixKHR.";
|
||||||
|
} else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
|
||||||
return state.diag(SPV_ERROR_INVALID_ID, inst)
|
return state.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "The type in " << instr_name << " <id> "
|
<< "The type in " << instr_name << " <id> "
|
||||||
<< state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
|
<< state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
|
||||||
@ -1668,6 +1679,112 @@ spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
|
|||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
|
||||||
|
const Instruction* inst) {
|
||||||
|
uint32_t type_id;
|
||||||
|
const char* opname;
|
||||||
|
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
|
||||||
|
type_id = inst->type_id();
|
||||||
|
opname = "spv::Op::OpCooperativeMatrixLoadKHR";
|
||||||
|
} else {
|
||||||
|
// get Object operand's type
|
||||||
|
type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
|
||||||
|
opname = "spv::Op::OpCooperativeMatrixStoreKHR";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matrix_type = _.FindDef(type_id);
|
||||||
|
|
||||||
|
if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||||
|
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
|
||||||
|
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
|
||||||
|
} else {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
|
||||||
|
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto pointer_index =
|
||||||
|
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
|
||||||
|
const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
|
||||||
|
const auto pointer = _.FindDef(pointer_id);
|
||||||
|
if (!pointer ||
|
||||||
|
((_.addressing_model() == spv::AddressingModel::Logical) &&
|
||||||
|
((!_.features().variable_pointers &&
|
||||||
|
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
|
||||||
|
(_.features().variable_pointers &&
|
||||||
|
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< opname << " Pointer <id> " << _.getIdName(pointer_id)
|
||||||
|
<< " is not a logical pointer.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto pointer_type_id = pointer->type_id();
|
||||||
|
const auto pointer_type = _.FindDef(pointer_type_id);
|
||||||
|
if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< opname << " type for pointer <id> " << _.getIdName(pointer_id)
|
||||||
|
<< " is not a pointer type.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto storage_class_index = 1u;
|
||||||
|
const auto storage_class =
|
||||||
|
pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
|
||||||
|
|
||||||
|
if (storage_class != spv::StorageClass::Workgroup &&
|
||||||
|
storage_class != spv::StorageClass::StorageBuffer &&
|
||||||
|
storage_class != spv::StorageClass::PhysicalStorageBuffer) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< opname << " storage class for pointer type <id> "
|
||||||
|
<< _.getIdName(pointer_type_id)
|
||||||
|
<< " is not Workgroup or StorageBuffer.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
|
||||||
|
const auto pointee_type = _.FindDef(pointee_id);
|
||||||
|
if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
|
||||||
|
_.IsFloatScalarOrVectorType(pointee_id))) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< opname << " Pointer <id> " << _.getIdName(pointer->id())
|
||||||
|
<< "s Type must be a scalar or vector type.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto layout_index =
|
||||||
|
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
|
||||||
|
const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
|
||||||
|
const auto colmajor = _.FindDef(colmajor_id);
|
||||||
|
if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
|
||||||
|
!(spvOpcodeIsConstant(colmajor->opcode()) ||
|
||||||
|
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
|
||||||
|
<< " must be a 32-bit integer constant instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto stride_index =
|
||||||
|
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
|
||||||
|
if (inst->operands().size() > stride_index) {
|
||||||
|
const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
|
||||||
|
const auto stride = _.FindDef(stride_id);
|
||||||
|
if (!stride || !_.IsIntScalarType(stride->type_id())) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "Stride operand <id> " << _.getIdName(stride_id)
|
||||||
|
<< " must be a scalar integer type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto memory_access_index =
|
||||||
|
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
|
||||||
|
if (inst->operands().size() > memory_access_index) {
|
||||||
|
if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SPV_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
spv_result_t ValidatePtrComparison(ValidationState_t& _,
|
spv_result_t ValidatePtrComparison(ValidationState_t& _,
|
||||||
const Instruction* inst) {
|
const Instruction* inst) {
|
||||||
if (_.addressing_model() == spv::AddressingModel::Logical &&
|
if (_.addressing_model() == spv::AddressingModel::Logical &&
|
||||||
@ -1757,9 +1874,15 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
|
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
|
||||||
return error;
|
return error;
|
||||||
break;
|
break;
|
||||||
|
case spv::Op::OpCooperativeMatrixLengthKHR:
|
||||||
case spv::Op::OpCooperativeMatrixLengthNV:
|
case spv::Op::OpCooperativeMatrixLengthNV:
|
||||||
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
|
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
|
||||||
break;
|
break;
|
||||||
|
case spv::Op::OpCooperativeMatrixLoadKHR:
|
||||||
|
case spv::Op::OpCooperativeMatrixStoreKHR:
|
||||||
|
if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
|
||||||
|
return error;
|
||||||
|
break;
|
||||||
case spv::Op::OpPtrEqual:
|
case spv::Op::OpPtrEqual:
|
||||||
case spv::Op::OpPtrNotEqual:
|
case spv::Op::OpPtrNotEqual:
|
||||||
case spv::Op::OpPtrDiff:
|
case spv::Op::OpPtrDiff:
|
||||||
|
@ -552,7 +552,7 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
|
|||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
|
||||||
const Instruction* inst) {
|
const Instruction* inst) {
|
||||||
const auto component_type_index = 1;
|
const auto component_type_index = 1;
|
||||||
const auto component_type_id =
|
const auto component_type_id =
|
||||||
@ -561,7 +561,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
|||||||
if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
|
if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
|
||||||
spv::Op::OpTypeInt != component_type->opcode())) {
|
spv::Op::OpTypeInt != component_type->opcode())) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "OpTypeCooperativeMatrixNV Component Type <id> "
|
<< "OpTypeCooperativeMatrix Component Type <id> "
|
||||||
<< _.getIdName(component_type_id)
|
<< _.getIdName(component_type_id)
|
||||||
<< " is not a scalar numerical type.";
|
<< " is not a scalar numerical type.";
|
||||||
}
|
}
|
||||||
@ -572,7 +572,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
|||||||
if (!scope || !_.IsIntScalarType(scope->type_id()) ||
|
if (!scope || !_.IsIntScalarType(scope->type_id()) ||
|
||||||
!spvOpcodeIsConstant(scope->opcode())) {
|
!spvOpcodeIsConstant(scope->opcode())) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "OpTypeCooperativeMatrixNV Scope <id> " << _.getIdName(scope_id)
|
<< "OpTypeCooperativeMatrix Scope <id> " << _.getIdName(scope_id)
|
||||||
<< " is not a constant instruction with scalar integer type.";
|
<< " is not a constant instruction with scalar integer type.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,7 +582,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
|||||||
if (!rows || !_.IsIntScalarType(rows->type_id()) ||
|
if (!rows || !_.IsIntScalarType(rows->type_id()) ||
|
||||||
!spvOpcodeIsConstant(rows->opcode())) {
|
!spvOpcodeIsConstant(rows->opcode())) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "OpTypeCooperativeMatrixNV Rows <id> " << _.getIdName(rows_id)
|
<< "OpTypeCooperativeMatrix Rows <id> " << _.getIdName(rows_id)
|
||||||
<< " is not a constant instruction with scalar integer type.";
|
<< " is not a constant instruction with scalar integer type.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -592,10 +592,22 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
|||||||
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
|
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
|
||||||
!spvOpcodeIsConstant(cols->opcode())) {
|
!spvOpcodeIsConstant(cols->opcode())) {
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "OpTypeCooperativeMatrixNV Cols <id> " << _.getIdName(cols_id)
|
<< "OpTypeCooperativeMatrix Cols <id> " << _.getIdName(cols_id)
|
||||||
<< " is not a constant instruction with scalar integer type.";
|
<< " is not a constant instruction with scalar integer type.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||||
|
const auto use_index = 5;
|
||||||
|
const auto use_id = inst->GetOperandAs<uint32_t>(use_index);
|
||||||
|
const auto use = _.FindDef(use_id);
|
||||||
|
if (!use || !_.IsIntScalarType(use->type_id()) ||
|
||||||
|
!spvOpcodeIsConstant(use->opcode())) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "OpTypeCooperativeMatrixKHR Use <id> " << _.getIdName(use_id)
|
||||||
|
<< " is not a constant instruction with scalar integer type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -640,7 +652,8 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
|
|||||||
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
|
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
|
||||||
break;
|
break;
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
|
if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
@ -859,6 +859,7 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
|
|||||||
return GetComponentType(inst->word(2));
|
return GetComponentType(inst->word(2));
|
||||||
|
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
return inst->word(2);
|
return inst->word(2);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -886,6 +887,7 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
|
|||||||
return inst->word(3);
|
return inst->word(3);
|
||||||
|
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
// Actual dimension isn't known, return 0
|
// Actual dimension isn't known, return 0
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
@ -1142,22 +1144,68 @@ bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
|
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
|
||||||
|
const Instruction* inst = FindDef(id);
|
||||||
|
return inst && (inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV ||
|
||||||
|
inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ValidationState_t::IsCooperativeMatrixNVType(uint32_t id) const {
|
||||||
const Instruction* inst = FindDef(id);
|
const Instruction* inst = FindDef(id);
|
||||||
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
|
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ValidationState_t::IsCooperativeMatrixKHRType(uint32_t id) const {
|
||||||
|
const Instruction* inst = FindDef(id);
|
||||||
|
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
|
||||||
|
if (!IsCooperativeMatrixKHRType(id)) return false;
|
||||||
|
const Instruction* inst = FindDef(id);
|
||||||
|
uint64_t matrixUse = 0;
|
||||||
|
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
|
||||||
|
return matrixUse ==
|
||||||
|
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
|
||||||
|
if (!IsCooperativeMatrixKHRType(id)) return false;
|
||||||
|
const Instruction* inst = FindDef(id);
|
||||||
|
uint64_t matrixUse = 0;
|
||||||
|
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
|
||||||
|
return matrixUse ==
|
||||||
|
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
|
||||||
|
if (!IsCooperativeMatrixKHRType(id)) return false;
|
||||||
|
const Instruction* inst = FindDef(id);
|
||||||
|
uint64_t matrixUse = 0;
|
||||||
|
if (GetConstantValUint64(inst->word(6), &matrixUse)) {
|
||||||
|
return matrixUse == static_cast<uint64_t>(
|
||||||
|
spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
|
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
|
||||||
if (!IsCooperativeMatrixType(id)) return false;
|
if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
|
||||||
|
return false;
|
||||||
return IsFloatScalarType(FindDef(id)->word(2));
|
return IsFloatScalarType(FindDef(id)->word(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
|
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
|
||||||
if (!IsCooperativeMatrixType(id)) return false;
|
if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
|
||||||
|
return false;
|
||||||
return IsIntScalarType(FindDef(id)->word(2));
|
return IsIntScalarType(FindDef(id)->word(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
|
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
|
||||||
if (!IsCooperativeMatrixType(id)) return false;
|
if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
|
||||||
|
return false;
|
||||||
return IsUnsignedIntScalarType(FindDef(id)->word(2));
|
return IsUnsignedIntScalarType(FindDef(id)->word(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1173,8 +1221,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
|||||||
const auto m1_type = FindDef(m1);
|
const auto m1_type = FindDef(m1);
|
||||||
const auto m2_type = FindDef(m2);
|
const auto m2_type = FindDef(m2);
|
||||||
|
|
||||||
if (m1_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV ||
|
if (m1_type->opcode() != m2_type->opcode()) {
|
||||||
m2_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
|
|
||||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
<< "Expected cooperative matrix types";
|
<< "Expected cooperative matrix types";
|
||||||
}
|
}
|
||||||
@ -1224,6 +1271,21 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
|||||||
<< "identical";
|
<< "identical";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||||
|
uint32_t m1_use_id = m1_type->GetOperandAs<uint32_t>(5);
|
||||||
|
uint32_t m2_use_id = m2_type->GetOperandAs<uint32_t>(5);
|
||||||
|
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
||||||
|
EvalInt32IfConst(m1_use_id);
|
||||||
|
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
||||||
|
EvalInt32IfConst(m2_use_id);
|
||||||
|
|
||||||
|
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
||||||
|
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||||
|
<< "Expected Use of Matrix type and Result Type to be "
|
||||||
|
<< "identical";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1489,6 +1551,7 @@ bool ValidationState_t::ContainsType(
|
|||||||
case spv::Op::OpTypeImage:
|
case spv::Op::OpTypeImage:
|
||||||
case spv::Op::OpTypeSampledImage:
|
case spv::Op::OpTypeSampledImage:
|
||||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||||
|
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||||
return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
|
return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
|
||||||
traverse_all_types);
|
traverse_all_types);
|
||||||
case spv::Op::OpTypePointer:
|
case spv::Op::OpTypePointer:
|
||||||
|
@ -610,6 +610,11 @@ class ValidationState_t {
|
|||||||
bool IsPointerType(uint32_t id) const;
|
bool IsPointerType(uint32_t id) const;
|
||||||
bool IsAccelerationStructureType(uint32_t id) const;
|
bool IsAccelerationStructureType(uint32_t id) const;
|
||||||
bool IsCooperativeMatrixType(uint32_t id) const;
|
bool IsCooperativeMatrixType(uint32_t id) const;
|
||||||
|
bool IsCooperativeMatrixNVType(uint32_t id) const;
|
||||||
|
bool IsCooperativeMatrixKHRType(uint32_t id) const;
|
||||||
|
bool IsCooperativeMatrixAType(uint32_t id) const;
|
||||||
|
bool IsCooperativeMatrixBType(uint32_t id) const;
|
||||||
|
bool IsCooperativeMatrixAccType(uint32_t id) const;
|
||||||
bool IsFloatCooperativeMatrixType(uint32_t id) const;
|
bool IsFloatCooperativeMatrixType(uint32_t id) const;
|
||||||
bool IsIntCooperativeMatrixType(uint32_t id) const;
|
bool IsIntCooperativeMatrixType(uint32_t id) const;
|
||||||
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
|
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
|
||||||
|
@ -171,6 +171,7 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
|
|||||||
types.emplace_back(new NamedBarrier());
|
types.emplace_back(new NamedBarrier());
|
||||||
types.emplace_back(new AccelerationStructureNV());
|
types.emplace_back(new AccelerationStructureNV());
|
||||||
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
|
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
|
||||||
|
types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002));
|
||||||
types.emplace_back(new RayQueryKHR());
|
types.emplace_back(new RayQueryKHR());
|
||||||
types.emplace_back(new HitObjectNV());
|
types.emplace_back(new HitObjectNV());
|
||||||
|
|
||||||
@ -237,6 +238,8 @@ TEST(TypeManager, TypeStrings) {
|
|||||||
%arr_long_constant = OpTypeArray %s32 %long_constant
|
%arr_long_constant = OpTypeArray %s32 %long_constant
|
||||||
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op
|
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op
|
||||||
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
|
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
|
||||||
|
%id2 = OpConstant %u32 2
|
||||||
|
%cmkhr = OpTypeCooperativeMatrixKHR %f64 %id4 %id4 %id4 %id2
|
||||||
)";
|
)";
|
||||||
|
|
||||||
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
|
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
|
||||||
@ -275,6 +278,7 @@ TEST(TypeManager, TypeStrings) {
|
|||||||
{37, "[sint32, id(33), words(0,705032704,1)]"},
|
{37, "[sint32, id(33), words(0,705032704,1)]"},
|
||||||
{38, "[sint32, id(34), words(2,34)]"},
|
{38, "[sint32, id(34), words(2,34)]"},
|
||||||
{39, "<float64, 6, 6, 6>"},
|
{39, "<float64, 6, 6, 6>"},
|
||||||
|
{41, "<float64, 6, 6, 6, 40>"},
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<IRContext> context =
|
std::unique_ptr<IRContext> context =
|
||||||
@ -940,12 +944,15 @@ OpMemoryModel Logical GLSL450
|
|||||||
std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
|
std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
|
||||||
uint32_t id = 1u;
|
uint32_t id = 1u;
|
||||||
for (auto& t : types) {
|
for (auto& t : types) {
|
||||||
|
std::cout << ". id " << id << std::endl;
|
||||||
context->get_type_mgr()->RegisterType(id, *t);
|
context->get_type_mgr()->RegisterType(id, *t);
|
||||||
EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
|
EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
|
||||||
}
|
}
|
||||||
|
std::cout << "clear" << id << std::endl;
|
||||||
types.clear();
|
types.clear();
|
||||||
|
|
||||||
for (; id > 0; --id) {
|
for (; id > 0; --id) {
|
||||||
|
std::cout << ". remove id " << id << std::endl;
|
||||||
context->get_type_mgr()->RemoveId(id);
|
context->get_type_mgr()->RemoveId(id);
|
||||||
EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
|
EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
|
||||||
}
|
}
|
||||||
@ -1030,6 +1037,8 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
|
|||||||
; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
|
; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
|
||||||
; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
|
; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
|
||||||
; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
|
; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
|
||||||
|
; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
|
||||||
|
; CHECK: [[uint8:%\w+]] = OpConstant [[uint]] 8
|
||||||
; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
|
; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
|
||||||
; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
|
; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
|
||||||
; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
|
; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
|
||||||
@ -1085,6 +1094,7 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
|
|||||||
; CHECK: OpTypeNamedBarrier
|
; CHECK: OpTypeNamedBarrier
|
||||||
; CHECK: OpTypeAccelerationStructureKHR
|
; CHECK: OpTypeAccelerationStructureKHR
|
||||||
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
|
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
|
||||||
|
; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
|
||||||
; CHECK: OpTypeRayQueryKHR
|
; CHECK: OpTypeRayQueryKHR
|
||||||
; CHECK: OpTypeHitObjectNV
|
; CHECK: OpTypeHitObjectNV
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
@ -1094,6 +1104,8 @@ OpMemoryModel Logical GLSL450
|
|||||||
%uint = OpTypeInt 32 0
|
%uint = OpTypeInt 32 0
|
||||||
%1 = OpTypePointer Input %uint
|
%1 = OpTypePointer Input %uint
|
||||||
%2 = OpTypePointer Uniform %uint
|
%2 = OpTypePointer Uniform %uint
|
||||||
|
%1002 = OpConstant %uint 2
|
||||||
|
%8 = OpConstant %uint 8
|
||||||
%24 = OpConstant %uint 24
|
%24 = OpConstant %uint 24
|
||||||
%42 = OpConstant %uint 42
|
%42 = OpConstant %uint 42
|
||||||
%100 = OpConstant %uint 100
|
%100 = OpConstant %uint 100
|
||||||
|
@ -1318,7 +1318,7 @@ TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
|
|||||||
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
|
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
EXPECT_THAT(getDiagnosticString(),
|
EXPECT_THAT(getDiagnosticString(),
|
||||||
HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
|
HasSubstr("OpTypeCooperativeMatrix Component Type <id> "
|
||||||
"'4[%bool]' is not a scalar numerical type."));
|
"'4[%bool]' is not a scalar numerical type."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1331,7 +1331,7 @@ TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
|
|||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
getDiagnosticString(),
|
getDiagnosticString(),
|
||||||
HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
|
HasSubstr("OpTypeCooperativeMatrix Scope <id> '17[%float_1]' is not a "
|
||||||
"constant instruction with scalar integer type."));
|
"constant instruction with scalar integer type."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1344,7 +1344,7 @@ TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
|
|||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
getDiagnosticString(),
|
getDiagnosticString(),
|
||||||
HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
|
HasSubstr("OpTypeCooperativeMatrix Rows <id> '17[%float_1]' is not a "
|
||||||
"constant instruction with scalar integer type."));
|
"constant instruction with scalar integer type."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1357,7 +1357,7 @@ TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
|
|||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
getDiagnosticString(),
|
getDiagnosticString(),
|
||||||
HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
|
HasSubstr("OpTypeCooperativeMatrix Cols <id> '17[%float_1]' is not a "
|
||||||
"constant instruction with scalar integer type."));
|
"constant instruction with scalar integer type."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1469,6 +1469,146 @@ TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
|
|||||||
"SMulExtended"));
|
"SMulExtended"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GenerateCoopMatKHRCode(const std::string& extra_types,
|
||||||
|
const std::string& main_body) {
|
||||||
|
const std::string prefix = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%s32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_16 = OpConstant %u32 16
|
||||||
|
%u32_4 = OpConstant %u32 4
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
%useA = OpConstant %u32 0
|
||||||
|
%useB = OpConstant %u32 1
|
||||||
|
%useC = OpConstant %u32 2
|
||||||
|
|
||||||
|
%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
|
||||||
|
%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useA
|
||||||
|
%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useA
|
||||||
|
|
||||||
|
%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useB
|
||||||
|
%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useB
|
||||||
|
%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useB
|
||||||
|
|
||||||
|
%f16matC = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useC
|
||||||
|
%f32matC = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useC
|
||||||
|
%u32matC = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useC
|
||||||
|
%s32matC = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useC
|
||||||
|
|
||||||
|
%f16_1 = OpConstant %f16 1
|
||||||
|
%f32_1 = OpConstant %f32 1
|
||||||
|
%u32_1 = OpConstant %u32 1
|
||||||
|
%s32_1 = OpConstant %s32 1
|
||||||
|
|
||||||
|
%f16mat_A_1 = OpConstantComposite %f16matA %f16_1
|
||||||
|
%u32mat_A_1 = OpConstantComposite %u32matA %u32_1
|
||||||
|
%s32mat_A_1 = OpConstantComposite %s32matA %s32_1
|
||||||
|
|
||||||
|
%f16mat_B_1 = OpConstantComposite %f16matB %f16_1
|
||||||
|
%u32mat_B_1 = OpConstantComposite %u32matB %u32_1
|
||||||
|
%s32mat_B_1 = OpConstantComposite %s32matB %s32_1
|
||||||
|
|
||||||
|
%f16mat_C_1 = OpConstantComposite %f16matC %f16_1
|
||||||
|
%u32mat_C_1 = OpConstantComposite %u32matC %u32_1
|
||||||
|
%s32mat_C_1 = OpConstantComposite %s32matC %s32_1
|
||||||
|
|
||||||
|
)";
|
||||||
|
|
||||||
|
const std::string func_begin = R"(
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel)";
|
||||||
|
|
||||||
|
const std::string suffix = R"(
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
return prefix + extra_types + func_begin + main_body + suffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateArithmetics, CoopMatKHRSuccess) {
|
||||||
|
const std::string body = R"(
|
||||||
|
%val1 = OpFAdd %f16matA %f16mat_A_1 %f16mat_A_1
|
||||||
|
%val2 = OpFSub %f16matA %f16mat_A_1 %f16mat_A_1
|
||||||
|
%val3 = OpFMul %f16matA %f16mat_A_1 %f16mat_A_1
|
||||||
|
%val4 = OpFDiv %f16matA %f16mat_A_1 %f16mat_A_1
|
||||||
|
%val5 = OpFNegate %f16matA %f16mat_A_1
|
||||||
|
%val6 = OpIAdd %u32matA %u32mat_A_1 %u32mat_A_1
|
||||||
|
%val7 = OpISub %u32matA %u32mat_A_1 %u32mat_A_1
|
||||||
|
%val8 = OpUDiv %u32matA %u32mat_A_1 %u32mat_A_1
|
||||||
|
%val9 = OpIAdd %s32matA %s32mat_A_1 %s32mat_A_1
|
||||||
|
%val10 = OpISub %s32matA %s32mat_A_1 %s32mat_A_1
|
||||||
|
%val11 = OpSDiv %s32matA %s32mat_A_1 %s32mat_A_1
|
||||||
|
%val12 = OpSNegate %s32matA %s32mat_A_1
|
||||||
|
%val13 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f16_1
|
||||||
|
%val14 = OpMatrixTimesScalar %u32matA %u32mat_A_1 %u32_1
|
||||||
|
%val15 = OpMatrixTimesScalar %s32matA %s32mat_A_1 %s32_1
|
||||||
|
%val16 = OpCooperativeMatrixMulAddKHR %f32matC %f16mat_A_1 %f16mat_B_1 %f16mat_C_1
|
||||||
|
%val17 = OpCooperativeMatrixMulAddKHR %s32matC %s32mat_A_1 %s32mat_B_1 %s32mat_C_1)";
|
||||||
|
|
||||||
|
CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
|
||||||
|
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
%val1 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f32_1
|
||||||
|
)";
|
||||||
|
|
||||||
|
CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Expected scalar operand type to be equal to the component "
|
||||||
|
"type of the matrix operand: MatrixTimesScalar"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
|
||||||
|
const std::string types = R"(
|
||||||
|
%workgroup = OpConstant %u32 2
|
||||||
|
%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC
|
||||||
|
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
|
||||||
|
)";
|
||||||
|
|
||||||
|
const std::string body = R"(
|
||||||
|
%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1
|
||||||
|
)";
|
||||||
|
|
||||||
|
CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
|
||||||
|
const std::string types = R"(
|
||||||
|
%mat16x4 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_4 %useC
|
||||||
|
%mat16x4_C_1 = OpConstantComposite %mat16x4 %f16_1
|
||||||
|
)";
|
||||||
|
|
||||||
|
const std::string body = R"(
|
||||||
|
%val1 = OpCooperativeMatrixMulAddKHR %mat16x4 %f16mat_A_1 %f16mat_B_1 %mat16x4_C_1
|
||||||
|
)";
|
||||||
|
|
||||||
|
CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace val
|
} // namespace val
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
@ -1486,8 +1486,7 @@ OpFunctionEnd
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
|
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
|
||||||
const std::string body =
|
const std::string body = R"(
|
||||||
R"(
|
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability CooperativeMatrixNV
|
OpCapability CooperativeMatrixNV
|
||||||
@ -1525,8 +1524,7 @@ OpFunctionEnd)";
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
|
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
|
||||||
const std::string body =
|
const std::string body = R"(
|
||||||
R"(
|
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability CooperativeMatrixNV
|
OpCapability CooperativeMatrixNV
|
||||||
@ -1562,6 +1560,86 @@ OpFunctionEnd)";
|
|||||||
HasSubstr("Expected Constituent type to be equal to the component type"));
|
HasSubstr("Expected Constituent type to be equal to the component type"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateComposites, CoopMatKHRConstantCompositeMismatchFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
|
||||||
|
%u32_16 = OpConstant %u32 16
|
||||||
|
%useA = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
|
||||||
|
|
||||||
|
%f32_1 = OpConstant %f32 1
|
||||||
|
|
||||||
|
%f16mat_1 = OpConstantComposite %f16mat %f32_1
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr(
|
||||||
|
"OpConstantComposite Constituent <id> '12[%float_1]' type "
|
||||||
|
"does not match the Result Type <id> '11[%11]'s component type."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateComposites, CoopMatKHRCompositeConstructMismatchFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
|
||||||
|
%u32_16 = OpConstant %u32 16
|
||||||
|
%useA = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
|
||||||
|
|
||||||
|
%f32_1 = OpConstant %f32 1
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Expected Constituent type to be equal to the component type"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
|
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
|
||||||
const std::string spirv = R"(
|
const std::string spirv = R"(
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
|
@ -1149,8 +1149,7 @@ OpFunctionEnd)";
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
|
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
|
||||||
const std::string body =
|
const std::string body = R"(
|
||||||
R"(
|
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Int16
|
OpCapability Int16
|
||||||
@ -1191,6 +1190,179 @@ OpFunctionEnd)";
|
|||||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateConversion, CoopMatKHRConversionSuccess) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u16 = OpTypeInt 16 0
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%s16 = OpTypeInt 16 1
|
||||||
|
%s32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%u16mat = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%u32mat = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%s16mat = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%s32mat = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%f16_1 = OpConstant %f16 1
|
||||||
|
%f32_1 = OpConstant %f32 1
|
||||||
|
%u16_1 = OpConstant %u16 1
|
||||||
|
%u32_1 = OpConstant %u32 1
|
||||||
|
%s16_1 = OpConstant %s16 1
|
||||||
|
%s32_1 = OpConstant %s32 1
|
||||||
|
|
||||||
|
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||||
|
%f32mat_1 = OpConstantComposite %f32mat %f32_1
|
||||||
|
%u16mat_1 = OpConstantComposite %u16mat %u16_1
|
||||||
|
%u32mat_1 = OpConstantComposite %u32mat %u32_1
|
||||||
|
%s16mat_1 = OpConstantComposite %s16mat %s16_1
|
||||||
|
%s32mat_1 = OpConstantComposite %s32mat %s32_1
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%val11 = OpConvertFToU %u16mat %f16mat_1
|
||||||
|
%val12 = OpConvertFToU %u32mat %f16mat_1
|
||||||
|
%val13 = OpConvertFToS %s16mat %f16mat_1
|
||||||
|
%val14 = OpConvertFToS %s32mat %f16mat_1
|
||||||
|
%val15 = OpFConvert %f32mat %f16mat_1
|
||||||
|
|
||||||
|
%val21 = OpConvertFToU %u16mat %f32mat_1
|
||||||
|
%val22 = OpConvertFToU %u32mat %f32mat_1
|
||||||
|
%val23 = OpConvertFToS %s16mat %f32mat_1
|
||||||
|
%val24 = OpConvertFToS %s32mat %f32mat_1
|
||||||
|
%val25 = OpFConvert %f16mat %f32mat_1
|
||||||
|
|
||||||
|
%val31 = OpConvertUToF %f16mat %u16mat_1
|
||||||
|
%val32 = OpConvertUToF %f32mat %u16mat_1
|
||||||
|
%val33 = OpUConvert %u32mat %u16mat_1
|
||||||
|
%val34 = OpSConvert %s32mat %u16mat_1
|
||||||
|
|
||||||
|
%val41 = OpConvertSToF %f16mat %s16mat_1
|
||||||
|
%val42 = OpConvertSToF %f32mat %s16mat_1
|
||||||
|
%val43 = OpUConvert %u32mat %s16mat_1
|
||||||
|
%val44 = OpSConvert %s32mat %s16mat_1
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateConversion, CoopMatKHRConversionUseMismatchFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u16 = OpTypeInt 16 0
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%s16 = OpTypeInt 16 1
|
||||||
|
%s32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%u32_4 = OpConstant %u32 4
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%use_B = OpConstant %u32 1
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B
|
||||||
|
|
||||||
|
%f16_1 = OpConstant %f16 1
|
||||||
|
|
||||||
|
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%val1 = OpFConvert %f32mat %f16mat_1
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Expected Use of Matrix type and Result Type to be identical"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateConversion, CoopMatKHRConversionScopeMismatchFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%bool = OpTypeBool
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%f32 = OpTypeFloat 32
|
||||||
|
%u16 = OpTypeInt 16 0
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%s16 = OpTypeInt 16 1
|
||||||
|
%s32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%u32_4 = OpConstant %u32 4
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
%workgroup = OpConstant %u32 2
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%f16_1 = OpConstant %f16 1
|
||||||
|
|
||||||
|
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%val1 = OpFConvert %f32mat %f16mat_1
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ValidateConversion, BitcastSuccess) {
|
TEST_F(ValidateConversion, BitcastSuccess) {
|
||||||
const std::string body = R"(
|
const std::string body = R"(
|
||||||
%ptr = OpVariable %f32ptr_func Function
|
%ptr = OpVariable %f32ptr_func Function
|
||||||
|
@ -23,12 +23,14 @@
|
|||||||
#include "test/val/val_fixtures.h"
|
#include "test/val/val_fixtures.h"
|
||||||
|
|
||||||
// For pretty-printing tuples with spv_target_env.
|
// For pretty-printing tuples with spv_target_env.
|
||||||
std::ostream& operator<<(std::ostream& stream, spv_target_env target)
|
std::ostream& operator<<(std::ostream& stream, spv_target_env target) {
|
||||||
{
|
|
||||||
switch (target) {
|
switch (target) {
|
||||||
case SPV_ENV_UNIVERSAL_1_3: return stream << "SPV_ENV_UNIVERSAL_1_3";
|
case SPV_ENV_UNIVERSAL_1_3:
|
||||||
case SPV_ENV_UNIVERSAL_1_4: return stream << "SPV_ENV_UNIVERSAL_1_4";
|
return stream << "SPV_ENV_UNIVERSAL_1_3";
|
||||||
default: return stream << (unsigned)target;
|
case SPV_ENV_UNIVERSAL_1_4:
|
||||||
|
return stream << "SPV_ENV_UNIVERSAL_1_4";
|
||||||
|
default:
|
||||||
|
return stream << (unsigned)target;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2346,6 +2348,186 @@ OpFunctionEnd)";
|
|||||||
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
|
||||||
|
std::string spirv =
|
||||||
|
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
|
||||||
|
"MakePointerVisibleKHR|NonPrivatePointerKHR");
|
||||||
|
|
||||||
|
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
|
||||||
|
std::string spirv =
|
||||||
|
GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
|
||||||
|
"MakePointerVisibleKHR|NonPrivatePointerKHR");
|
||||||
|
|
||||||
|
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||||
|
EXPECT_THAT(getDiagnosticString(),
|
||||||
|
HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) {
|
||||||
|
std::string spirv =
|
||||||
|
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
|
||||||
|
"MakePointerAvailableKHR|NonPrivatePointerKHR");
|
||||||
|
|
||||||
|
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||||
|
EXPECT_THAT(getDiagnosticString(),
|
||||||
|
HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatKHRInvalidStorageClassFail) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%str = OpTypeStruct %f16mat
|
||||||
|
%str_ptr = OpTypePointer Workgroup %str
|
||||||
|
%sh = OpVariable %str_ptr Workgroup
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr(
|
||||||
|
"Cooperative matrix types (or types containing them) can only be "
|
||||||
|
"allocated in Function or Private storage classes or as function "
|
||||||
|
"parameters"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthResultTypeBad) {
|
||||||
|
const std::string body = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%i32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%1 = OpCooperativeMatrixLengthKHR %i32 %f16mat
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("The Result Type of OpCooperativeMatrixLengthKHR <id> "
|
||||||
|
"'12[%12]' must be OpTypeInt with width 32 and signedness 0"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthOperandTypeBad) {
|
||||||
|
const std::string body =
|
||||||
|
R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%i32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%1 = OpCooperativeMatrixLengthKHR %u32 %u32
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||||
|
EXPECT_THAT(
|
||||||
|
getDiagnosticString(),
|
||||||
|
HasSubstr("The type in OpCooperativeMatrixLengthKHR <id> '5[%uint]' "
|
||||||
|
"must be OpTypeCooperativeMatrixKHR"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ValidateMemory, CoopMatMatrixKHRLengthGood) {
|
||||||
|
const std::string body =
|
||||||
|
R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability CooperativeMatrixKHR
|
||||||
|
OpExtension "SPV_KHR_cooperative_matrix"
|
||||||
|
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%func = OpTypeFunction %void
|
||||||
|
%f16 = OpTypeFloat 16
|
||||||
|
%u32 = OpTypeInt 32 0
|
||||||
|
%i32 = OpTypeInt 32 1
|
||||||
|
|
||||||
|
%u32_8 = OpConstant %u32 8
|
||||||
|
%use_A = OpConstant %u32 0
|
||||||
|
%subgroup = OpConstant %u32 3
|
||||||
|
|
||||||
|
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
|
||||||
|
|
||||||
|
%main = OpFunction %void None %func
|
||||||
|
%main_entry = OpLabel
|
||||||
|
|
||||||
|
%1 = OpCooperativeMatrixLengthKHR %u32 %f16mat
|
||||||
|
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd)";
|
||||||
|
|
||||||
|
CompileSuccessfully(body.c_str());
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
|
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
|
||||||
std::string spirv = R"(
|
std::string spirv = R"(
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
@ -3765,9 +3947,8 @@ OpFunctionEnd
|
|||||||
HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks"));
|
HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks"));
|
||||||
}
|
}
|
||||||
|
|
||||||
using ValidateSizedVariable =
|
using ValidateSizedVariable = spvtest::ValidateBase<
|
||||||
spvtest::ValidateBase<std::tuple<std::string, std::string,
|
std::tuple<std::string, std::string, std::string, spv_target_env>>;
|
||||||
std::string, spv_target_env>>;
|
|
||||||
|
|
||||||
CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
|
CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
|
||||||
CodeGenerator generator;
|
CodeGenerator generator;
|
||||||
@ -3777,7 +3958,8 @@ CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
|
|||||||
"\"SPV_KHR_8bit_storage\"\n";
|
"\"SPV_KHR_8bit_storage\"\n";
|
||||||
generator.memory_model_ = "OpMemoryModel Logical GLSL450\n";
|
generator.memory_model_ = "OpMemoryModel Logical GLSL450\n";
|
||||||
if (is_8bit) {
|
if (is_8bit) {
|
||||||
generator.before_types_ = "OpMemberDecorate %char_buffer_block 0 Offset 0\n";
|
generator.before_types_ =
|
||||||
|
"OpMemberDecorate %char_buffer_block 0 Offset 0\n";
|
||||||
if (buffer_block)
|
if (buffer_block)
|
||||||
generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n";
|
generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n";
|
||||||
|
|
||||||
|
@ -540,7 +540,7 @@ def generate_operand_kind_table(enums):
|
|||||||
|
|
||||||
# We have a few operand kinds that require their optional counterpart to
|
# We have a few operand kinds that require their optional counterpart to
|
||||||
# exist in the operand info table.
|
# exist in the operand info table.
|
||||||
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat']
|
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands']
|
||||||
optional_enums = [e for e in enums if e[0] in optional_enums]
|
optional_enums = [e for e in enums if e[0] in optional_enums]
|
||||||
enums.extend(optional_enums)
|
enums.extend(optional_enums)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user