From 3c7ff8d4f0a1c0f27328871fe64879170a4f0930 Mon Sep 17 00:00:00 2001 From: Jeremy Hayes Date: Mon, 7 Oct 2019 13:52:48 +0000 Subject: [PATCH] Enable OpTypeCooperativeMatrix specialization (#2927) --- source/opt/type_manager.cpp | 30 +++++++++++++++++++++++++ source/opt/types.cpp | 41 ++++++++++++++++++++++++++++++++++ source/opt/types.h | 33 +++++++++++++++++++++++++++ test/opt/type_manager_test.cpp | 7 +++++- 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index d34948194..166b8281f 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -409,6 +409,22 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { {static_cast( type->AsForwardPointer()->storage_class())}}}); break; + case Type::kCooperativeMatrixNV: { + auto coop_mat = type->AsCooperativeMatrixNV(); + uint32_t const component_type = + GetTypeInstruction(coop_mat->component_type()); + if (component_type == 0) { + return 0; + } + typeInst = MakeUnique( + context(), SpvOpTypeCooperativeMatrixNV, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {component_type}}, + {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}}, + {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}}, + {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}}); + break; + } default: assert(false && "Unexpected type"); break; @@ -604,6 +620,14 @@ Type* TypeManager::RebuildType(const Type& type) { } break; } + case Type::kCooperativeMatrixNV: { + const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV(); + const Type* component_type = cm_type->component_type(); + rebuilt_ty = MakeUnique( + RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(), + cm_type->columns_id()); + break; + } default: assert(false && "Unhandled type"); return nullptr; @@ -832,6 +856,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { case SpvOpTypeAccelerationStructureNV: type = new AccelerationStructureNV(); break; + case SpvOpTypeCooperativeMatrixNV: + type = new CooperativeMatrixNV(GetType(inst.GetSingleWordInOperand(0)), + inst.GetSingleWordInOperand(1), + inst.GetSingleWordInOperand(2), + inst.GetSingleWordInOperand(3)); + break; default: SPIRV_UNIMPLEMENTED(consumer_, "unhandled type"); break; diff --git a/source/opt/types.cpp b/source/opt/types.cpp index 4f7150fc3..17f8fe920 100644 --- a/source/opt/types.cpp +++ b/source/opt/types.cpp @@ -127,6 +127,7 @@ std::unique_ptr Type::Clone() const { DeclareKindCase(PipeStorage); DeclareKindCase(NamedBarrier); DeclareKindCase(AccelerationStructureNV); + DeclareKindCase(CooperativeMatrixNV); #undef DeclareKindCase default: assert(false && "Unhandled type"); @@ -171,6 +172,7 @@ bool Type::operator==(const Type& other) const { DeclareKindCase(PipeStorage); DeclareKindCase(NamedBarrier); DeclareKindCase(AccelerationStructureNV); + DeclareKindCase(CooperativeMatrixNV); #undef DeclareKindCase default: assert(false && "Unhandled type"); @@ -220,6 +222,7 @@ void Type::GetHashWords(std::vector* words, DeclareKindCase(PipeStorage); DeclareKindCase(NamedBarrier); DeclareKindCase(AccelerationStructureNV); + DeclareKindCase(CooperativeMatrixNV); #undef DeclareKindCase default: assert(false && "Unhandled type"); @@ -654,6 +657,44 @@ void ForwardPointer::GetExtraHashWords( if (pointer_) pointer_->GetHashWords(words, seen); } +CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope, + const uint32_t rows, + const uint32_t columns) + : Type(kCooperativeMatrixNV), + component_type_(type), + scope_id_(scope), + rows_id_(rows), + columns_id_(columns) { + assert(type != nullptr); + assert(scope != 0); + assert(rows != 0); + assert(columns != 0); +} + +std::string CooperativeMatrixNV::str() const { + std::ostringstream oss; + oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_ + << ", " << columns_id_ << ">"; + return oss.str(); +} + +void CooperativeMatrixNV::GetExtraHashWords( + std::vector* words, std::unordered_set* pSet) const { + component_type_->GetHashWords(words, pSet); + words->push_back(scope_id_); + words->push_back(rows_id_); + words->push_back(columns_id_); +} + +bool CooperativeMatrixNV::IsSameImpl(const Type* that, + IsSameCache* seen) const { + const CooperativeMatrixNV* mt = that->AsCooperativeMatrixNV(); + if (!mt) return false; + return component_type_->IsSameImpl(mt->component_type_, seen) && + scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ && + columns_id_ == mt->columns_id_ && HasSameDecorations(that); +} + } // namespace analysis } // namespace opt } // namespace spvtools diff --git a/source/opt/types.h b/source/opt/types.h index 57920df96..69071ea17 100644 --- a/source/opt/types.h +++ b/source/opt/types.h @@ -58,6 +58,7 @@ class ForwardPointer; class PipeStorage; class NamedBarrier; class AccelerationStructureNV; +class CooperativeMatrixNV; // Abstract class for a SPIR-V type. It has a bunch of As() methods, // which is used as a way to probe the actual . @@ -93,6 +94,7 @@ class Type { kPipeStorage, kNamedBarrier, kAccelerationStructureNV, + kCooperativeMatrixNV }; Type(Kind k) : kind_(k) {} @@ -196,6 +198,7 @@ class Type { DeclareCastMethod(PipeStorage) DeclareCastMethod(NamedBarrier) DeclareCastMethod(AccelerationStructureNV) + DeclareCastMethod(CooperativeMatrixNV) #undef DeclareCastMethod protected: @@ -597,6 +600,36 @@ class ForwardPointer : public Type { const Pointer* pointer_; }; +class CooperativeMatrixNV : public Type { + public: + CooperativeMatrixNV(const Type* type, const uint32_t scope, + const uint32_t rows, const uint32_t columns); + CooperativeMatrixNV(const CooperativeMatrixNV&) = default; + + std::string str() const override; + + CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; } + const CooperativeMatrixNV* AsCooperativeMatrixNV() const override { + return this; + } + + void GetExtraHashWords(std::vector*, + std::unordered_set*) const override; + + const Type* component_type() const { return component_type_; } + uint32_t scope_id() const { return scope_id_; } + uint32_t rows_id() const { return rows_id_; } + uint32_t columns_id() const { return columns_id_; } + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* component_type_; + const uint32_t scope_id_; + const uint32_t rows_id_; + const uint32_t columns_id_; +}; + #define DefineParameterlessType(type, name) \ class type : public Type { \ public: \ diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index 267d98c3f..743d0b616 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -156,7 +156,8 @@ std::vector> GenerateAllTypes() { types.emplace_back(new ReserveId()); types.emplace_back(new Queue()); - // Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV + // Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV, + // CooperativeMatrixNV types.emplace_back(new Pipe(SpvAccessQualifierReadWrite)); types.emplace_back(new Pipe(SpvAccessQualifierReadOnly)); types.emplace_back(new ForwardPointer(1, SpvStorageClassInput)); @@ -165,6 +166,7 @@ std::vector> GenerateAllTypes() { types.emplace_back(new PipeStorage()); types.emplace_back(new NamedBarrier()); types.emplace_back(new AccelerationStructureNV()); + types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24)); return types; } @@ -214,6 +216,7 @@ TEST(TypeManager, TypeStrings) { %arr_spec_const_with_id = OpTypeArray %s32 %spec_const_with_id %arr_long_constant = OpTypeArray %s32 %long_constant %arr_spec_const_op = OpTypeArray %s32 %spec_const_op + %cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4 )"; std::vector> type_id_strs = { @@ -251,6 +254,7 @@ TEST(TypeManager, TypeStrings) { {36, "[sint32, id(1), words(1,99,42)]"}, {37, "[sint32, id(33), words(0,705032704,1)]"}, {38, "[sint32, id(34), words(2,34)]"}, + {39, ""}, }; std::unique_ptr context = @@ -1060,6 +1064,7 @@ TEST(TypeManager, GetTypeInstructionAllTypes) { ; CHECK: OpTypePipeStorage ; CHECK: OpTypeNamedBarrier ; CHECK: OpTypeAccelerationStructureNV +; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]] OpCapability Shader OpCapability Int64 OpCapability Linkage