Enable OpTypeCooperativeMatrix specialization (#2927)

This commit is contained in:
Jeremy Hayes 2019-10-07 13:52:48 +00:00 committed by Steven Perron
parent c18c9ff6bc
commit 3c7ff8d4f0
4 changed files with 110 additions and 1 deletions

View File

@ -409,6 +409,22 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
{static_cast<uint32_t>(
type->AsForwardPointer()->storage_class())}}});
break;
case Type::kCooperativeMatrixNV: {
auto coop_mat = type->AsCooperativeMatrixNV();
uint32_t const component_type =
GetTypeInstruction(coop_mat->component_type());
if (component_type == 0) {
return 0;
}
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeCooperativeMatrixNV, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {component_type}},
{SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
{SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
break;
}
default:
assert(false && "Unexpected type");
break;
@ -604,6 +620,14 @@ Type* TypeManager::RebuildType(const Type& type) {
}
break;
}
case Type::kCooperativeMatrixNV: {
const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixNV>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id());
break;
}
default:
assert(false && "Unhandled type");
return nullptr;
@ -832,6 +856,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case SpvOpTypeAccelerationStructureNV:
type = new AccelerationStructureNV();
break;
case SpvOpTypeCooperativeMatrixNV:
type = new CooperativeMatrixNV(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1),
inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3));
break;
default:
SPIRV_UNIMPLEMENTED(consumer_, "unhandled type");
break;

View File

@ -127,6 +127,7 @@ std::unique_ptr<Type> Type::Clone() const {
DeclareKindCase(PipeStorage);
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
@ -171,6 +172,7 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(PipeStorage);
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
@ -220,6 +222,7 @@ void Type::GetHashWords(std::vector<uint32_t>* words,
DeclareKindCase(PipeStorage);
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
@ -654,6 +657,44 @@ void ForwardPointer::GetExtraHashWords(
if (pointer_) pointer_->GetHashWords(words, seen);
}
CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope,
const uint32_t rows,
const uint32_t columns)
: Type(kCooperativeMatrixNV),
component_type_(type),
scope_id_(scope),
rows_id_(rows),
columns_id_(columns) {
assert(type != nullptr);
assert(scope != 0);
assert(rows != 0);
assert(columns != 0);
}
std::string CooperativeMatrixNV::str() const {
std::ostringstream oss;
oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
<< ", " << columns_id_ << ">";
return oss.str();
}
void CooperativeMatrixNV::GetExtraHashWords(
std::vector<uint32_t>* words, std::unordered_set<const Type*>* pSet) const {
component_type_->GetHashWords(words, pSet);
words->push_back(scope_id_);
words->push_back(rows_id_);
words->push_back(columns_id_);
}
bool CooperativeMatrixNV::IsSameImpl(const Type* that,
IsSameCache* seen) const {
const CooperativeMatrixNV* mt = that->AsCooperativeMatrixNV();
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
}
} // namespace analysis
} // namespace opt
} // namespace spvtools

View File

@ -58,6 +58,7 @@ class ForwardPointer;
class PipeStorage;
class NamedBarrier;
class AccelerationStructureNV;
class CooperativeMatrixNV;
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
// which is used as a way to probe the actual <subclass>.
@ -93,6 +94,7 @@ class Type {
kPipeStorage,
kNamedBarrier,
kAccelerationStructureNV,
kCooperativeMatrixNV
};
Type(Kind k) : kind_(k) {}
@ -196,6 +198,7 @@ class Type {
DeclareCastMethod(PipeStorage)
DeclareCastMethod(NamedBarrier)
DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV)
#undef DeclareCastMethod
protected:
@ -597,6 +600,36 @@ class ForwardPointer : public Type {
const Pointer* pointer_;
};
class CooperativeMatrixNV : public Type {
public:
CooperativeMatrixNV(const Type* type, const uint32_t scope,
const uint32_t rows, const uint32_t columns);
CooperativeMatrixNV(const CooperativeMatrixNV&) = default;
std::string str() const override;
CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; }
const CooperativeMatrixNV* AsCooperativeMatrixNV() const override {
return this;
}
void GetExtraHashWords(std::vector<uint32_t>*,
std::unordered_set<const Type*>*) const override;
const Type* component_type() const { return component_type_; }
uint32_t scope_id() const { return scope_id_; }
uint32_t rows_id() const { return rows_id_; }
uint32_t columns_id() const { return columns_id_; }
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* component_type_;
const uint32_t scope_id_;
const uint32_t rows_id_;
const uint32_t columns_id_;
};
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \

View File

@ -156,7 +156,8 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new ReserveId());
types.emplace_back(new Queue());
// Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV
// Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV,
// CooperativeMatrixNV
types.emplace_back(new Pipe(SpvAccessQualifierReadWrite));
types.emplace_back(new Pipe(SpvAccessQualifierReadOnly));
types.emplace_back(new ForwardPointer(1, SpvStorageClassInput));
@ -165,6 +166,7 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new PipeStorage());
types.emplace_back(new NamedBarrier());
types.emplace_back(new AccelerationStructureNV());
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
return types;
}
@ -214,6 +216,7 @@ TEST(TypeManager, TypeStrings) {
%arr_spec_const_with_id = OpTypeArray %s32 %spec_const_with_id
%arr_long_constant = OpTypeArray %s32 %long_constant
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
)";
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@ -251,6 +254,7 @@ TEST(TypeManager, TypeStrings) {
{36, "[sint32, id(1), words(1,99,42)]"},
{37, "[sint32, id(33), words(0,705032704,1)]"},
{38, "[sint32, id(34), words(2,34)]"},
{39, "<float64, 6, 6, 6>"},
};
std::unique_ptr<IRContext> context =
@ -1060,6 +1064,7 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
; CHECK: OpTypePipeStorage
; CHECK: OpTypeNamedBarrier
; CHECK: OpTypeAccelerationStructureNV
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
OpCapability Shader
OpCapability Int64
OpCapability Linkage