Add validation for SPV_NV_tensor_addressing and SPV_NV_cooperative_matrix2 (#5865)

This commit is contained in:
Jeff Bolz 2024-10-24 12:55:04 -05:00 committed by GitHub
parent 298055b25c
commit ce92630396
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 2247 additions and 38 deletions

View File

@ -75,6 +75,7 @@ SPVTOOLS_SRC_FILES := \
source/val/validate_ray_tracing_reorder.cpp \
source/val/validate_scopes.cpp \
source/val/validate_small_type_uses.cpp \
source/val/validate_tensor_layout.cpp \
source/val/validate_type.cpp
SPVTOOLS_OPT_SRC_FILES := \

View File

@ -557,6 +557,7 @@ static_library("spvtools_val") {
"source/val/validate_scopes.cpp",
"source/val/validate_scopes.h",
"source/val/validate_small_type_uses.cpp",
"source/val/validate_tensor_layout.cpp",
"source/val/validate_type.cpp",
"source/val/validation_state.cpp",
"source/val/validation_state.h",

2
DEPS
View File

@ -14,7 +14,7 @@ vars = {
're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',
'spirv_headers_revision': '50bc4debdc3eec5045edbeb8ce164090e29b91f3',
'spirv_headers_revision': '22c4d1b1e9d1c7d9aa5086c93e6491f21080019b',
}
deps = {

View File

@ -314,6 +314,12 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS,
// Optional enum type from SPV_NV_raw_access_chains
SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS,
// Enum type from SPV_NV_tensor_addressing
SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS,
// This is a sentinel value, and does not represent an operand type.
// It should come last.

View File

@ -334,6 +334,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_tracing_reorder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp

View File

@ -717,13 +717,16 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: {
// This operand is a mask.
// Map an optional operand type to its corresponding concrete type.

View File

@ -382,6 +382,8 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeRayQueryKHR:
case spv::Op::OpTypeHitObjectNV:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeTensorLayoutNV:
case spv::Op::OpTypeTensorViewNV:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,

View File

@ -231,6 +231,12 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "cooperative matrix layout";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
return "cooperative matrix use";
case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE:
return "tensor clamp mode";
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
return "cooperative matrix reduce";
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return "tensor addressing operands";
case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
return "initialization mode qualifier";
case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
@ -389,6 +395,7 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_STORE_CACHE_CONTROL:
case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS:
case SPV_OPERAND_TYPE_FPENCODING:
case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE:
return true;
default:
break;
@ -409,6 +416,8 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return true;
default:
break;
@ -598,6 +607,16 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
case spv::Op::OpTypeArray:
out = [](unsigned index) { return index == 1; };
break;
case spv::Op::OpCooperativeMatrixPerElementOpNV:
out = [](unsigned index) { return index == 3; };
break;
case spv::Op::OpCooperativeMatrixReduceNV:
out = [](unsigned index) { return index == 4; };
break;
case spv::Op::OpCooperativeMatrixLoadTensorNV:
// approximate, due to variable operands
out = [](unsigned index) { return index > 6; };
break;
default:
out = [](unsigned) { return false; };
break;

View File

@ -43,6 +43,7 @@ constexpr uint32_t kGlobalVariableVariableIndex = 12;
constexpr uint32_t kExtInstSetInIdx = 0;
constexpr uint32_t kExtInstOpInIdx = 1;
constexpr uint32_t kInterpolantInIdx = 2;
constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0;
// Sorting functor to present annotation instructions in an easy-to-process
// order. The functor orders by opcode first and falls back on unique id
@ -438,6 +439,11 @@ uint32_t AggressiveDCEPass::GetLoadedVariableFromNonFunctionCalls(
}
break;
}
case spv::Op::OpCooperativeMatrixLoadNV:
case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpCooperativeMatrixLoadTensorNV:
return GetVariableId(
inst->GetSingleWordInOperand(kCooperativeMatrixLoadSourceAddrInIdx));
default:
break;
}

View File

@ -70,6 +70,11 @@ void EliminateDeadMembersPass::FindLiveMembers() {
MarkPointeeTypeAsFullUsed(inst.type_id());
break;
}
} else if (inst.opcode() == spv::Op::OpTypePointer) {
uint32_t storage_class = inst.GetSingleWordInOperand(0);
if (storage_class == uint32_t(spv::StorageClass::PhysicalStorageBuffer)) {
MarkTypeAsFullyUsed(inst.GetSingleWordInOperand(1));
}
}
}
@ -200,6 +205,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
@ -246,6 +253,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
@ -505,6 +514,8 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
new_operands.emplace_back(inst->GetInOperand(i));
type_id = type_inst->GetSingleWordInOperand(0);
break;
@ -578,6 +589,8 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
@ -639,6 +652,8 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:

View File

@ -926,9 +926,35 @@ uint32_t IRContext::GetBuiltinInputVarId(uint32_t builtin) {
void IRContext::AddCalls(const Function* func, std::queue<uint32_t>* todo) {
for (auto bi = func->begin(); bi != func->end(); ++bi)
for (auto ii = bi->begin(); ii != bi->end(); ++ii)
for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
if (ii->opcode() == spv::Op::OpFunctionCall)
todo->push(ii->GetSingleWordInOperand(0));
if (ii->opcode() == spv::Op::OpCooperativeMatrixPerElementOpNV)
todo->push(ii->GetSingleWordInOperand(1));
if (ii->opcode() == spv::Op::OpCooperativeMatrixReduceNV)
todo->push(ii->GetSingleWordInOperand(2));
if (ii->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
const auto memory_operands_index = 3;
auto mask = ii->GetSingleWordInOperand(memory_operands_index);
uint32_t count = 1;
if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++count;
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR))
++count;
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR))
++count;
const auto tensor_operands_index = memory_operands_index + count;
mask = ii->GetSingleWordInOperand(tensor_operands_index);
count = 1;
if (mask & uint32_t(spv::TensorAddressingOperandsMask::TensorView))
++count;
if (mask & uint32_t(spv::TensorAddressingOperandsMask::DecodeFunc)) {
todo->push(ii->GetSingleWordInOperand(tensor_operands_index + count));
}
}
}
}
bool IRContext::ProcessEntryPointCallTree(ProcessFunction& pfn) {

View File

@ -441,6 +441,28 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
{SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
break;
}
case Type::kTensorLayoutNV: {
auto tensor_layout = type->AsTensorLayoutNV();
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeTensorLayoutNV, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {tensor_layout->dim_id()}},
{SPV_OPERAND_TYPE_ID, {tensor_layout->clamp_mode_id()}}});
break;
}
case Type::kTensorViewNV: {
auto tensor_view = type->AsTensorViewNV();
std::vector<Operand> operands;
operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {tensor_view->dim_id()}});
operands.push_back(
Operand{SPV_OPERAND_TYPE_ID, {tensor_view->has_dimensions_id()}});
for (auto p : tensor_view->perm()) {
operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {p}});
}
typeInst = MakeUnique<Instruction>(context(), spv::Op::OpTypeTensorViewNV,
0, id, operands);
break;
}
default:
assert(false && "Unexpected type");
break;
@ -667,6 +689,18 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
cm_type->use_id());
break;
}
case Type::kTensorLayoutNV: {
const TensorLayoutNV* tl_type = type.AsTensorLayoutNV();
rebuilt_ty = MakeUnique<TensorLayoutNV>(tl_type->dim_id(),
tl_type->clamp_mode_id());
break;
}
case Type::kTensorViewNV: {
const TensorViewNV* tv_type = type.AsTensorViewNV();
rebuilt_ty = MakeUnique<TensorViewNV>(
tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm());
break;
}
default:
assert(false && "Unhandled type");
return nullptr;
@ -914,6 +948,20 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case spv::Op::OpTypeHitObjectNV:
type = new HitObjectNV();
break;
case spv::Op::OpTypeTensorLayoutNV:
type = new TensorLayoutNV(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1));
break;
case spv::Op::OpTypeTensorViewNV: {
const auto count = inst.NumOperands();
std::vector<uint32_t> perm;
for (uint32_t i = 2; i < count; ++i) {
perm.push_back(inst.GetSingleWordOperand(i));
}
type = new TensorViewNV(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1), perm);
break;
}
default:
assert(false && "Type not handled by the type manager.");
break;

View File

@ -179,6 +179,8 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
@ -235,6 +237,8 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
#undef DeclareKindCase
default:
assert(false && "Unhandled type");
@ -747,7 +751,55 @@ bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
columns_id_ == mt->columns_id_ && use_id_ == mt->use_id_ &&
HasSameDecorations(that);
}
TensorLayoutNV::TensorLayoutNV(const uint32_t dim, const uint32_t clamp_mode)
: Type(kTensorLayoutNV), dim_id_(dim), clamp_mode_id_(clamp_mode) {}
std::string TensorLayoutNV::str() const {
std::ostringstream oss;
oss << "<" << dim_id_ << ", " << clamp_mode_id_ << ">";
return oss.str();
}
size_t TensorLayoutNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
return hash_combine(hash, dim_id_, clamp_mode_id_);
}
bool TensorLayoutNV::IsSameImpl(const Type* that, IsSameCache*) const {
const TensorLayoutNV* tl = that->AsTensorLayoutNV();
if (!tl) return false;
return dim_id_ == tl->dim_id_ && clamp_mode_id_ == tl->clamp_mode_id_;
}
TensorViewNV::TensorViewNV(const uint32_t dim, const uint32_t clamp_mode,
const std::vector<uint32_t>& perm)
: Type(kTensorViewNV),
dim_id_(dim),
has_dimensions_id_(clamp_mode),
perm_(perm) {}
std::string TensorViewNV::str() const {
std::ostringstream oss;
oss << "<" << dim_id_ << ", " << has_dimensions_id_;
for (auto p : perm_) {
oss << ", " << p;
}
oss << ">";
return oss.str();
}
size_t TensorViewNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
return hash_combine(hash, dim_id_, has_dimensions_id_, perm_);
}
bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const {
const TensorViewNV* tv = that->AsTensorViewNV();
if (!tv) return false;
return dim_id_ == tv->dim_id_ &&
has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_;
}
} // namespace analysis

View File

@ -63,6 +63,8 @@ class CooperativeMatrixNV;
class CooperativeMatrixKHR;
class RayQueryKHR;
class HitObjectNV;
class TensorLayoutNV;
class TensorViewNV;
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
// which is used as a way to probe the actual <subclass>.
@ -104,6 +106,8 @@ class Type {
kCooperativeMatrixKHR,
kRayQueryKHR,
kHitObjectNV,
kTensorLayoutNV,
kTensorViewNV,
kLast
};
@ -206,6 +210,8 @@ class Type {
DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
DeclareCastMethod(TensorLayoutNV)
DeclareCastMethod(TensorViewNV)
#undef DeclareCastMethod
protected:
@ -659,6 +665,53 @@ class CooperativeMatrixKHR : public Type {
const uint32_t use_id_;
};
class TensorLayoutNV : public Type {
public:
TensorLayoutNV(const uint32_t dim, const uint32_t clamp_mode);
TensorLayoutNV(const TensorLayoutNV&) = default;
std::string str() const override;
TensorLayoutNV* AsTensorLayoutNV() override { return this; }
const TensorLayoutNV* AsTensorLayoutNV() const override { return this; }
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
uint32_t dim_id() const { return dim_id_; }
uint32_t clamp_mode_id() const { return clamp_mode_id_; }
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const uint32_t dim_id_;
const uint32_t clamp_mode_id_;
};
class TensorViewNV : public Type {
public:
TensorViewNV(const uint32_t dim, const uint32_t clamp_mode,
const std::vector<uint32_t>& perm);
TensorViewNV(const TensorViewNV&) = default;
std::string str() const override;
TensorViewNV* AsTensorViewNV() override { return this; }
const TensorViewNV* AsTensorViewNV() const override { return this; }
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
uint32_t dim_id() const { return dim_id_; }
uint32_t has_dimensions_id() const { return has_dimensions_id_; }
const std::vector<uint32_t>& perm() const { return perm_; }
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const uint32_t dim_id_;
const uint32_t has_dimensions_id_;
std::vector<uint32_t> perm_;
};
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \

View File

@ -414,7 +414,9 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: {
uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error)

View File

@ -366,6 +366,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
if (auto error = RayTracingPass(*vstate, &instruction)) return error;
if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
}
// Validate the preconditions involving adjacent instructions. e.g.

View File

@ -226,6 +226,9 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);
/// Validates tensor layout and view instructions.
spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst);
/// Validates execution limitations.
///
/// Verifies execution models are allowed for all functionality they contain.

View File

@ -62,7 +62,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
_.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
if (ret != SPV_SUCCESS) return ret;
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -96,7 +96,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
_.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
if (ret != SPV_SUCCESS) return ret;
} else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -142,7 +142,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
<< operand_index;
}
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, type_id, result_type);
_.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
if (ret != SPV_SUCCESS) return ret;
}
@ -672,6 +672,128 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case spv::Op::OpCooperativeMatrixReduceNV: {
if (!_.IsCooperativeMatrixKHRType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type must be a cooperative matrix type: "
<< spvOpcodeString(opcode);
}
const auto result_comp_type_id =
_.FindDef(result_type)->GetOperandAs<uint32_t>(1);
const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
const auto matrix = _.FindDef(matrix_id);
const auto matrix_type_id = matrix->type_id();
if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Matrix must have a cooperative matrix type: "
<< spvOpcodeString(opcode);
}
const auto matrix_type = _.FindDef(matrix_type_id);
const auto matrix_comp_type_id = matrix_type->GetOperandAs<uint32_t>(1);
if (matrix_comp_type_id != result_comp_type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type and Matrix type must have the same component "
"type: "
<< spvOpcodeString(opcode);
}
if (_.FindDef(result_type)->GetOperandAs<uint32_t>(2) !=
matrix_type->GetOperandAs<uint32_t>(2)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type and Matrix type must have the same scope: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixAccType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type must have UseAccumulator: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixAccType(matrix_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Matrix type must have UseAccumulator: "
<< spvOpcodeString(opcode);
}
const auto reduce_value = inst->GetOperandAs<uint32_t>(3);
if ((reduce_value &
uint32_t(
spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) &&
(reduce_value & uint32_t(spv::CooperativeMatrixReduceMask::Row |
spv::CooperativeMatrixReduceMask::Column))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Reduce 2x2 must not be used with Row/Column: "
<< spvOpcodeString(opcode);
}
std::tuple<bool, bool, uint32_t> result_rows, result_cols, matrix_rows,
matrix_cols;
result_rows =
_.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(3));
result_cols =
_.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(4));
matrix_rows = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(3));
matrix_cols = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(4));
if (reduce_value &
uint32_t(
spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) {
if (std::get<1>(result_rows) && std::get<1>(result_cols) &&
std::get<1>(matrix_rows) && std::get<1>(matrix_cols) &&
(std::get<2>(result_rows) != std::get<2>(matrix_rows) / 2 ||
std::get<2>(result_cols) != std::get<2>(matrix_cols) / 2)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "For Reduce2x2, result rows/cols must be half of matrix "
"rows/cols: "
<< spvOpcodeString(opcode);
}
}
if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Row)) {
if (std::get<1>(result_rows) && std::get<1>(matrix_rows) &&
std::get<2>(result_rows) != std::get<2>(matrix_rows)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "For ReduceRow, result rows must match matrix rows: "
<< spvOpcodeString(opcode);
}
}
if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Column)) {
if (std::get<1>(result_cols) && std::get<1>(matrix_cols) &&
std::get<2>(result_cols) != std::get<2>(matrix_cols)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "For ReduceColumn, result cols must match matrix cols: "
<< spvOpcodeString(opcode);
}
}
const auto combine_func_id = inst->GetOperandAs<uint32_t>(4);
const auto combine_func = _.FindDef(combine_func_id);
if (!combine_func || combine_func->opcode() != spv::Op::OpFunction) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "CombineFunc must be a function: " << spvOpcodeString(opcode);
}
const auto function_type_id = combine_func->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
if (function_type->operands().size() != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "CombineFunc must have two parameters: "
<< spvOpcodeString(opcode);
}
for (uint32_t i = 0; i < 3; ++i) {
// checks return type and two params
const auto param_type_id = function_type->GetOperandAs<uint32_t>(i + 1);
if (param_type_id != matrix_comp_type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "CombineFunc return type and parameters must match matrix "
"component type: "
<< spvOpcodeString(opcode);
}
}
break;
}
default:
break;
}

View File

@ -49,7 +49,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -79,7 +79,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -111,7 +111,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -142,7 +142,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -177,7 +177,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -213,7 +213,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@ -497,8 +497,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "matrix: " << spvOpcodeString(opcode);
if (result_is_coopmat) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type,
input_type, false);
if (ret != SPV_SUCCESS) return ret;
}
@ -568,6 +568,43 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case spv::Op::OpCooperativeMatrixConvertNV:
case spv::Op::OpCooperativeMatrixTransposeNV: {
if (!_.IsCooperativeMatrixType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix Result Type: "
<< spvOpcodeString(opcode);
}
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!_.IsCooperativeMatrixType(input_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type for Matrix input: "
<< spvOpcodeString(opcode);
}
bool swap_row_col = (opcode == spv::Op::OpCooperativeMatrixTransposeNV);
if (auto error = _.CooperativeMatrixShapesMatch(
inst, result_type, input_type, true, swap_row_col))
return error;
if (opcode == spv::Op::OpCooperativeMatrixConvertNV) {
if (_.FindDef(result_type)->GetOperandAs<uint32_t>(1) !=
_.FindDef(input_type)->GetOperandAs<uint32_t>(1)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type and Matrix component types mismatch: "
<< spvOpcodeString(opcode);
}
}
if (opcode == spv::Op::OpCooperativeMatrixTransposeNV) {
if (!_.IsCooperativeMatrixBType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type must have UseB: " << spvOpcodeString(opcode);
}
}
break;
}
default:
break;
}

View File

@ -86,7 +86,10 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
spv::Op::OpGetKernelLocalSizeForSubgroupCount,
spv::Op::OpGetKernelMaxNumSubgroups,
spv::Op::OpName};
spv::Op::OpName,
spv::Op::OpCooperativeMatrixPerElementOpNV,
spv::Op::OpCooperativeMatrixReduceNV,
spv::Op::OpCooperativeMatrixLoadTensorNV};
for (auto& pair : inst->uses()) {
const auto* use = pair.first;
if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
@ -341,6 +344,80 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
return SPV_SUCCESS;
}
spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
const Instruction* inst) {
const auto function_id = inst->GetOperandAs<uint32_t>(3);
const auto function = _.FindDef(function_id);
if (!function || spv::Op::OpFunction != function->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV Function <id> "
<< _.getIdName(function_id) << " is not a function.";
}
const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
const auto matrix = _.FindDef(matrix_id);
const auto matrix_type_id = matrix->type_id();
if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV Matrix <id> "
<< _.getIdName(matrix_id) << " is not a cooperative matrix.";
}
const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
if (matrix_type_id != result_type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV Result Type <id> "
<< _.getIdName(result_type_id) << " must match matrix type <id> "
<< _.getIdName(matrix_type_id) << ".";
}
const auto matrix_comp_type_id =
_.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
const auto function_type_id = function->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
if (return_type_id != matrix_comp_type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV function return type <id> "
<< _.getIdName(return_type_id)
<< " must match matrix component type <id> "
<< _.getIdName(matrix_comp_type_id) << ".";
}
if (function_type->operands().size() < 5) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV function type <id> "
<< _.getIdName(function_type_id)
<< " must have a least three parameters.";
}
const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV function type first parameter "
"type <id> "
<< _.getIdName(param0_id) << " must be a 32-bit integer.";
}
if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV function type second "
"parameter type <id> "
<< _.getIdName(param1_id) << " must be a 32-bit integer.";
}
if (param2_id != matrix_comp_type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixPerElementOpNV function type third parameter "
"type <id> "
<< _.getIdName(param2_id) << " must match matrix component type.";
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
@ -354,6 +431,10 @@ spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpFunctionCall:
if (auto error = ValidateFunctionCall(_, inst)) return error;
break;
case spv::Op::OpCooperativeMatrixPerElementOpNV:
if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
return error;
break;
default:
break;
}

View File

@ -475,6 +475,12 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
const uint32_t entry_point = inst->word(1);
_.RegisterExecutionModeForEntryPoint(entry_point,
spv::ExecutionMode(inst->word(2)));
if (inst->GetOperandAs<spv::ExecutionMode>(1) ==
spv::ExecutionMode::LocalSize ||
inst->GetOperandAs<spv::ExecutionMode>(1) ==
spv::ExecutionMode::LocalSizeId) {
_.RegisterEntryPointLocalSize(entry_point, inst);
}
} else if (opcode == spv::Op::OpVariable) {
const auto storage_class = inst->GetOperandAs<spv::StorageClass>(2);
if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {

View File

@ -233,6 +233,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
spv::StorageClass src_sc = spv::StorageClass::Max;
switch (inst->opcode()) {
case spv::Op::OpCooperativeMatrixLoadNV:
case spv::Op::OpCooperativeMatrixLoadTensorNV:
case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpLoad: {
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
@ -241,6 +242,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
break;
}
case spv::Op::OpCooperativeMatrixStoreNV:
case spv::Op::OpCooperativeMatrixStoreTensorNV:
case spv::Op::OpCooperativeMatrixStoreKHR:
case spv::Op::OpStore: {
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
@ -330,6 +332,7 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
if (inst->opcode() == spv::Op::OpLoad ||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV ||
inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
@ -350,7 +353,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
if (inst->opcode() == spv::Op::OpStore ||
inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV ||
inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR) {
inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR ||
inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerVisibleKHR cannot be used with OpStore.";
}
@ -2176,6 +2180,222 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
return SPV_SUCCESS;
}
// Returns the number of instruction words taken up by a tensor addressing
// operands argument and its implied operands.
int TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask) {
int result = 1; // Count the mask
if ((mask & spv::TensorAddressingOperandsMask::TensorView) !=
spv::TensorAddressingOperandsMask::MaskNone)
++result;
if ((mask & spv::TensorAddressingOperandsMask::DecodeFunc) !=
spv::TensorAddressingOperandsMask::MaskNone)
++result;
return result;
}
spv_result_t ValidateCooperativeMatrixLoadStoreTensorNV(
ValidationState_t& _, const Instruction* inst) {
uint32_t type_id;
const char* opname;
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
type_id = inst->type_id();
opname = "spv::Op::OpCooperativeMatrixLoadTensorNV";
} else {
// get Object operand's type
type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
opname = "spv::Op::OpCooperativeMatrixStoreTensorNV";
}
auto matrix_type = _.FindDef(type_id);
if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "spv::Op::OpCooperativeMatrixLoadTensorNV Result Type <id> "
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
} else {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "spv::Op::OpCooperativeMatrixStoreTensorNV Object type <id> "
<< _.getIdName(type_id) << " is not a cooperative matrix type.";
}
}
const auto pointer_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 2u : 0u;
const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
const auto pointer = _.FindDef(pointer_id);
if (!pointer ||
((_.addressing_model() == spv::AddressingModel::Logical) &&
((!_.features().variable_pointers &&
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
(_.features().variable_pointers &&
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Pointer <id> " << _.getIdName(pointer_id)
<< " is not a logical pointer.";
}
const auto pointer_type_id = pointer->type_id();
const auto pointer_type = _.FindDef(pointer_type_id);
if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " type for pointer <id> " << _.getIdName(pointer_id)
<< " is not a pointer type.";
}
const auto storage_class_index = 1u;
const auto storage_class =
pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
if (storage_class != spv::StorageClass::Workgroup &&
storage_class != spv::StorageClass::StorageBuffer &&
storage_class != spv::StorageClass::PhysicalStorageBuffer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< _.VkErrorID(8973) << opname
<< " storage class for pointer type <id> "
<< _.getIdName(pointer_type_id)
<< " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer.";
}
if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
const auto object_index = 3;
const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
const auto object = _.FindDef(object_id);
if (!object || object->type_id() != type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Object <id> " << _.getIdName(object_id)
<< " type does not match Result Type.";
}
}
const auto tensor_layout_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 4u : 2u;
const auto tensor_layout_id =
inst->GetOperandAs<uint32_t>(tensor_layout_index);
const auto tensor_layout = _.FindDef(tensor_layout_id);
if (!tensor_layout || _.FindDef(tensor_layout->type_id())->opcode() !=
spv::Op::OpTypeTensorLayoutNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " TensorLayout <id> " << _.getIdName(tensor_layout_id)
<< " does not have a tensor layout type.";
}
const auto memory_access_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 5u : 3u;
if (inst->operands().size() > memory_access_index) {
if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
return error;
}
const auto memory_access_mask =
inst->GetOperandAs<uint32_t>(memory_access_index);
const auto tensor_operands_index =
memory_access_index + MemoryAccessNumWords(memory_access_mask);
const auto tensor_operands =
inst->GetOperandAs<spv::TensorAddressingOperandsMask>(
tensor_operands_index);
if (inst->operands().size() <
tensor_operands_index +
TensorAddressingOperandsNumWords(tensor_operands)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " not enough tensor addressing operands.";
}
uint32_t tensor_operand_index = tensor_operands_index + 1;
if ((tensor_operands & spv::TensorAddressingOperandsMask::TensorView) !=
spv::TensorAddressingOperandsMask::MaskNone) {
const auto tensor_view_id =
inst->GetOperandAs<uint32_t>(tensor_operand_index);
const auto tensor_view = _.FindDef(tensor_view_id);
if (!tensor_view || _.FindDef(tensor_view->type_id())->opcode() !=
spv::Op::OpTypeTensorViewNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " TensorView <id> " << _.getIdName(tensor_view_id)
<< " does not have a tensor view type.";
}
tensor_operand_index++;
}
if ((tensor_operands & spv::TensorAddressingOperandsMask::DecodeFunc) !=
spv::TensorAddressingOperandsMask::MaskNone) {
if (inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpCooperativeMatrixStoreTensorNV does not support DecodeFunc.";
}
const auto decode_func_id =
inst->GetOperandAs<uint32_t>(tensor_operand_index);
const auto decode_func = _.FindDef(decode_func_id);
if (!decode_func || decode_func->opcode() != spv::Op::OpFunction) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
<< " is not a function.";
}
const auto component_type_index = 1;
const auto component_type_id =
matrix_type->GetOperandAs<uint32_t>(component_type_index);
const auto function_type =
_.FindDef(decode_func->GetOperandAs<uint32_t>(3));
if (function_type->GetOperandAs<uint32_t>(1) != component_type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
<< " return type must match matrix component type.";
}
const auto decode_ptr_type_id = function_type->GetOperandAs<uint32_t>(2);
const auto decode_ptr_type = _.FindDef(decode_ptr_type_id);
auto decode_storage_class =
decode_ptr_type->GetOperandAs<spv::StorageClass>(storage_class_index);
if (decode_storage_class != spv::StorageClass::PhysicalStorageBuffer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
<< " first parameter must be pointer to PhysicalStorageBuffer.";
}
const auto tensor_layout_type = _.FindDef(tensor_layout->type_id());
for (uint32_t param = 3; param < 5; ++param) {
const auto param_type_id = function_type->GetOperandAs<uint32_t>(param);
const auto param_type = _.FindDef(param_type_id);
if (param_type->opcode() != spv::Op::OpTypeArray) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
<< " second/third parameter must be array of 32-bit integer "
"with "
<< " dimension equal to the tensor dimension.";
}
const auto length_index = 2u;
uint64_t array_length;
if (_.EvalConstantValUint64(
param_type->GetOperandAs<uint32_t>(length_index),
&array_length)) {
const auto tensor_layout_dim_id =
tensor_layout_type->GetOperandAs<uint32_t>(1);
uint64_t dim_value;
if (_.EvalConstantValUint64(tensor_layout_dim_id, &dim_value)) {
if (array_length != dim_value) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " DecodeFunc <id> "
<< _.getIdName(decode_func_id)
<< " second/third parameter must be array of 32-bit integer "
"with "
<< " dimension equal to the tensor dimension.";
}
}
}
}
tensor_operand_index++;
}
return SPV_SUCCESS;
}
spv_result_t ValidatePtrComparison(ValidationState_t& _,
const Instruction* inst) {
if (_.addressing_model() == spv::AddressingModel::Logical &&
@ -2284,6 +2504,11 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
return error;
break;
case spv::Op::OpCooperativeMatrixLoadTensorNV:
case spv::Op::OpCooperativeMatrixStoreTensorNV:
if (auto error = ValidateCooperativeMatrixLoadStoreTensorNV(_, inst))
return error;
break;
case spv::Op::OpPtrEqual:
case spv::Op::OpPtrNotEqual:
case spv::Op::OpPtrDiff:

View File

@ -0,0 +1,184 @@
// Copyright (c) 2024 NVIDIA Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Validate instructions that manipulate tensor layout and view objects
#include "source/opcode.h"
#include "source/spirv_target_env.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _,
const Instruction* inst) {
const auto result_type_index = 0;
const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
const auto result_type = _.FindDef(result_type_id);
if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
<< _.getIdName(result_type_id) << " is not a tensor layout type.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _,
const Instruction* inst) {
const auto result_type_index = 0;
const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
const auto result_type = _.FindDef(result_type_id);
if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
<< _.getIdName(result_type_id) << " is not a tensor view type.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _,
const Instruction* inst) {
if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
return SPV_SUCCESS;
}
spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _,
const Instruction* inst) {
if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
return SPV_SUCCESS;
}
enum ExpectedNumValues {
DIM,
DIMx2,
ONE,
FOUR,
};
spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
const Instruction* inst,
ExpectedNumValues expected,
bool is_view) {
std::string type_str;
if (is_view) {
if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
type_str = "TensorView";
} else {
if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
type_str = "TensorLayout";
}
const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
const auto tensor_id = inst->GetOperandAs<uint32_t>(2);
const auto tensor = _.FindDef(tensor_id);
if (!tensor || result_type_id != tensor->type_id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
<< _.getIdName(result_type_id) << " does not match " << type_str
<< " type.";
}
const auto num_values = inst->operands().size() - 3;
const auto result_type = _.FindDef(result_type_id);
const auto dim_index = 1;
const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index);
uint64_t dim_value;
if (_.EvalConstantValUint64(dim_id, &dim_value)) {
uint64_t expected_num_values = 0;
switch (expected) {
case DIM:
expected_num_values = dim_value;
break;
case DIMx2:
expected_num_values = dim_value * 2;
break;
case ONE:
expected_num_values = 1;
break;
case FOUR:
expected_num_values = 4;
break;
}
if (num_values != expected_num_values) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode())
<< " unexpected number of operands.";
}
}
for (uint32_t i = 0; i < num_values; ++i) {
const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
const auto val = _.FindDef(val_id);
if (!val || !_.IsIntScalarType(val->type_id()) ||
_.GetBitWidth(val->type_id()) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " operand <id> "
<< _.getIdName(val_id) << " is not a 32-bit integer.";
}
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case spv::Op::OpCreateTensorLayoutNV:
if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error;
break;
case spv::Op::OpCreateTensorViewNV:
if (auto error = ValidateCreateTensorViewNV(_, inst)) return error;
break;
case spv::Op::OpTensorLayoutSetBlockSizeNV:
case spv::Op::OpTensorLayoutSetDimensionNV:
case spv::Op::OpTensorLayoutSetStrideNV:
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false))
return error;
break;
case spv::Op::OpTensorLayoutSliceNV:
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false))
return error;
break;
case spv::Op::OpTensorLayoutSetClampValueNV:
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false))
return error;
break;
case spv::Op::OpTensorViewSetDimensionNV:
case spv::Op::OpTensorViewSetStrideNV:
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true))
return error;
break;
case spv::Op::OpTensorViewSetClipNV:
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true))
return error;
break;
default:
break;
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools

View File

@ -1,4 +1,5 @@
// Copyright (c) 2018 Google LLC.
// Copyright (c) 2024 NVIDIA Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -582,6 +583,37 @@ spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
}
}
uint64_t scope_value;
if (_.EvalConstantValUint64(scope_id, &scope_value)) {
if (scope_value == static_cast<uint32_t>(spv::Scope::Workgroup)) {
for (auto entry_point_id : _.entry_points()) {
if (!_.EntryPointHasLocalSizeOrId(entry_point_id)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixKHR with ScopeWorkgroup "
<< "used without specifying LocalSize or LocalSizeId "
<< "for entry point <id> " << _.getIdName(entry_point_id);
}
const auto local_size = _.EntryPointLocalSizeOrId(entry_point_id);
const auto mode = local_size->GetOperandAs<spv::ExecutionMode>(1);
if (mode == spv::ExecutionMode::LocalSizeId) {
uint32_t local_size_ids[3] = {
local_size->GetOperandAs<uint32_t>(2),
local_size->GetOperandAs<uint32_t>(3),
local_size->GetOperandAs<uint32_t>(4),
};
for (auto id : local_size_ids) {
if (_.FindDef(id) > inst) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixKHR with ScopeWorkgroup "
<< "used before LocalSizeId constant value <id> "
<< _.getIdName(id) << " is defined.";
}
}
}
}
}
}
return SPV_SUCCESS;
}
@ -611,6 +643,115 @@ spv_result_t ValidateTypeUntypedPointerKHR(ValidationState_t& _,
}
return SPV_SUCCESS;
}
spv_result_t ValidateTensorDim(ValidationState_t& _, const Instruction* inst) {
const auto dim_index = 1;
const auto dim_id = inst->GetOperandAs<uint32_t>(dim_index);
const auto dim = _.FindDef(dim_id);
if (!dim || !_.IsIntScalarType(dim->type_id()) ||
_.GetBitWidth(dim->type_id()) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Dim <id> "
<< _.getIdName(dim_id) << " is not a 32-bit integer.";
}
constexpr uint32_t max_tensor_dim = 5;
uint64_t dim_value;
if (_.EvalConstantValUint64(dim_id, &dim_value)) {
if (dim_value == 0 || dim_value > max_tensor_dim) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Dim <id> "
<< _.getIdName(dim_id) << " must be between 1 and "
<< max_tensor_dim << ".";
}
}
return SPV_SUCCESS;
}
spv_result_t ValidateTypeTensorLayoutNV(ValidationState_t& _,
const Instruction* inst) {
if (auto error = ValidateTensorDim(_, inst)) return error;
const auto clamp_index = 2;
const auto clamp_id = inst->GetOperandAs<uint32_t>(clamp_index);
const auto clamp = _.FindDef(clamp_id);
if (!clamp || !_.IsIntScalarType(clamp->type_id()) ||
_.GetBitWidth(clamp->type_id()) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " ClampMode <id> "
<< _.getIdName(clamp_id) << " is not a 32-bit integer.";
}
uint64_t clamp_value;
if (_.EvalConstantValUint64(clamp_id, &clamp_value)) {
if (clamp_value >
static_cast<uint32_t>(spv::TensorClampMode::RepeatMirrored)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " ClampMode <id> "
<< _.getIdName(clamp_id) << " must be a valid TensorClampMode.";
}
}
return SPV_SUCCESS;
}
spv_result_t ValidateTypeTensorViewNV(ValidationState_t& _,
const Instruction* inst) {
if (auto error = ValidateTensorDim(_, inst)) return error;
const auto has_dim_index = 2;
const auto has_dim_id = inst->GetOperandAs<uint32_t>(has_dim_index);
const auto has_dim = _.FindDef(has_dim_id);
if (!has_dim || !_.IsBoolScalarType(has_dim->type_id())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " HasDimensions <id> "
<< _.getIdName(has_dim_id) << " is not a boolean value.";
}
uint32_t permutation_mask = 0;
bool all_constant = true;
const auto num_dim = inst->operands().size() - 3;
for (size_t p_index = 3; p_index < inst->operands().size(); ++p_index) {
auto p_id = inst->GetOperandAs<uint32_t>(p_index);
const auto p = _.FindDef(p_id);
if (!p || !_.IsIntScalarType(p->type_id()) ||
_.GetBitWidth(p->type_id()) != 32) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Permutation <id> "
<< _.getIdName(p_id) << " is not a 32-bit integer.";
}
uint64_t p_value;
if (_.EvalConstantValUint64(p_id, &p_value)) {
if (p_value >= num_dim) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode()) << " Permutation <id> "
<< _.getIdName(p_id) << " must be a valid dimension.";
}
permutation_mask |= 1 << p_value;
} else {
all_constant = false;
}
}
if (all_constant && permutation_mask != (1U << num_dim) - 1U) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode())
<< " Permutation values don't form a valid permutation.";
}
uint64_t dim_value;
if (_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(1), &dim_value)) {
if (dim_value != num_dim) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< spvOpcodeString(inst->opcode())
<< " Incorrect number of permutation values.";
}
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
@ -659,6 +800,12 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpTypeUntypedPointerKHR:
if (auto error = ValidateTypeUntypedPointerKHR(_, inst)) return error;
break;
case spv::Op::OpTypeTensorLayoutNV:
if (auto error = ValidateTypeTensorLayoutNV(_, inst)) return error;
break;
case spv::Op::OpTypeTensorViewNV:
if (auto error = ValidateTypeTensorViewNV(_, inst)) return error;
break;
default:
break;
}

View File

@ -1290,8 +1290,9 @@ bool ValidationState_t::IsUnsigned64BitHandle(uint32_t id) const {
}
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
const Instruction* inst, uint32_t m1, uint32_t m2) {
const auto m1_type = FindDef(m1);
const Instruction* inst, uint32_t result_type_id, uint32_t m2,
bool is_conversion, bool swap_row_col) {
const auto m1_type = FindDef(result_type_id);
const auto m2_type = FindDef(m2);
if (m1_type->opcode() != m2_type->opcode()) {
@ -1307,6 +1308,10 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
if (swap_row_col) {
std::swap(m1_rows_id, m1_cols_id);
}
bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
m2_is_const_int32 = false;
uint32_t m1_value = 0, m2_value = 0;
@ -1330,7 +1335,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected rows of Matrix type and Result Type to be "
<< "identical";
<< (swap_row_col ? "swapped with columns" : "identical");
}
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
@ -1341,7 +1346,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected columns of Matrix type and Result Type to be "
<< "identical";
<< (swap_row_col ? "swapped with rows" : "identical");
}
if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
@ -1352,7 +1357,12 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
EvalInt32IfConst(m2_use_id);
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value &&
// CooperativeMatrixConversionsNV allows conversions from Acc->A/B
!(is_conversion &&
HasCapability(spv::Capability::CooperativeMatrixConversionsNV) &&
m2_value ==
(uint32_t)spv::CooperativeMatrixUse::MatrixAccumulatorKHR)) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Use of Matrix type and Result Type to be "
<< "identical";

View File

@ -240,6 +240,21 @@ class ValidationState_t {
entry_point_to_execution_modes_[entry_point].insert(execution_mode);
}
/// Registers that the entry point declares its local size
void RegisterEntryPointLocalSize(uint32_t entry_point,
const Instruction* inst) {
entry_point_to_local_size_or_id_[entry_point] = inst;
}
/// Returns whether the entry point declares its local size
bool EntryPointHasLocalSizeOrId(uint32_t entry_point) const {
return entry_point_to_local_size_or_id_.find(entry_point) !=
entry_point_to_local_size_or_id_.end();
}
/// Returns the id of the local size
const Instruction* EntryPointLocalSizeOrId(uint32_t entry_point) const {
return entry_point_to_local_size_or_id_.find(entry_point)->second;
}
/// Returns the interface descriptions of a given entry point.
const std::vector<EntryPointDescription>& entry_point_descriptions(
uint32_t entry_point) {
@ -759,11 +774,14 @@ class ValidationState_t {
return SpvDecorationString(uint32_t(decoration));
}
// Returns whether type m1 and type m2 are cooperative matrices with
// the same "shape" (matching scope, rows, cols). If any are specialization
// constants, we assume they can match because we can't prove they don't.
// Returns whether type result_type_id and type m2 are cooperative matrices
// with the same "shape" (matching scope, rows, cols). If any are
// specialization constants, we assume they can match because we can't prove
// they don't.
spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst,
uint32_t m1, uint32_t m2);
uint32_t result_type_id,
uint32_t m2, bool is_conversion,
bool swap_row_col = false);
// Returns true if |lhs| and |rhs| logically match and, if the decorations of
// |rhs| are a subset of |lhs|.
@ -949,6 +967,10 @@ class ValidationState_t {
std::unordered_map<uint32_t, std::set<spv::ExecutionMode>>
entry_point_to_execution_modes_;
// Mapping entry point -> local size execution mode instruction
std::unordered_map<uint32_t, const Instruction*>
entry_point_to_local_size_or_id_;
/// Mapping function -> array of entry points inside this
/// module which can (indirectly) call the function.
std::unordered_map<uint32_t, std::vector<uint32_t>> function_to_entry_points_;

View File

@ -175,6 +175,9 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new RayQueryKHR());
types.emplace_back(new HitObjectNV());
types.emplace_back(new TensorLayoutNV(1002, 1000));
types.emplace_back(new TensorViewNV(1002, 1003, {1000, 1001}));
return types;
}
@ -1102,11 +1105,14 @@ OpMemoryModel Logical GLSL450
%uint = OpTypeInt 32 0
%1 = OpTypePointer Input %uint
%2 = OpTypePointer Uniform %uint
%1000 = OpConstant %uint 0
%1001 = OpConstant %uint 1
%1002 = OpConstant %uint 2
%8 = OpConstant %uint 8
%24 = OpConstant %uint 24
%42 = OpConstant %uint 42
%100 = OpConstant %uint 100
%1003 = OpConstantFalse %bool
)";
std::unique_ptr<IRContext> context =

View File

@ -1280,14 +1280,14 @@ TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
TEST_F(ValidateArithmetics, CoopMatScopeFail) {
const std::string types = R"(
%workgroup = OpConstant %u32 2
%device = OpConstant %u32 1
%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
%mat16x16_dv = OpTypeCooperativeMatrixNV %f16 %device %u32_16 %u32_16
%f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matdv_16x16_1
)";
CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
@ -1475,7 +1475,10 @@ std::string GenerateCoopMatKHRCode(const std::string& extra_types,
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpCapability CooperativeMatrixReductionsNV
OpCapability CooperativeMatrixPerElementOperationsNV
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_NV_cooperative_matrix2"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
@ -1487,6 +1490,7 @@ OpEntryPoint GLCompute %main "main"
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_16 = OpConstant %u32 16
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
@ -1579,13 +1583,13 @@ TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
const std::string types = R"(
%workgroup = OpConstant %u32 2
%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
%device = OpConstant %u32 1
%mat16x16_dv = OpTypeCooperativeMatrixKHR %f16 %device %u32_16 %u32_16 %useC
%f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1
)";
const std::string body = R"(
%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1
%val1 = OpFAdd %f16matA %f16matdv_16x16_1 %f16mat_A_1
)";
CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
@ -1612,6 +1616,241 @@ TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
}
TEST_F(ValidateArithmetics, CoopMat2ReduceSuccess) {
const std::string extra_types = R"(
%f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC
%f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC
%f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC
%functy = OpTypeFunction %f16 %f16 %f16
%reducefunc = OpFunction %f16 None %functy
%x = OpFunctionParameter %f16
%y = OpFunctionParameter %f16
%entry2 = OpLabel
%sum = OpFAdd %f16 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 2x2 %reducefunc
%val2 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Row %reducefunc
%val3 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Column %reducefunc
%val4 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc
%val5 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, CoopMat2Reduce2x2DimFail) {
const std::string extra_types = R"(
%functy = OpTypeFunction %f16 %f16 %f16
%reducefunc = OpFunction %f16 None %functy
%x = OpFunctionParameter %f16
%y = OpFunctionParameter %f16
%entry2 = OpLabel
%sum = OpFAdd %f16 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 2x2 %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("For Reduce2x2, result rows/cols must be half of "
"matrix rows/cols: CooperativeMatrixReduceNV"));
}
TEST_F(ValidateArithmetics, CoopMat2ReduceRowDimFail) {
const std::string extra_types = R"(
%f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC
%functy = OpTypeFunction %f16 %f16 %f16
%reducefunc = OpFunction %f16 None %functy
%x = OpFunctionParameter %f16
%y = OpFunctionParameter %f16
%entry2 = OpLabel
%sum = OpFAdd %f16 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Row %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("For ReduceRow, result rows must match matrix rows: "
"CooperativeMatrixReduceNV"));
}
TEST_F(ValidateArithmetics, CoopMat2ReduceColDimFail) {
const std::string extra_types = R"(
%f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC
%functy = OpTypeFunction %f16 %f16 %f16
%reducefunc = OpFunction %f16 None %functy
%x = OpFunctionParameter %f16
%y = OpFunctionParameter %f16
%entry2 = OpLabel
%sum = OpFAdd %f16 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Column %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("For ReduceColumn, result cols must match matrix cols: "
"CooperativeMatrixReduceNV"));
}
TEST_F(ValidateArithmetics, CoopMat2ReduceMaskFail) {
const std::string extra_types = R"(
%f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC
%functy = OpTypeFunction %f16 %f16 %f16
%reducefunc = OpFunction %f16 None %functy
%x = OpFunctionParameter %f16
%y = OpFunctionParameter %f16
%entry2 = OpLabel
%sum = OpFAdd %f16 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column|2x2 %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Reduce 2x2 must not be used with Row/Column: "
"CooperativeMatrixReduceNV"));
}
TEST_F(ValidateArithmetics, CoopMat2ReduceFuncTypeFail) {
const std::string extra_types = R"(
%functy = OpTypeFunction %f32 %f32 %f32
%reducefunc = OpFunction %f32 None %functy
%x = OpFunctionParameter %f32
%y = OpFunctionParameter %f32
%entry2 = OpLabel
%sum = OpFAdd %f32 %x %y
OpReturnValue %sum
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("CombineFunc return type and parameters must match "
"matrix component type: CooperativeMatrixReduceNV"));
}
TEST_F(ValidateArithmetics, CoopMat2PerElementOpSuccess) {
const std::string extra_types = R"(
%functy = OpTypeFunction %f16 %u32 %u32 %f16
%functy2 = OpTypeFunction %f16 %u32 %u32 %f16 %u32
%elemfunc = OpFunction %f16 None %functy
%row = OpFunctionParameter %u32
%col = OpFunctionParameter %u32
%el = OpFunctionParameter %f16
%entry2 = OpLabel
OpReturnValue %el
OpFunctionEnd
%elemfunc2 = OpFunction %f16 None %functy2
%row2 = OpFunctionParameter %u32
%col2 = OpFunctionParameter %u32
%el2 = OpFunctionParameter %f16
%x = OpFunctionParameter %u32
%entry3 = OpLabel
OpReturnValue %el2
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
%val2 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc2 %f16_1
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, CoopMat2PerElementOpElemTyFail) {
const std::string extra_types = R"(
%functy = OpTypeFunction %f32 %u32 %u32 %f32
%elemfunc = OpFunction %f32 None %functy
%row = OpFunctionParameter %u32
%col = OpFunctionParameter %u32
%el = OpFunctionParameter %f32
%entry2 = OpLabel
OpReturnValue %el
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("must match matrix component type"));
}
TEST_F(ValidateArithmetics, CoopMat2PerElementOpRowTyFail) {
const std::string extra_types = R"(
%functy = OpTypeFunction %f16 %f16 %u32 %f16
%elemfunc = OpFunction %f16 None %functy
%row = OpFunctionParameter %f16
%col = OpFunctionParameter %u32
%el = OpFunctionParameter %f16
%entry2 = OpLabel
OpReturnValue %el
OpFunctionEnd
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
)";
CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr("must be a 32-bit integer"));
}
} // namespace
} // namespace val
} // namespace spvtools

View File

@ -1338,11 +1338,11 @@ OpEntryPoint GLCompute %main "main"
%u32_8 = OpConstant %u32 8
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%workgroup = OpConstant %u32 2
%device = OpConstant %u32 1
%use_A = OpConstant %u32 0
%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %device %u32_8 %u32_8 %use_A
%f16_1 = OpConstant %f16 1
@ -2140,6 +2140,192 @@ INSTANTIATE_TEST_SUITE_P(SmallConversionInstructions, ValidateSmallConversions,
"%inst = OpBitcast %short %ld_half",
"%inst = OpBitcast %short2 %ld_half2"));
TEST_F(ValidateConversion, CoopMat2ConversionSuccess) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixConversionsNV
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_NV_cooperative_matrix2"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_16 = OpConstant %u32 16
%use_A = OpConstant %u32 0
%use_B = OpConstant %u32 1
%use_Acc = OpConstant %u32 2
%subgroup = OpConstant %u32 3
%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
%f32matA = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
%u16matA = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A
%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A
%s16matA = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A
%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A
%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_B
%f32matB = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B
%u16matB = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_B
%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_B
%s16matB = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_B
%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_B
%f16matAcc = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_Acc
%f32matAcc = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_Acc
%u16matAcc = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_Acc
%u32matAcc = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_Acc
%s16matAcc = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_Acc
%s32matAcc = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_Acc
%f16matAcc16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_Acc
%f16matB8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %use_B
%f16_1 = OpConstant %f16 1
%f32_1 = OpConstant %f32 1
%u16_1 = OpConstant %u16 1
%u32_1 = OpConstant %u32 1
%s16_1 = OpConstant %s16 1
%s32_1 = OpConstant %s32 1
%f16matAcc_1 = OpConstantComposite %f16matAcc %f16_1
%f32matAcc_1 = OpConstantComposite %f32matAcc %f32_1
%u16matAcc_1 = OpConstantComposite %u16matAcc %u16_1
%u32matAcc_1 = OpConstantComposite %u32matAcc %u32_1
%s16matAcc_1 = OpConstantComposite %s16matAcc %s16_1
%s32matAcc_1 = OpConstantComposite %s32matAcc %s32_1
%f16matAcc16x8_1 = OpConstantComposite %f16matAcc16x8 %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val11A = OpConvertFToU %u16matA %f16matAcc_1
%val12A = OpConvertFToU %u32matA %f16matAcc_1
%val13A = OpConvertFToS %s16matA %f16matAcc_1
%val14A = OpConvertFToS %s32matA %f16matAcc_1
%val15A = OpFConvert %f32matA %f16matAcc_1
%val11B = OpConvertFToU %u16matB %f16matAcc_1
%val12B = OpConvertFToU %u32matB %f16matAcc_1
%val13B = OpConvertFToS %s16matB %f16matAcc_1
%val14B = OpConvertFToS %s32matB %f16matAcc_1
%val15B = OpFConvert %f32matB %f16matAcc_1
%val21A = OpConvertFToU %u16matA %f32matAcc_1
%val22A = OpConvertFToU %u32matA %f32matAcc_1
%val23A = OpConvertFToS %s16matA %f32matAcc_1
%val24A = OpConvertFToS %s32matA %f32matAcc_1
%val25A = OpFConvert %f16matA %f32matAcc_1
%val21B = OpConvertFToU %u16matB %f32matAcc_1
%val22B = OpConvertFToU %u32matB %f32matAcc_1
%val23B = OpConvertFToS %s16matB %f32matAcc_1
%val24B = OpConvertFToS %s32matB %f32matAcc_1
%val25B = OpFConvert %f16matB %f32matAcc_1
%val31A = OpConvertUToF %f16matA %u16matAcc_1
%val32A = OpConvertUToF %f32matA %u16matAcc_1
%val33A = OpUConvert %u32matA %u16matAcc_1
%val34A = OpSConvert %s32matA %u16matAcc_1
%val31B = OpConvertUToF %f16matB %u16matAcc_1
%val32B = OpConvertUToF %f32matB %u16matAcc_1
%val33B = OpUConvert %u32matB %u16matAcc_1
%val34B = OpSConvert %s32matB %u16matAcc_1
%val41A = OpConvertSToF %f16matA %s16matAcc_1
%val42A = OpConvertSToF %f32matA %s16matAcc_1
%val43A = OpUConvert %u32matA %s16matAcc_1
%val44A = OpSConvert %s32matA %s16matAcc_1
%val41B = OpConvertSToF %f16matB %s16matAcc_1
%val42B = OpConvertSToF %f32matB %s16matAcc_1
%val43B = OpUConvert %u32matB %s16matAcc_1
%val44B = OpSConvert %s32matB %s16matAcc_1
%val51A = OpCooperativeMatrixConvertNV %f16matA %f16matAcc_1
%val52A = OpCooperativeMatrixConvertNV %f32matA %f32matAcc_1
%val53A = OpCooperativeMatrixConvertNV %u16matA %u16matAcc_1
%val54A = OpCooperativeMatrixConvertNV %s16matA %s16matAcc_1
%val51B = OpCooperativeMatrixConvertNV %f16matB %f16matAcc_1
%val52B = OpCooperativeMatrixConvertNV %f32matB %f32matAcc_1
%val53B = OpCooperativeMatrixConvertNV %u16matB %u16matAcc_1
%val54B = OpCooperativeMatrixConvertNV %s16matB %s16matAcc_1
%val61B = OpCooperativeMatrixTransposeNV %f16matB %f16matAcc_1
%val62B = OpCooperativeMatrixTransposeNV %f32matB %f32matAcc_1
%val63B = OpCooperativeMatrixTransposeNV %u16matB %u16matAcc_1
%val64B = OpCooperativeMatrixTransposeNV %s16matB %s16matAcc_1
%val71B = OpCooperativeMatrixTransposeNV %f16matB8x16 %f16matAcc16x8_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateConversion, CoopMat2TransposeShapeFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixConversionsNV
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_NV_cooperative_matrix2"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%u32_8 = OpConstant %u32 8
%u32_16 = OpConstant %u32 16
%use_B = OpConstant %u32 1
%use_Acc = OpConstant %u32 2
%subgroup = OpConstant %u32 3
%f16matAcc16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_Acc
%f16matB16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_B
%f16_1 = OpConstant %f16 1
%f16matAcc16x8_1 = OpConstantComposite %f16matAcc16x8 %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val71B = OpCooperativeMatrixTransposeNV %f16matB16x8 %f16matAcc16x8_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Expected rows of Matrix type and Result Type to be "
"swapped with columns"));
}
} // namespace
} // namespace val
} // namespace spvtools

View File

@ -7171,6 +7171,639 @@ OpFunctionEnd
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_3));
}
std::string GenCoopMat2Shader(const std::string& extra_types,
const std::string& main_body,
const std::string& after_main = "",
const std::string& extra_decorations = "") {
const std::string prefix = R"(
OpCapability Shader
OpCapability Float16
OpCapability PhysicalStorageBufferAddresses
OpCapability VulkanMemoryModel
OpCapability CooperativeMatrixKHR
OpCapability TensorAddressingNV
OpCapability CooperativeMatrixTensorAddressingNV
OpCapability CooperativeMatrixBlockLoadsNV
OpExtension "SPV_KHR_physical_storage_buffer"
OpExtension "SPV_KHR_storage_buffer_storage_class"
OpExtension "SPV_NV_tensor_addressing"
OpExtension "SPV_NV_cooperative_matrix2"
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical VulkanKHR
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %f16_arr ArrayStride 2
OpDecorate %46 Block
OpMemberDecorate %46 0 Offset 0
OpDecorate %48 Binding 0
OpDecorate %48 DescriptorSet 0
OpDecorate %psb Restrict
)" + extra_decorations + R"(
%void = OpTypeVoid
%bool = OpTypeBool
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%s32_0 = OpConstant %s32 0
%f16_0 = OpConstant %f16 0
%u32_2 = OpConstant %u32 2
%u32_8 = OpConstant %u32 8
%use_A = OpConstant %u32 0
%workgroup = OpConstant %u32 2
%subgroup = OpConstant %u32 3
%f16_arr = OpTypeRuntimeArray %f16
%46 = OpTypeStruct %f16_arr
%47 = OpTypePointer StorageBuffer %46
%48 = OpVariable %47 StorageBuffer
%51 = OpTypePointer StorageBuffer %f16_arr
%psbptr = OpTypePointer PhysicalStorageBuffer %f16_arr
%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_8 %u32_8 %use_A
%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
%arr2 = OpTypeArray %u32 %u32_2
%functy = OpTypeFunction %f16 %psbptr %arr2 %arr2
)";
const std::string decode_func =
R"(
%decodefunc = OpFunction %f16 None %functy
%psb = OpFunctionParameter %psbptr
%c0 = OpFunctionParameter %arr2
%c1 = OpFunctionParameter %arr2
%entry2 = OpLabel
OpReturnValue %f16_0
OpFunctionEnd
)";
const std::string func_begin =
R"(
%main = OpFunction %void None %func
%main_entry = OpLabel
%array_ptr = OpAccessChain %51 %48 %s32_0
)";
const std::string suffix =
R"(
OpReturn
OpFunctionEnd)";
return prefix + extra_types + func_begin + main_body + suffix + decode_func +
after_main;
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutAndViewSuccess) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutInvalidDimFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 6
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr("must be between 1 and 5"));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutInvalidClampFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 6
%dim = OpConstant %u32 2
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("must be a valid TensorClampMode"));
}
TEST_F(ValidateMemory, CoopMat2TensorViewInvalidDimFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 6
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr("must be between 1 and 5"));
}
TEST_F(ValidateMemory, CoopMat2TensorViewInvalidPermutationFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p1
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Permutation values don't form a valid permutation"));
}
TEST_F(ValidateMemory, CoopMat2TensorViewInvalidPermutation2Fail) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Incorrect number of permutation values."));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutBlockSizePass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetBlockSizeNV %layout %tl %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutBlockSizeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetBlockSizeNV %layout %tl %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutDimensionPass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetDimensionNV %layout %tl %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutDimensionFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetDimensionNV %layout %tl %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutStridePass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetStrideNV %layout %tl %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutStrideFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetStrideNV %layout %tl %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutSlicePass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSliceNV %layout %tl %b %b %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutSliceFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSliceNV %layout %tl %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorLayoutSetClampValuePass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 3
%b = OpConstant %u32 1
%layout = OpTypeTensorLayoutNV %dim %clamp
)",
R"(
%tl = OpCreateTensorLayoutNV %layout
%tl2 = OpTensorLayoutSetClampValueNV %layout %tl %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorViewDimensionPass) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%hasdim = OpConstantFalse %bool
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%p2 = OpConstant %u32 2
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2
%b = OpConstant %u32 1
)",
R"(
%tv = OpCreateTensorViewNV %view
%tv2 = OpTensorViewSetDimensionNV %view %tv %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorViewDimensionFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%hasdim = OpConstantFalse %bool
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%p2 = OpConstant %u32 2
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2
%b = OpConstant %u32 1
)",
R"(
%tv = OpCreateTensorViewNV %view
%tv2 = OpTensorViewSetDimensionNV %view %tv %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorViewStridePass) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%hasdim = OpConstantFalse %bool
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%p2 = OpConstant %u32 2
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2
%b = OpConstant %u32 1
)",
R"(
%tv = OpCreateTensorViewNV %view
%tv2 = OpTensorViewSetStrideNV %view %tv %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2TensorViewStrideFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%hasdim = OpConstantFalse %bool
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%p2 = OpConstant %u32 2
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2
%b = OpConstant %u32 1
)",
R"(
%tv = OpCreateTensorViewNV %view
%tv2 = OpTensorViewSetStrideNV %view %tv %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("unexpected number of operands"));
}
TEST_F(ValidateMemory, CoopMat2TensorViewClipPass) {
std::string spirv = GenCoopMat2Shader(
R"(
%dim = OpConstant %u32 3
%hasdim = OpConstantFalse %bool
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%p2 = OpConstant %u32 2
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2
%b = OpConstant %u32 1
)",
R"(
%tv = OpCreateTensorViewNV %view
%tv2 = OpTensorViewSetClipNV %view %tv %b %b %b %b
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2LoadStoreTensorPass) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
%mat = OpUndef %f16mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None None
%mat3 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl Aligned 4 None
%mat4 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None TensorView %tv
%mat5 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc
%mat6 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None TensorView|DecodeFunc %tv %decodefunc
%mat7 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl Aligned 4 TensorView|DecodeFunc %tv %decodefunc
OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl None None
OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl Aligned 4 None
OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl None TensorView %tv
OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl Aligned 4 TensorView %tv
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, CoopMat2LoadTensorWrongLayoutTypeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
%mat = OpUndef %f16mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tv None None
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("does not have a tensor layout type"));
}
TEST_F(ValidateMemory, CoopMat2LoadTensorWrongObjectTypeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
%mat = OpUndef %f32mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None None
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("type does not match Result Type"));
}
TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncTypeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
)",
R"(
%mat = OpUndef %f32mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f32mat %array_ptr %mat %tl None DecodeFunc %decodefunc
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("return type must match matrix component type"));
}
TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncArrayTypeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%u32_3 = OpConstant %u32 3
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
%arr3 = OpTypeArray %u32 %u32_3
%functy2 = OpTypeFunction %f16 %psbptr %arr3 %arr3
)",
R"(
%mat = OpUndef %f16mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc2
)",
R"(
%decodefunc2 = OpFunction %f16 None %functy2
%psb2 = OpFunctionParameter %psbptr
%c02 = OpFunctionParameter %arr3
%c12 = OpFunctionParameter %arr3
%entry3 = OpLabel
OpReturnValue %f16_0
OpFunctionEnd
)",
R"(
OpDecorate %psb2 Restrict
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("dimension equal to the tensor dimension"));
}
TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncPointerTypeFail) {
std::string spirv = GenCoopMat2Shader(
R"(
%clamp = OpConstant %u32 0
%dim = OpConstant %u32 2
%p0 = OpConstant %u32 0
%p1 = OpConstant %u32 1
%hasdim = OpConstantFalse %bool
%layout = OpTypeTensorLayoutNV %dim %clamp
%view = OpTypeTensorViewNV %dim %hasdim %p0 %p1
%sbptr = OpTypePointer StorageBuffer %f16_arr
%functy2 = OpTypeFunction %f16 %sbptr %arr2 %arr2
)",
R"(
%mat = OpUndef %f16mat
%tl = OpCreateTensorLayoutNV %layout
%tv = OpCreateTensorViewNV %view
%mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc2
)",
R"(
%decodefunc2 = OpFunction %f16 None %functy2
%sb = OpFunctionParameter %sbptr
%c02 = OpFunctionParameter %arr2
%c12 = OpFunctionParameter %arr2
%entry3 = OpLabel
OpReturnValue %f16_0
OpFunctionEnd
)");
CompileSuccessfully(spirv.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("first parameter must be pointer to PhysicalStorageBuffer"));
}
} // namespace
} // namespace val
} // namespace spvtools

View File

@ -350,6 +350,77 @@ OpEntryPoint Vertex %func "shader"
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Invalid storage class for target environment"));
}
TEST_F(ValidateMisc, CoopMat2WorkgroupLocalSizeIdPass) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionModeId %main LocalSizeId %u32_16 %u32_16 %u32_16
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%u32_16 = OpConstant %u32 16
%use_Acc = OpConstant %u32 2
%workgroup = OpConstant %u32 2
%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %use_Acc
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str(), SPV_ENV_UNIVERSAL_1_3);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
}
TEST_F(ValidateMisc, CoopMat2WorkgroupLocalSizeIdConstantNotDeclaredYetFail) {
const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionModeId %main LocalSizeId %u32_16 %u32_8 %u32_16
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%u32_16 = OpConstant %u32 16
%use_Acc = OpConstant %u32 2
%workgroup = OpConstant %u32 2
%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %use_Acc
%u32_8 = OpConstant %u32 8
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str(), SPV_ENV_UNIVERSAL_1_3);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixKHR with ScopeWorkgroup used "
"before LocalSizeId constant value"));
}
} // namespace
} // namespace val
} // namespace spvtools

View File

@ -43,7 +43,8 @@ AUTHORS = ['The Khronos Group Inc.',
'Shiyu Liu',
'ZHOU He',
'Nintendo',
'Epic Games, Inc.']
'Epic Games, Inc.',
'NVIDIA Corporation']
CURRENT_YEAR = 2023
FIRST_YEAR = 2014