From ce92630396c2fd2d6d04819369116af4fb141a28 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 24 Oct 2024 12:55:04 -0500 Subject: [PATCH] Add validation for SPV_NV_tensor_addressing and SPV_NV_cooperative_matrix2 (#5865) --- Android.mk | 1 + BUILD.gn | 1 + DEPS | 2 +- include/spirv-tools/libspirv.h | 6 + source/CMakeLists.txt | 1 + source/binary.cpp | 5 +- source/opcode.cpp | 2 + source/operand.cpp | 19 + source/opt/aggressive_dead_code_elim_pass.cpp | 6 + source/opt/eliminate_dead_members_pass.cpp | 15 + source/opt/ir_context.cpp | 28 +- source/opt/type_manager.cpp | 48 ++ source/opt/types.cpp | 54 +- source/opt/types.h | 53 ++ source/text.cpp | 4 +- source/val/validate.cpp | 1 + source/val/validate.h | 3 + source/val/validate_arithmetics.cpp | 128 +++- source/val/validate_conversion.cpp | 53 +- source/val/validate_function.cpp | 83 ++- source/val/validate_instruction.cpp | 6 + source/val/validate_memory.cpp | 227 ++++++- source/val/validate_tensor_layout.cpp | 184 +++++ source/val/validate_type.cpp | 147 ++++ source/val/validation_state.cpp | 20 +- source/val/validation_state.h | 30 +- test/opt/type_manager_test.cpp | 6 + test/val/val_arithmetics_test.cpp | 255 ++++++- test/val/val_conversion_test.cpp | 190 +++++- test/val/val_memory_test.cpp | 633 ++++++++++++++++++ test/val/val_misc_test.cpp | 71 ++ utils/check_copyright.py | 3 +- 32 files changed, 2247 insertions(+), 38 deletions(-) create mode 100644 source/val/validate_tensor_layout.cpp diff --git a/Android.mk b/Android.mk index abbd42b23..1414b5249 100644 --- a/Android.mk +++ b/Android.mk @@ -75,6 +75,7 @@ SPVTOOLS_SRC_FILES := \ source/val/validate_ray_tracing_reorder.cpp \ source/val/validate_scopes.cpp \ source/val/validate_small_type_uses.cpp \ + source/val/validate_tensor_layout.cpp \ source/val/validate_type.cpp SPVTOOLS_OPT_SRC_FILES := \ diff --git a/BUILD.gn b/BUILD.gn index 59a34d6f9..23610262e 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -557,6 +557,7 @@ static_library("spvtools_val") { "source/val/validate_scopes.cpp", "source/val/validate_scopes.h", "source/val/validate_small_type_uses.cpp", + "source/val/validate_tensor_layout.cpp", "source/val/validate_type.cpp", "source/val/validation_state.cpp", "source/val/validation_state.h", diff --git a/DEPS b/DEPS index 69ad61090..4543be7bc 100644 --- a/DEPS +++ b/DEPS @@ -14,7 +14,7 @@ vars = { 're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca', - 'spirv_headers_revision': '50bc4debdc3eec5045edbeb8ce164090e29b91f3', + 'spirv_headers_revision': '22c4d1b1e9d1c7d9aa5086c93e6491f21080019b', } deps = { diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index 1a21ccef0..70931cb6c 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -314,6 +314,12 @@ typedef enum spv_operand_type_t { SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS, // Optional enum type from SPV_NV_raw_access_chains SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS, + // Enum type from SPV_NV_tensor_addressing + SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE, + // Enum type from SPV_NV_cooperative_matrix2 + SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE, + // Enum type from SPV_NV_cooperative_matrix2 + SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS, // This is a sentinel value, and does not represent an operand type. // It should come last. diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index cb6026ad5..b20357bbc 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -334,6 +334,7 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_tracing_reorder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h ${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp diff --git a/source/binary.cpp b/source/binary.cpp index 772e98c0a..ed5749857 100644 --- a/source/binary.cpp +++ b/source/binary.cpp @@ -717,13 +717,16 @@ spv_result_t Parser::parseOperand(size_t inst_offset, case SPV_OPERAND_TYPE_LOOP_CONTROL: case SPV_OPERAND_TYPE_IMAGE: case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + case SPV_OPERAND_TYPE_MEMORY_ACCESS: case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS: - case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: { + case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: + case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: + case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: { // This operand is a mask. // Map an optional operand type to its corresponding concrete type. diff --git a/source/opcode.cpp b/source/opcode.cpp index ea03bd671..2b25fc3da 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -382,6 +382,8 @@ int32_t spvOpcodeGeneratesType(spv::Op op) { case spv::Op::OpTypeRayQueryKHR: case spv::Op::OpTypeHitObjectNV: case spv::Op::OpTypeUntypedPointerKHR: + case spv::Op::OpTypeTensorLayoutNV: + case spv::Op::OpTypeTensorViewNV: return true; default: // In particular, OpTypeForwardPointer does not generate a type, diff --git a/source/operand.cpp b/source/operand.cpp index 9bee647fc..548564662 100644 --- a/source/operand.cpp +++ b/source/operand.cpp @@ -231,6 +231,12 @@ const char* spvOperandTypeStr(spv_operand_type_t type) { return "cooperative matrix layout"; case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE: return "cooperative matrix use"; + case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE: + return "tensor clamp mode"; + case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: + return "cooperative matrix reduce"; + case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: + return "tensor addressing operands"; case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER: return "initialization mode qualifier"; case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER: @@ -389,6 +395,7 @@ bool spvOperandIsConcrete(spv_operand_type_t type) { case SPV_OPERAND_TYPE_STORE_CACHE_CONTROL: case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS: case SPV_OPERAND_TYPE_FPENCODING: + case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE: return true; default: break; @@ -409,6 +416,8 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) { case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS: case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS: + case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: + case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: return true; default: break; @@ -598,6 +607,16 @@ std::function spvOperandCanBeForwardDeclaredFunction( case spv::Op::OpTypeArray: out = [](unsigned index) { return index == 1; }; break; + case spv::Op::OpCooperativeMatrixPerElementOpNV: + out = [](unsigned index) { return index == 3; }; + break; + case spv::Op::OpCooperativeMatrixReduceNV: + out = [](unsigned index) { return index == 4; }; + break; + case spv::Op::OpCooperativeMatrixLoadTensorNV: + // approximate, due to variable operands + out = [](unsigned index) { return index > 6; }; + break; default: out = [](unsigned) { return false; }; break; diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 444b0dbc2..d78d63ca0 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -43,6 +43,7 @@ constexpr uint32_t kGlobalVariableVariableIndex = 12; constexpr uint32_t kExtInstSetInIdx = 0; constexpr uint32_t kExtInstOpInIdx = 1; constexpr uint32_t kInterpolantInIdx = 2; +constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0; // Sorting functor to present annotation instructions in an easy-to-process // order. The functor orders by opcode first and falls back on unique id @@ -438,6 +439,11 @@ uint32_t AggressiveDCEPass::GetLoadedVariableFromNonFunctionCalls( } break; } + case spv::Op::OpCooperativeMatrixLoadNV: + case spv::Op::OpCooperativeMatrixLoadKHR: + case spv::Op::OpCooperativeMatrixLoadTensorNV: + return GetVariableId( + inst->GetSingleWordInOperand(kCooperativeMatrixLoadSourceAddrInIdx)); default: break; } diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp index 1c98502e2..e440296ff 100644 --- a/source/opt/eliminate_dead_members_pass.cpp +++ b/source/opt/eliminate_dead_members_pass.cpp @@ -70,6 +70,11 @@ void EliminateDeadMembersPass::FindLiveMembers() { MarkPointeeTypeAsFullUsed(inst.type_id()); break; } + } else if (inst.opcode() == spv::Op::OpTypePointer) { + uint32_t storage_class = inst.GetSingleWordInOperand(0); + if (storage_class == uint32_t(spv::StorageClass::PhysicalStorageBuffer)) { + MarkTypeAsFullyUsed(inst.GetSingleWordInOperand(1)); + } } } @@ -200,6 +205,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract( case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeVector: case spv::Op::OpTypeMatrix: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -246,6 +253,8 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain( case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeVector: case spv::Op::OpTypeMatrix: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -505,6 +514,8 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) { case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeVector: case spv::Op::OpTypeMatrix: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: new_operands.emplace_back(inst->GetInOperand(i)); type_id = type_inst->GetSingleWordInOperand(0); break; @@ -578,6 +589,8 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeVector: case spv::Op::OpTypeMatrix: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -639,6 +652,8 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) { case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeVector: case spv::Op::OpTypeMatrix: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: type_id = type_inst->GetSingleWordInOperand(0); break; default: diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index d864b7c02..1cf0d74bf 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -926,9 +926,35 @@ uint32_t IRContext::GetBuiltinInputVarId(uint32_t builtin) { void IRContext::AddCalls(const Function* func, std::queue* todo) { for (auto bi = func->begin(); bi != func->end(); ++bi) - for (auto ii = bi->begin(); ii != bi->end(); ++ii) + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { if (ii->opcode() == spv::Op::OpFunctionCall) todo->push(ii->GetSingleWordInOperand(0)); + if (ii->opcode() == spv::Op::OpCooperativeMatrixPerElementOpNV) + todo->push(ii->GetSingleWordInOperand(1)); + if (ii->opcode() == spv::Op::OpCooperativeMatrixReduceNV) + todo->push(ii->GetSingleWordInOperand(2)); + if (ii->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) { + const auto memory_operands_index = 3; + auto mask = ii->GetSingleWordInOperand(memory_operands_index); + + uint32_t count = 1; + if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++count; + if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) + ++count; + if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) + ++count; + + const auto tensor_operands_index = memory_operands_index + count; + mask = ii->GetSingleWordInOperand(tensor_operands_index); + count = 1; + if (mask & uint32_t(spv::TensorAddressingOperandsMask::TensorView)) + ++count; + + if (mask & uint32_t(spv::TensorAddressingOperandsMask::DecodeFunc)) { + todo->push(ii->GetSingleWordInOperand(tensor_operands_index + count)); + } + } + } } bool IRContext::ProcessEntryPointCallTree(ProcessFunction& pfn) { diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index 79648ad49..88106b608 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -441,6 +441,28 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { {SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}}); break; } + case Type::kTensorLayoutNV: { + auto tensor_layout = type->AsTensorLayoutNV(); + typeInst = MakeUnique( + context(), spv::Op::OpTypeTensorLayoutNV, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {tensor_layout->dim_id()}}, + {SPV_OPERAND_TYPE_ID, {tensor_layout->clamp_mode_id()}}}); + break; + } + case Type::kTensorViewNV: { + auto tensor_view = type->AsTensorViewNV(); + std::vector operands; + operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {tensor_view->dim_id()}}); + operands.push_back( + Operand{SPV_OPERAND_TYPE_ID, {tensor_view->has_dimensions_id()}}); + for (auto p : tensor_view->perm()) { + operands.push_back(Operand{SPV_OPERAND_TYPE_ID, {p}}); + } + typeInst = MakeUnique(context(), spv::Op::OpTypeTensorViewNV, + 0, id, operands); + break; + } default: assert(false && "Unexpected type"); break; @@ -667,6 +689,18 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) { cm_type->use_id()); break; } + case Type::kTensorLayoutNV: { + const TensorLayoutNV* tl_type = type.AsTensorLayoutNV(); + rebuilt_ty = MakeUnique(tl_type->dim_id(), + tl_type->clamp_mode_id()); + break; + } + case Type::kTensorViewNV: { + const TensorViewNV* tv_type = type.AsTensorViewNV(); + rebuilt_ty = MakeUnique( + tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm()); + break; + } default: assert(false && "Unhandled type"); return nullptr; @@ -914,6 +948,20 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { case spv::Op::OpTypeHitObjectNV: type = new HitObjectNV(); break; + case spv::Op::OpTypeTensorLayoutNV: + type = new TensorLayoutNV(inst.GetSingleWordInOperand(0), + inst.GetSingleWordInOperand(1)); + break; + case spv::Op::OpTypeTensorViewNV: { + const auto count = inst.NumOperands(); + std::vector perm; + for (uint32_t i = 2; i < count; ++i) { + perm.push_back(inst.GetSingleWordOperand(i)); + } + type = new TensorViewNV(inst.GetSingleWordInOperand(0), + inst.GetSingleWordInOperand(1), perm); + break; + } default: assert(false && "Type not handled by the type manager."); break; diff --git a/source/opt/types.cpp b/source/opt/types.cpp index b18b8cb1a..8ccf6c9a2 100644 --- a/source/opt/types.cpp +++ b/source/opt/types.cpp @@ -179,6 +179,8 @@ bool Type::operator==(const Type& other) const { DeclareKindCase(CooperativeMatrixKHR); DeclareKindCase(RayQueryKHR); DeclareKindCase(HitObjectNV); + DeclareKindCase(TensorLayoutNV); + DeclareKindCase(TensorViewNV); #undef DeclareKindCase default: assert(false && "Unhandled type"); @@ -235,6 +237,8 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const { DeclareKindCase(CooperativeMatrixKHR); DeclareKindCase(RayQueryKHR); DeclareKindCase(HitObjectNV); + DeclareKindCase(TensorLayoutNV); + DeclareKindCase(TensorViewNV); #undef DeclareKindCase default: assert(false && "Unhandled type"); @@ -747,7 +751,55 @@ bool CooperativeMatrixKHR::IsSameImpl(const Type* that, if (!mt) return false; return component_type_->IsSameImpl(mt->component_type_, seen) && scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ && - columns_id_ == mt->columns_id_ && HasSameDecorations(that); + columns_id_ == mt->columns_id_ && use_id_ == mt->use_id_ && + HasSameDecorations(that); +} + +TensorLayoutNV::TensorLayoutNV(const uint32_t dim, const uint32_t clamp_mode) + : Type(kTensorLayoutNV), dim_id_(dim), clamp_mode_id_(clamp_mode) {} + +std::string TensorLayoutNV::str() const { + std::ostringstream oss; + oss << "<" << dim_id_ << ", " << clamp_mode_id_ << ">"; + return oss.str(); +} + +size_t TensorLayoutNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, dim_id_, clamp_mode_id_); +} + +bool TensorLayoutNV::IsSameImpl(const Type* that, IsSameCache*) const { + const TensorLayoutNV* tl = that->AsTensorLayoutNV(); + if (!tl) return false; + return dim_id_ == tl->dim_id_ && clamp_mode_id_ == tl->clamp_mode_id_; +} + +TensorViewNV::TensorViewNV(const uint32_t dim, const uint32_t clamp_mode, + const std::vector& perm) + : Type(kTensorViewNV), + dim_id_(dim), + has_dimensions_id_(clamp_mode), + perm_(perm) {} + +std::string TensorViewNV::str() const { + std::ostringstream oss; + oss << "<" << dim_id_ << ", " << has_dimensions_id_; + for (auto p : perm_) { + oss << ", " << p; + } + oss << ">"; + return oss.str(); +} + +size_t TensorViewNV::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, dim_id_, has_dimensions_id_, perm_); +} + +bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const { + const TensorViewNV* tv = that->AsTensorViewNV(); + if (!tv) return false; + return dim_id_ == tv->dim_id_ && + has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_; } } // namespace analysis diff --git a/source/opt/types.h b/source/opt/types.h index 16a948cec..6092c3c92 100644 --- a/source/opt/types.h +++ b/source/opt/types.h @@ -63,6 +63,8 @@ class CooperativeMatrixNV; class CooperativeMatrixKHR; class RayQueryKHR; class HitObjectNV; +class TensorLayoutNV; +class TensorViewNV; // Abstract class for a SPIR-V type. It has a bunch of As() methods, // which is used as a way to probe the actual . @@ -104,6 +106,8 @@ class Type { kCooperativeMatrixKHR, kRayQueryKHR, kHitObjectNV, + kTensorLayoutNV, + kTensorViewNV, kLast }; @@ -206,6 +210,8 @@ class Type { DeclareCastMethod(CooperativeMatrixKHR) DeclareCastMethod(RayQueryKHR) DeclareCastMethod(HitObjectNV) + DeclareCastMethod(TensorLayoutNV) + DeclareCastMethod(TensorViewNV) #undef DeclareCastMethod protected: @@ -659,6 +665,53 @@ class CooperativeMatrixKHR : public Type { const uint32_t use_id_; }; +class TensorLayoutNV : public Type { + public: + TensorLayoutNV(const uint32_t dim, const uint32_t clamp_mode); + TensorLayoutNV(const TensorLayoutNV&) = default; + + std::string str() const override; + + TensorLayoutNV* AsTensorLayoutNV() override { return this; } + const TensorLayoutNV* AsTensorLayoutNV() const override { return this; } + + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; + + uint32_t dim_id() const { return dim_id_; } + uint32_t clamp_mode_id() const { return clamp_mode_id_; } + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const uint32_t dim_id_; + const uint32_t clamp_mode_id_; +}; + +class TensorViewNV : public Type { + public: + TensorViewNV(const uint32_t dim, const uint32_t clamp_mode, + const std::vector& perm); + TensorViewNV(const TensorViewNV&) = default; + + std::string str() const override; + + TensorViewNV* AsTensorViewNV() override { return this; } + const TensorViewNV* AsTensorViewNV() const override { return this; } + + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; + + uint32_t dim_id() const { return dim_id_; } + uint32_t has_dimensions_id() const { return has_dimensions_id_; } + const std::vector& perm() const { return perm_; } + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const uint32_t dim_id_; + const uint32_t has_dimensions_id_; + std::vector perm_; +}; + #define DefineParameterlessType(type, name) \ class type : public Type { \ public: \ diff --git a/source/text.cpp b/source/text.cpp index fda46ec2e..1723bb317 100644 --- a/source/text.cpp +++ b/source/text.cpp @@ -414,7 +414,9 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar, case SPV_OPERAND_TYPE_SELECTION_CONTROL: case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: - case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: { + case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: + case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: + case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: { uint32_t value; if (auto error = grammar.parseMaskOperand(type, textValue, &value)) { return context->diagnostic(error) diff --git a/source/val/validate.cpp b/source/val/validate.cpp index 32368075c..2d10347d8 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp @@ -366,6 +366,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( if (auto error = RayTracingPass(*vstate, &instruction)) return error; if (auto error = RayReorderNVPass(*vstate, &instruction)) return error; if (auto error = MeshShadingPass(*vstate, &instruction)) return error; + if (auto error = TensorLayoutPass(*vstate, &instruction)) return error; } // Validate the preconditions involving adjacent instructions. e.g. diff --git a/source/val/validate.h b/source/val/validate.h index 78093ce5f..5514ff738 100644 --- a/source/val/validate.h +++ b/source/val/validate.h @@ -226,6 +226,9 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst); /// Calculates the reachability of basic blocks. void ReachabilityPass(ValidationState_t& _); +/// Validates tensor layout and view instructions. +spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst); + /// Validates execution limitations. /// /// Verifies execution models are allowed for all functionality they contain. diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp index b608a8595..8b0049c5b 100644 --- a/source/val/validate_arithmetics.cpp +++ b/source/val/validate_arithmetics.cpp @@ -62,7 +62,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { << operand_index; } spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, type_id, result_type); + _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); if (ret != SPV_SUCCESS) return ret; } else if (_.GetOperandTypeId(inst, operand_index) != result_type) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -96,7 +96,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { << operand_index; } spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, type_id, result_type); + _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); if (ret != SPV_SUCCESS) return ret; } else if (_.GetOperandTypeId(inst, operand_index) != result_type) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -142,7 +142,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { << operand_index; } spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, type_id, result_type); + _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); if (ret != SPV_SUCCESS) return ret; } @@ -672,6 +672,128 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { break; } + case spv::Op::OpCooperativeMatrixReduceNV: { + if (!_.IsCooperativeMatrixKHRType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type must be a cooperative matrix type: " + << spvOpcodeString(opcode); + } + + const auto result_comp_type_id = + _.FindDef(result_type)->GetOperandAs(1); + + const auto matrix_id = inst->GetOperandAs(2); + const auto matrix = _.FindDef(matrix_id); + const auto matrix_type_id = matrix->type_id(); + if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Matrix must have a cooperative matrix type: " + << spvOpcodeString(opcode); + } + const auto matrix_type = _.FindDef(matrix_type_id); + const auto matrix_comp_type_id = matrix_type->GetOperandAs(1); + if (matrix_comp_type_id != result_comp_type_id) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type and Matrix type must have the same component " + "type: " + << spvOpcodeString(opcode); + } + if (_.FindDef(result_type)->GetOperandAs(2) != + matrix_type->GetOperandAs(2)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type and Matrix type must have the same scope: " + << spvOpcodeString(opcode); + } + + if (!_.IsCooperativeMatrixAccType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type must have UseAccumulator: " + << spvOpcodeString(opcode); + } + if (!_.IsCooperativeMatrixAccType(matrix_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Matrix type must have UseAccumulator: " + << spvOpcodeString(opcode); + } + + const auto reduce_value = inst->GetOperandAs(3); + + if ((reduce_value & + uint32_t( + spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) && + (reduce_value & uint32_t(spv::CooperativeMatrixReduceMask::Row | + spv::CooperativeMatrixReduceMask::Column))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Reduce 2x2 must not be used with Row/Column: " + << spvOpcodeString(opcode); + } + + std::tuple result_rows, result_cols, matrix_rows, + matrix_cols; + result_rows = + _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs(3)); + result_cols = + _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs(4)); + matrix_rows = _.EvalInt32IfConst(matrix_type->GetOperandAs(3)); + matrix_cols = _.EvalInt32IfConst(matrix_type->GetOperandAs(4)); + + if (reduce_value & + uint32_t( + spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) { + if (std::get<1>(result_rows) && std::get<1>(result_cols) && + std::get<1>(matrix_rows) && std::get<1>(matrix_cols) && + (std::get<2>(result_rows) != std::get<2>(matrix_rows) / 2 || + std::get<2>(result_cols) != std::get<2>(matrix_cols) / 2)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For Reduce2x2, result rows/cols must be half of matrix " + "rows/cols: " + << spvOpcodeString(opcode); + } + } + if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Row)) { + if (std::get<1>(result_rows) && std::get<1>(matrix_rows) && + std::get<2>(result_rows) != std::get<2>(matrix_rows)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For ReduceRow, result rows must match matrix rows: " + << spvOpcodeString(opcode); + } + } + if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Column)) { + if (std::get<1>(result_cols) && std::get<1>(matrix_cols) && + std::get<2>(result_cols) != std::get<2>(matrix_cols)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For ReduceColumn, result cols must match matrix cols: " + << spvOpcodeString(opcode); + } + } + + const auto combine_func_id = inst->GetOperandAs(4); + const auto combine_func = _.FindDef(combine_func_id); + if (!combine_func || combine_func->opcode() != spv::Op::OpFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "CombineFunc must be a function: " << spvOpcodeString(opcode); + } + const auto function_type_id = combine_func->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (function_type->operands().size() != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "CombineFunc must have two parameters: " + << spvOpcodeString(opcode); + } + for (uint32_t i = 0; i < 3; ++i) { + // checks return type and two params + const auto param_type_id = function_type->GetOperandAs(i + 1); + if (param_type_id != matrix_comp_type_id) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "CombineFunc return type and parameters must match matrix " + "component type: " + << spvOpcodeString(opcode); + } + } + + break; + } + default: break; } diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp index b2892a863..770b8e2e3 100644 --- a/source/val/validate_conversion.cpp +++ b/source/val/validate_conversion.cpp @@ -49,7 +49,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -79,7 +79,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -111,7 +111,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -142,7 +142,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -177,7 +177,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -213,7 +213,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { if (_.IsCooperativeMatrixType(result_type) || _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; } else { if (_.GetDimension(result_type) != _.GetDimension(input_type)) @@ -497,8 +497,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { << "matrix: " << spvOpcodeString(opcode); if (result_is_coopmat) { - spv_result_t ret = - _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, + input_type, false); if (ret != SPV_SUCCESS) return ret; } @@ -568,6 +568,43 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { break; } + case spv::Op::OpCooperativeMatrixConvertNV: + case spv::Op::OpCooperativeMatrixTransposeNV: { + if (!_.IsCooperativeMatrixType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix Result Type: " + << spvOpcodeString(opcode); + } + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!_.IsCooperativeMatrixType(input_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix type for Matrix input: " + << spvOpcodeString(opcode); + } + + bool swap_row_col = (opcode == spv::Op::OpCooperativeMatrixTransposeNV); + if (auto error = _.CooperativeMatrixShapesMatch( + inst, result_type, input_type, true, swap_row_col)) + return error; + + if (opcode == spv::Op::OpCooperativeMatrixConvertNV) { + if (_.FindDef(result_type)->GetOperandAs(1) != + _.FindDef(input_type)->GetOperandAs(1)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type and Matrix component types mismatch: " + << spvOpcodeString(opcode); + } + } + + if (opcode == spv::Op::OpCooperativeMatrixTransposeNV) { + if (!_.IsCooperativeMatrixBType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type must have UseB: " << spvOpcodeString(opcode); + } + } + break; + } + default: break; } diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index 26b366828..67758c66f 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp @@ -86,7 +86,10 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple, spv::Op::OpGetKernelLocalSizeForSubgroupCount, spv::Op::OpGetKernelMaxNumSubgroups, - spv::Op::OpName}; + spv::Op::OpName, + spv::Op::OpCooperativeMatrixPerElementOpNV, + spv::Op::OpCooperativeMatrixReduceNV, + spv::Op::OpCooperativeMatrixLoadTensorNV}; for (auto& pair : inst->uses()) { const auto* use = pair.first; if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == @@ -341,6 +344,80 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, return SPV_SUCCESS; } +spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _, + const Instruction* inst) { + const auto function_id = inst->GetOperandAs(3); + const auto function = _.FindDef(function_id); + if (!function || spv::Op::OpFunction != function->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV Function " + << _.getIdName(function_id) << " is not a function."; + } + + const auto matrix_id = inst->GetOperandAs(2); + const auto matrix = _.FindDef(matrix_id); + const auto matrix_type_id = matrix->type_id(); + if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV Matrix " + << _.getIdName(matrix_id) << " is not a cooperative matrix."; + } + + const auto result_type_id = inst->GetOperandAs(0); + if (matrix_type_id != result_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV Result Type " + << _.getIdName(result_type_id) << " must match matrix type " + << _.getIdName(matrix_type_id) << "."; + } + + const auto matrix_comp_type_id = + _.FindDef(matrix_type_id)->GetOperandAs(1); + const auto function_type_id = function->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + auto return_type_id = function_type->GetOperandAs(1); + if (return_type_id != matrix_comp_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV function return type " + << _.getIdName(return_type_id) + << " must match matrix component type " + << _.getIdName(matrix_comp_type_id) << "."; + } + + if (function_type->operands().size() < 5) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV function type " + << _.getIdName(function_type_id) + << " must have a least three parameters."; + } + + const auto param0_id = function_type->GetOperandAs(2); + const auto param1_id = function_type->GetOperandAs(3); + const auto param2_id = function_type->GetOperandAs(4); + if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV function type first parameter " + "type " + << _.getIdName(param0_id) << " must be a 32-bit integer."; + } + + if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV function type second " + "parameter type " + << _.getIdName(param1_id) << " must be a 32-bit integer."; + } + + if (param2_id != matrix_comp_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixPerElementOpNV function type third parameter " + "type " + << _.getIdName(param2_id) << " must match matrix component type."; + } + + return SPV_SUCCESS; +} + } // namespace spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { @@ -354,6 +431,10 @@ spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpFunctionCall: if (auto error = ValidateFunctionCall(_, inst)) return error; break; + case spv::Op::OpCooperativeMatrixPerElementOpNV: + if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst)) + return error; + break; default: break; } diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp index 5bc4d2cef..39b1c0216 100644 --- a/source/val/validate_instruction.cpp +++ b/source/val/validate_instruction.cpp @@ -475,6 +475,12 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t entry_point = inst->word(1); _.RegisterExecutionModeForEntryPoint(entry_point, spv::ExecutionMode(inst->word(2))); + if (inst->GetOperandAs(1) == + spv::ExecutionMode::LocalSize || + inst->GetOperandAs(1) == + spv::ExecutionMode::LocalSizeId) { + _.RegisterEntryPointLocalSize(entry_point, inst); + } } else if (opcode == spv::Op::OpVariable) { const auto storage_class = inst->GetOperandAs(2); if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) { diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 830c3ea11..ef05d6e6c 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -233,6 +233,7 @@ std::pair GetStorageClass( spv::StorageClass src_sc = spv::StorageClass::Max; switch (inst->opcode()) { case spv::Op::OpCooperativeMatrixLoadNV: + case spv::Op::OpCooperativeMatrixLoadTensorNV: case spv::Op::OpCooperativeMatrixLoadKHR: case spv::Op::OpLoad: { auto load_pointer = _.FindDef(inst->GetOperandAs(2)); @@ -241,6 +242,7 @@ std::pair GetStorageClass( break; } case spv::Op::OpCooperativeMatrixStoreNV: + case spv::Op::OpCooperativeMatrixStoreTensorNV: case spv::Op::OpCooperativeMatrixStoreKHR: case spv::Op::OpStore: { auto store_pointer = _.FindDef(inst->GetOperandAs(0)); @@ -330,6 +332,7 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) { if (inst->opcode() == spv::Op::OpLoad || inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV || + inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV || inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerAvailableKHR cannot be used with OpLoad."; @@ -350,7 +353,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) { if (inst->opcode() == spv::Op::OpStore || inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV || - inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR) { + inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR || + inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerVisibleKHR cannot be used with OpStore."; } @@ -2176,6 +2180,222 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _, return SPV_SUCCESS; } +// Returns the number of instruction words taken up by a tensor addressing +// operands argument and its implied operands. +int TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask) { + int result = 1; // Count the mask + if ((mask & spv::TensorAddressingOperandsMask::TensorView) != + spv::TensorAddressingOperandsMask::MaskNone) + ++result; + if ((mask & spv::TensorAddressingOperandsMask::DecodeFunc) != + spv::TensorAddressingOperandsMask::MaskNone) + ++result; + return result; +} + +spv_result_t ValidateCooperativeMatrixLoadStoreTensorNV( + ValidationState_t& _, const Instruction* inst) { + uint32_t type_id; + const char* opname; + if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) { + type_id = inst->type_id(); + opname = "spv::Op::OpCooperativeMatrixLoadTensorNV"; + } else { + // get Object operand's type + type_id = _.FindDef(inst->GetOperandAs(1))->type_id(); + opname = "spv::Op::OpCooperativeMatrixStoreTensorNV"; + } + + auto matrix_type = _.FindDef(type_id); + + if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) { + if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "spv::Op::OpCooperativeMatrixLoadTensorNV Result Type " + << _.getIdName(type_id) << " is not a cooperative matrix type."; + } else { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "spv::Op::OpCooperativeMatrixStoreTensorNV Object type " + << _.getIdName(type_id) << " is not a cooperative matrix type."; + } + } + + const auto pointer_index = + (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 2u : 0u; + const auto pointer_id = inst->GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + ((_.addressing_model() == spv::AddressingModel::Logical) && + ((!_.features().variable_pointers && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (_.features().variable_pointers && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer " << _.getIdName(pointer_id) + << " is not a logical pointer."; + } + + const auto pointer_type_id = pointer->type_id(); + const auto pointer_type = _.FindDef(pointer_type_id); + if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " type for pointer " << _.getIdName(pointer_id) + << " is not a pointer type."; + } + + const auto storage_class_index = 1u; + const auto storage_class = + pointer_type->GetOperandAs(storage_class_index); + + if (storage_class != spv::StorageClass::Workgroup && + storage_class != spv::StorageClass::StorageBuffer && + storage_class != spv::StorageClass::PhysicalStorageBuffer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << _.VkErrorID(8973) << opname + << " storage class for pointer type " + << _.getIdName(pointer_type_id) + << " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer."; + } + + if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) { + const auto object_index = 3; + const auto object_id = inst->GetOperandAs(object_index); + const auto object = _.FindDef(object_id); + if (!object || object->type_id() != type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Object " << _.getIdName(object_id) + << " type does not match Result Type."; + } + } + + const auto tensor_layout_index = + (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 4u : 2u; + const auto tensor_layout_id = + inst->GetOperandAs(tensor_layout_index); + const auto tensor_layout = _.FindDef(tensor_layout_id); + if (!tensor_layout || _.FindDef(tensor_layout->type_id())->opcode() != + spv::Op::OpTypeTensorLayoutNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " TensorLayout " << _.getIdName(tensor_layout_id) + << " does not have a tensor layout type."; + } + + const auto memory_access_index = + (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 5u : 3u; + if (inst->operands().size() > memory_access_index) { + if (auto error = CheckMemoryAccess(_, inst, memory_access_index)) + return error; + } + + const auto memory_access_mask = + inst->GetOperandAs(memory_access_index); + const auto tensor_operands_index = + memory_access_index + MemoryAccessNumWords(memory_access_mask); + const auto tensor_operands = + inst->GetOperandAs( + tensor_operands_index); + + if (inst->operands().size() < + tensor_operands_index + + TensorAddressingOperandsNumWords(tensor_operands)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " not enough tensor addressing operands."; + } + + uint32_t tensor_operand_index = tensor_operands_index + 1; + if ((tensor_operands & spv::TensorAddressingOperandsMask::TensorView) != + spv::TensorAddressingOperandsMask::MaskNone) { + const auto tensor_view_id = + inst->GetOperandAs(tensor_operand_index); + const auto tensor_view = _.FindDef(tensor_view_id); + if (!tensor_view || _.FindDef(tensor_view->type_id())->opcode() != + spv::Op::OpTypeTensorViewNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " TensorView " << _.getIdName(tensor_view_id) + << " does not have a tensor view type."; + } + + tensor_operand_index++; + } + + if ((tensor_operands & spv::TensorAddressingOperandsMask::DecodeFunc) != + spv::TensorAddressingOperandsMask::MaskNone) { + if (inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpCooperativeMatrixStoreTensorNV does not support DecodeFunc."; + } + const auto decode_func_id = + inst->GetOperandAs(tensor_operand_index); + const auto decode_func = _.FindDef(decode_func_id); + + if (!decode_func || decode_func->opcode() != spv::Op::OpFunction) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " DecodeFunc " << _.getIdName(decode_func_id) + << " is not a function."; + } + + const auto component_type_index = 1; + const auto component_type_id = + matrix_type->GetOperandAs(component_type_index); + + const auto function_type = + _.FindDef(decode_func->GetOperandAs(3)); + if (function_type->GetOperandAs(1) != component_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " DecodeFunc " << _.getIdName(decode_func_id) + << " return type must match matrix component type."; + } + + const auto decode_ptr_type_id = function_type->GetOperandAs(2); + const auto decode_ptr_type = _.FindDef(decode_ptr_type_id); + auto decode_storage_class = + decode_ptr_type->GetOperandAs(storage_class_index); + + if (decode_storage_class != spv::StorageClass::PhysicalStorageBuffer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " DecodeFunc " << _.getIdName(decode_func_id) + << " first parameter must be pointer to PhysicalStorageBuffer."; + } + + const auto tensor_layout_type = _.FindDef(tensor_layout->type_id()); + + for (uint32_t param = 3; param < 5; ++param) { + const auto param_type_id = function_type->GetOperandAs(param); + const auto param_type = _.FindDef(param_type_id); + if (param_type->opcode() != spv::Op::OpTypeArray) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " DecodeFunc " << _.getIdName(decode_func_id) + << " second/third parameter must be array of 32-bit integer " + "with " + << " dimension equal to the tensor dimension."; + } + const auto length_index = 2u; + uint64_t array_length; + if (_.EvalConstantValUint64( + param_type->GetOperandAs(length_index), + &array_length)) { + const auto tensor_layout_dim_id = + tensor_layout_type->GetOperandAs(1); + uint64_t dim_value; + if (_.EvalConstantValUint64(tensor_layout_dim_id, &dim_value)) { + if (array_length != dim_value) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " DecodeFunc " + << _.getIdName(decode_func_id) + << " second/third parameter must be array of 32-bit integer " + "with " + << " dimension equal to the tensor dimension."; + } + } + } + } + + tensor_operand_index++; + } + + return SPV_SUCCESS; +} + spv_result_t ValidatePtrComparison(ValidationState_t& _, const Instruction* inst) { if (_.addressing_model() == spv::AddressingModel::Logical && @@ -2284,6 +2504,11 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst)) return error; break; + case spv::Op::OpCooperativeMatrixLoadTensorNV: + case spv::Op::OpCooperativeMatrixStoreTensorNV: + if (auto error = ValidateCooperativeMatrixLoadStoreTensorNV(_, inst)) + return error; + break; case spv::Op::OpPtrEqual: case spv::Op::OpPtrNotEqual: case spv::Op::OpPtrDiff: diff --git a/source/val/validate_tensor_layout.cpp b/source/val/validate_tensor_layout.cpp new file mode 100644 index 000000000..35c766b83 --- /dev/null +++ b/source/val/validate_tensor_layout.cpp @@ -0,0 +1,184 @@ +// Copyright (c) 2024 NVIDIA Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validate instructions that manipulate tensor layout and view objects + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _, + const Instruction* inst) { + const auto result_type_index = 0; + const auto result_type_id = inst->GetOperandAs(result_type_index); + const auto result_type = _.FindDef(result_type_id); + + if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Result Type " + << _.getIdName(result_type_id) << " is not a tensor layout type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _, + const Instruction* inst) { + const auto result_type_index = 0; + const auto result_type_id = inst->GetOperandAs(result_type_index); + const auto result_type = _.FindDef(result_type_id); + + if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Result Type " + << _.getIdName(result_type_id) << " is not a tensor view type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _, + const Instruction* inst) { + if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; + + return SPV_SUCCESS; +} + +spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _, + const Instruction* inst) { + if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; + + return SPV_SUCCESS; +} + +enum ExpectedNumValues { + DIM, + DIMx2, + ONE, + FOUR, +}; + +spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _, + const Instruction* inst, + ExpectedNumValues expected, + bool is_view) { + std::string type_str; + if (is_view) { + if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; + type_str = "TensorView"; + } else { + if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; + type_str = "TensorLayout"; + } + + const auto result_type_id = inst->GetOperandAs(0); + const auto tensor_id = inst->GetOperandAs(2); + const auto tensor = _.FindDef(tensor_id); + if (!tensor || result_type_id != tensor->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Result Type " + << _.getIdName(result_type_id) << " does not match " << type_str + << " type."; + } + + const auto num_values = inst->operands().size() - 3; + + const auto result_type = _.FindDef(result_type_id); + const auto dim_index = 1; + const auto dim_id = result_type->GetOperandAs(dim_index); + uint64_t dim_value; + if (_.EvalConstantValUint64(dim_id, &dim_value)) { + uint64_t expected_num_values = 0; + switch (expected) { + case DIM: + expected_num_values = dim_value; + break; + case DIMx2: + expected_num_values = dim_value * 2; + break; + case ONE: + expected_num_values = 1; + break; + case FOUR: + expected_num_values = 4; + break; + } + + if (num_values != expected_num_values) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) + << " unexpected number of operands."; + } + } + + for (uint32_t i = 0; i < num_values; ++i) { + const auto val_id = inst->GetOperandAs(i + 3); + const auto val = _.FindDef(val_id); + if (!val || !_.IsIntScalarType(val->type_id()) || + _.GetBitWidth(val->type_id()) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " operand " + << _.getIdName(val_id) << " is not a 32-bit integer."; + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case spv::Op::OpCreateTensorLayoutNV: + if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error; + break; + case spv::Op::OpCreateTensorViewNV: + if (auto error = ValidateCreateTensorViewNV(_, inst)) return error; + break; + case spv::Op::OpTensorLayoutSetBlockSizeNV: + case spv::Op::OpTensorLayoutSetDimensionNV: + case spv::Op::OpTensorLayoutSetStrideNV: + if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false)) + return error; + break; + case spv::Op::OpTensorLayoutSliceNV: + if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false)) + return error; + break; + case spv::Op::OpTensorLayoutSetClampValueNV: + if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false)) + return error; + break; + case spv::Op::OpTensorViewSetDimensionNV: + case spv::Op::OpTensorViewSetStrideNV: + if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true)) + return error; + break; + case spv::Op::OpTensorViewSetClipNV: + if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true)) + return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index 32024b735..5101a40c4 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -1,4 +1,5 @@ // Copyright (c) 2018 Google LLC. +// Copyright (c) 2024 NVIDIA Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -582,6 +583,37 @@ spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _, } } + uint64_t scope_value; + if (_.EvalConstantValUint64(scope_id, &scope_value)) { + if (scope_value == static_cast(spv::Scope::Workgroup)) { + for (auto entry_point_id : _.entry_points()) { + if (!_.EntryPointHasLocalSizeOrId(entry_point_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixKHR with ScopeWorkgroup " + << "used without specifying LocalSize or LocalSizeId " + << "for entry point " << _.getIdName(entry_point_id); + } + const auto local_size = _.EntryPointLocalSizeOrId(entry_point_id); + const auto mode = local_size->GetOperandAs(1); + if (mode == spv::ExecutionMode::LocalSizeId) { + uint32_t local_size_ids[3] = { + local_size->GetOperandAs(2), + local_size->GetOperandAs(3), + local_size->GetOperandAs(4), + }; + for (auto id : local_size_ids) { + if (_.FindDef(id) > inst) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixKHR with ScopeWorkgroup " + << "used before LocalSizeId constant value " + << _.getIdName(id) << " is defined."; + } + } + } + } + } + } + return SPV_SUCCESS; } @@ -611,6 +643,115 @@ spv_result_t ValidateTypeUntypedPointerKHR(ValidationState_t& _, } return SPV_SUCCESS; } + +spv_result_t ValidateTensorDim(ValidationState_t& _, const Instruction* inst) { + const auto dim_index = 1; + const auto dim_id = inst->GetOperandAs(dim_index); + const auto dim = _.FindDef(dim_id); + if (!dim || !_.IsIntScalarType(dim->type_id()) || + _.GetBitWidth(dim->type_id()) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Dim " + << _.getIdName(dim_id) << " is not a 32-bit integer."; + } + + constexpr uint32_t max_tensor_dim = 5; + + uint64_t dim_value; + if (_.EvalConstantValUint64(dim_id, &dim_value)) { + if (dim_value == 0 || dim_value > max_tensor_dim) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Dim " + << _.getIdName(dim_id) << " must be between 1 and " + << max_tensor_dim << "."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeTensorLayoutNV(ValidationState_t& _, + const Instruction* inst) { + if (auto error = ValidateTensorDim(_, inst)) return error; + + const auto clamp_index = 2; + const auto clamp_id = inst->GetOperandAs(clamp_index); + const auto clamp = _.FindDef(clamp_id); + if (!clamp || !_.IsIntScalarType(clamp->type_id()) || + _.GetBitWidth(clamp->type_id()) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " ClampMode " + << _.getIdName(clamp_id) << " is not a 32-bit integer."; + } + + uint64_t clamp_value; + if (_.EvalConstantValUint64(clamp_id, &clamp_value)) { + if (clamp_value > + static_cast(spv::TensorClampMode::RepeatMirrored)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " ClampMode " + << _.getIdName(clamp_id) << " must be a valid TensorClampMode."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeTensorViewNV(ValidationState_t& _, + const Instruction* inst) { + if (auto error = ValidateTensorDim(_, inst)) return error; + + const auto has_dim_index = 2; + const auto has_dim_id = inst->GetOperandAs(has_dim_index); + const auto has_dim = _.FindDef(has_dim_id); + if (!has_dim || !_.IsBoolScalarType(has_dim->type_id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " HasDimensions " + << _.getIdName(has_dim_id) << " is not a boolean value."; + } + + uint32_t permutation_mask = 0; + bool all_constant = true; + const auto num_dim = inst->operands().size() - 3; + for (size_t p_index = 3; p_index < inst->operands().size(); ++p_index) { + auto p_id = inst->GetOperandAs(p_index); + const auto p = _.FindDef(p_id); + if (!p || !_.IsIntScalarType(p->type_id()) || + _.GetBitWidth(p->type_id()) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Permutation " + << _.getIdName(p_id) << " is not a 32-bit integer."; + } + + uint64_t p_value; + if (_.EvalConstantValUint64(p_id, &p_value)) { + if (p_value >= num_dim) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) << " Permutation " + << _.getIdName(p_id) << " must be a valid dimension."; + } + permutation_mask |= 1 << p_value; + } else { + all_constant = false; + } + } + if (all_constant && permutation_mask != (1U << num_dim) - 1U) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) + << " Permutation values don't form a valid permutation."; + } + + uint64_t dim_value; + if (_.EvalConstantValUint64(inst->GetOperandAs(1), &dim_value)) { + if (dim_value != num_dim) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << spvOpcodeString(inst->opcode()) + << " Incorrect number of permutation values."; + } + } + + return SPV_SUCCESS; +} } // namespace spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { @@ -659,6 +800,12 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpTypeUntypedPointerKHR: if (auto error = ValidateTypeUntypedPointerKHR(_, inst)) return error; break; + case spv::Op::OpTypeTensorLayoutNV: + if (auto error = ValidateTypeTensorLayoutNV(_, inst)) return error; + break; + case spv::Op::OpTypeTensorViewNV: + if (auto error = ValidateTypeTensorViewNV(_, inst)) return error; + break; default: break; } diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 2afcacc33..da9174fa5 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -1290,8 +1290,9 @@ bool ValidationState_t::IsUnsigned64BitHandle(uint32_t id) const { } spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( - const Instruction* inst, uint32_t m1, uint32_t m2) { - const auto m1_type = FindDef(m1); + const Instruction* inst, uint32_t result_type_id, uint32_t m2, + bool is_conversion, bool swap_row_col) { + const auto m1_type = FindDef(result_type_id); const auto m2_type = FindDef(m2); if (m1_type->opcode() != m2_type->opcode()) { @@ -1307,6 +1308,10 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( uint32_t m2_rows_id = m2_type->GetOperandAs(3); uint32_t m2_cols_id = m2_type->GetOperandAs(4); + if (swap_row_col) { + std::swap(m1_rows_id, m1_cols_id); + } + bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false, m2_is_const_int32 = false; uint32_t m1_value = 0, m2_value = 0; @@ -1330,7 +1335,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { return diag(SPV_ERROR_INVALID_DATA, inst) << "Expected rows of Matrix type and Result Type to be " - << "identical"; + << (swap_row_col ? "swapped with columns" : "identical"); } std::tie(m1_is_int32, m1_is_const_int32, m1_value) = @@ -1341,7 +1346,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { return diag(SPV_ERROR_INVALID_DATA, inst) << "Expected columns of Matrix type and Result Type to be " - << "identical"; + << (swap_row_col ? "swapped with rows" : "identical"); } if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) { @@ -1352,7 +1357,12 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( std::tie(m2_is_int32, m2_is_const_int32, m2_value) = EvalInt32IfConst(m2_use_id); - if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { + if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value && + // CooperativeMatrixConversionsNV allows conversions from Acc->A/B + !(is_conversion && + HasCapability(spv::Capability::CooperativeMatrixConversionsNV) && + m2_value == + (uint32_t)spv::CooperativeMatrixUse::MatrixAccumulatorKHR)) { return diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Use of Matrix type and Result Type to be " << "identical"; diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 372b5b7b9..44551add3 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -240,6 +240,21 @@ class ValidationState_t { entry_point_to_execution_modes_[entry_point].insert(execution_mode); } + /// Registers that the entry point declares its local size + void RegisterEntryPointLocalSize(uint32_t entry_point, + const Instruction* inst) { + entry_point_to_local_size_or_id_[entry_point] = inst; + } + /// Returns whether the entry point declares its local size + bool EntryPointHasLocalSizeOrId(uint32_t entry_point) const { + return entry_point_to_local_size_or_id_.find(entry_point) != + entry_point_to_local_size_or_id_.end(); + } + /// Returns the id of the local size + const Instruction* EntryPointLocalSizeOrId(uint32_t entry_point) const { + return entry_point_to_local_size_or_id_.find(entry_point)->second; + } + /// Returns the interface descriptions of a given entry point. const std::vector& entry_point_descriptions( uint32_t entry_point) { @@ -759,11 +774,14 @@ class ValidationState_t { return SpvDecorationString(uint32_t(decoration)); } - // Returns whether type m1 and type m2 are cooperative matrices with - // the same "shape" (matching scope, rows, cols). If any are specialization - // constants, we assume they can match because we can't prove they don't. + // Returns whether type result_type_id and type m2 are cooperative matrices + // with the same "shape" (matching scope, rows, cols). If any are + // specialization constants, we assume they can match because we can't prove + // they don't. spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst, - uint32_t m1, uint32_t m2); + uint32_t result_type_id, + uint32_t m2, bool is_conversion, + bool swap_row_col = false); // Returns true if |lhs| and |rhs| logically match and, if the decorations of // |rhs| are a subset of |lhs|. @@ -949,6 +967,10 @@ class ValidationState_t { std::unordered_map> entry_point_to_execution_modes_; + // Mapping entry point -> local size execution mode instruction + std::unordered_map + entry_point_to_local_size_or_id_; + /// Mapping function -> array of entry points inside this /// module which can (indirectly) call the function. std::unordered_map> function_to_entry_points_; diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index d4d0fef52..865bfbb07 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -175,6 +175,9 @@ std::vector> GenerateAllTypes() { types.emplace_back(new RayQueryKHR()); types.emplace_back(new HitObjectNV()); + types.emplace_back(new TensorLayoutNV(1002, 1000)); + types.emplace_back(new TensorViewNV(1002, 1003, {1000, 1001})); + return types; } @@ -1102,11 +1105,14 @@ OpMemoryModel Logical GLSL450 %uint = OpTypeInt 32 0 %1 = OpTypePointer Input %uint %2 = OpTypePointer Uniform %uint +%1000 = OpConstant %uint 0 +%1001 = OpConstant %uint 1 %1002 = OpConstant %uint 2 %8 = OpConstant %uint 8 %24 = OpConstant %uint 24 %42 = OpConstant %uint 42 %100 = OpConstant %uint 100 +%1003 = OpConstantFalse %bool )"; std::unique_ptr context = diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp index 58ac4423e..8b2a8d0b7 100644 --- a/test/val/val_arithmetics_test.cpp +++ b/test/val/val_arithmetics_test.cpp @@ -1280,14 +1280,14 @@ TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) { TEST_F(ValidateArithmetics, CoopMatScopeFail) { const std::string types = R"( -%workgroup = OpConstant %u32 2 +%device = OpConstant %u32 1 -%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16 -%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1 +%mat16x16_dv = OpTypeCooperativeMatrixNV %f16 %device %u32_16 %u32_16 +%f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1 )"; const std::string body = R"( -%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1 +%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matdv_16x16_1 )"; CompileSuccessfully(GenerateCoopMatCode(types, body).c_str()); @@ -1475,7 +1475,10 @@ std::string GenerateCoopMatKHRCode(const std::string& extra_types, OpCapability Shader OpCapability Float16 OpCapability CooperativeMatrixKHR +OpCapability CooperativeMatrixReductionsNV +OpCapability CooperativeMatrixPerElementOperationsNV OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_NV_cooperative_matrix2" OpExtension "SPV_KHR_vulkan_memory_model" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" @@ -1487,6 +1490,7 @@ OpEntryPoint GLCompute %main "main" %u32 = OpTypeInt 32 0 %s32 = OpTypeInt 32 1 +%u32_8 = OpConstant %u32 8 %u32_16 = OpConstant %u32 16 %u32_4 = OpConstant %u32 4 %subgroup = OpConstant %u32 3 @@ -1579,13 +1583,13 @@ TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) { TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) { const std::string types = R"( -%workgroup = OpConstant %u32 2 -%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC -%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1 +%device = OpConstant %u32 1 +%mat16x16_dv = OpTypeCooperativeMatrixKHR %f16 %device %u32_16 %u32_16 %useC +%f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1 )"; const std::string body = R"( -%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1 +%val1 = OpFAdd %f16matA %f16matdv_16x16_1 %f16mat_A_1 )"; CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str()); @@ -1612,6 +1616,241 @@ TEST_F(ValidateArithmetics, CoopMatKHRDimFail) { HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR")); } +TEST_F(ValidateArithmetics, CoopMat2ReduceSuccess) { + const std::string extra_types = R"( + +%f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC +%f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC +%f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC + +%functy = OpTypeFunction %f16 %f16 %f16 +%reducefunc = OpFunction %f16 None %functy +%x = OpFunctionParameter %f16 +%y = OpFunctionParameter %f16 +%entry2 = OpLabel +%sum = OpFAdd %f16 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 2x2 %reducefunc +%val2 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Row %reducefunc +%val3 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Column %reducefunc +%val4 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc +%val5 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, CoopMat2Reduce2x2DimFail) { + const std::string extra_types = R"( + +%functy = OpTypeFunction %f16 %f16 %f16 +%reducefunc = OpFunction %f16 None %functy +%x = OpFunctionParameter %f16 +%y = OpFunctionParameter %f16 +%entry2 = OpLabel +%sum = OpFAdd %f16 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 2x2 %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For Reduce2x2, result rows/cols must be half of " + "matrix rows/cols: CooperativeMatrixReduceNV")); +} + +TEST_F(ValidateArithmetics, CoopMat2ReduceRowDimFail) { + const std::string extra_types = R"( + +%f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC + +%functy = OpTypeFunction %f16 %f16 %f16 +%reducefunc = OpFunction %f16 None %functy +%x = OpFunctionParameter %f16 +%y = OpFunctionParameter %f16 +%entry2 = OpLabel +%sum = OpFAdd %f16 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Row %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For ReduceRow, result rows must match matrix rows: " + "CooperativeMatrixReduceNV")); +} + +TEST_F(ValidateArithmetics, CoopMat2ReduceColDimFail) { + const std::string extra_types = R"( + +%f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC + +%functy = OpTypeFunction %f16 %f16 %f16 +%reducefunc = OpFunction %f16 None %functy +%x = OpFunctionParameter %f16 +%y = OpFunctionParameter %f16 +%entry2 = OpLabel +%sum = OpFAdd %f16 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Column %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For ReduceColumn, result cols must match matrix cols: " + "CooperativeMatrixReduceNV")); +} + +TEST_F(ValidateArithmetics, CoopMat2ReduceMaskFail) { + const std::string extra_types = R"( + +%f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC + +%functy = OpTypeFunction %f16 %f16 %f16 +%reducefunc = OpFunction %f16 None %functy +%x = OpFunctionParameter %f16 +%y = OpFunctionParameter %f16 +%entry2 = OpLabel +%sum = OpFAdd %f16 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column|2x2 %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reduce 2x2 must not be used with Row/Column: " + "CooperativeMatrixReduceNV")); +} + +TEST_F(ValidateArithmetics, CoopMat2ReduceFuncTypeFail) { + const std::string extra_types = R"( + +%functy = OpTypeFunction %f32 %f32 %f32 +%reducefunc = OpFunction %f32 None %functy +%x = OpFunctionParameter %f32 +%y = OpFunctionParameter %f32 +%entry2 = OpLabel +%sum = OpFAdd %f32 %x %y +OpReturnValue %sum +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("CombineFunc return type and parameters must match " + "matrix component type: CooperativeMatrixReduceNV")); +} + +TEST_F(ValidateArithmetics, CoopMat2PerElementOpSuccess) { + const std::string extra_types = R"( + +%functy = OpTypeFunction %f16 %u32 %u32 %f16 +%functy2 = OpTypeFunction %f16 %u32 %u32 %f16 %u32 + +%elemfunc = OpFunction %f16 None %functy +%row = OpFunctionParameter %u32 +%col = OpFunctionParameter %u32 +%el = OpFunctionParameter %f16 +%entry2 = OpLabel +OpReturnValue %el +OpFunctionEnd + +%elemfunc2 = OpFunction %f16 None %functy2 +%row2 = OpFunctionParameter %u32 +%col2 = OpFunctionParameter %u32 +%el2 = OpFunctionParameter %f16 +%x = OpFunctionParameter %u32 +%entry3 = OpLabel +OpReturnValue %el2 +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc +%val2 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc2 %f16_1 +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, CoopMat2PerElementOpElemTyFail) { + const std::string extra_types = R"( + +%functy = OpTypeFunction %f32 %u32 %u32 %f32 + +%elemfunc = OpFunction %f32 None %functy +%row = OpFunctionParameter %u32 +%col = OpFunctionParameter %u32 +%el = OpFunctionParameter %f32 +%entry2 = OpLabel +OpReturnValue %el +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must match matrix component type")); +} + +TEST_F(ValidateArithmetics, CoopMat2PerElementOpRowTyFail) { + const std::string extra_types = R"( + +%functy = OpTypeFunction %f16 %f16 %u32 %f16 + +%elemfunc = OpFunction %f16 None %functy +%row = OpFunctionParameter %f16 +%col = OpFunctionParameter %u32 +%el = OpFunctionParameter %f16 +%entry2 = OpLabel +OpReturnValue %el +OpFunctionEnd + + )"; + const std::string body = R"( +%val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc +)"; + + CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must be a 32-bit integer")); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp index 748ad64f8..3869626ef 100644 --- a/test/val/val_conversion_test.cpp +++ b/test/val/val_conversion_test.cpp @@ -1338,11 +1338,11 @@ OpEntryPoint GLCompute %main "main" %u32_8 = OpConstant %u32 8 %u32_4 = OpConstant %u32 4 %subgroup = OpConstant %u32 3 -%workgroup = OpConstant %u32 2 +%device = OpConstant %u32 1 %use_A = OpConstant %u32 0 %f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A -%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A +%f32mat = OpTypeCooperativeMatrixKHR %f32 %device %u32_8 %u32_8 %use_A %f16_1 = OpConstant %f16 1 @@ -2140,6 +2140,192 @@ INSTANTIATE_TEST_SUITE_P(SmallConversionInstructions, ValidateSmallConversions, "%inst = OpBitcast %short %ld_half", "%inst = OpBitcast %short2 %ld_half2")); +TEST_F(ValidateConversion, CoopMat2ConversionSuccess) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixConversionsNV +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_NV_cooperative_matrix2" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_16 = OpConstant %u32 16 +%use_A = OpConstant %u32 0 +%use_B = OpConstant %u32 1 +%use_Acc = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 + +%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A +%f32matA = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A +%u16matA = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A +%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A +%s16matA = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A +%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A + +%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_B +%f32matB = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B +%u16matB = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_B +%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_B +%s16matB = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_B +%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_B + +%f16matAcc = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_Acc +%f32matAcc = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_Acc +%u16matAcc = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_Acc +%u32matAcc = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_Acc +%s16matAcc = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_Acc +%s32matAcc = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_Acc + +%f16matAcc16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_Acc +%f16matB8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %use_B + +%f16_1 = OpConstant %f16 1 +%f32_1 = OpConstant %f32 1 +%u16_1 = OpConstant %u16 1 +%u32_1 = OpConstant %u32 1 +%s16_1 = OpConstant %s16 1 +%s32_1 = OpConstant %s32 1 + +%f16matAcc_1 = OpConstantComposite %f16matAcc %f16_1 +%f32matAcc_1 = OpConstantComposite %f32matAcc %f32_1 +%u16matAcc_1 = OpConstantComposite %u16matAcc %u16_1 +%u32matAcc_1 = OpConstantComposite %u32matAcc %u32_1 +%s16matAcc_1 = OpConstantComposite %s16matAcc %s16_1 +%s32matAcc_1 = OpConstantComposite %s32matAcc %s32_1 + +%f16matAcc16x8_1 = OpConstantComposite %f16matAcc16x8 %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val11A = OpConvertFToU %u16matA %f16matAcc_1 +%val12A = OpConvertFToU %u32matA %f16matAcc_1 +%val13A = OpConvertFToS %s16matA %f16matAcc_1 +%val14A = OpConvertFToS %s32matA %f16matAcc_1 +%val15A = OpFConvert %f32matA %f16matAcc_1 + +%val11B = OpConvertFToU %u16matB %f16matAcc_1 +%val12B = OpConvertFToU %u32matB %f16matAcc_1 +%val13B = OpConvertFToS %s16matB %f16matAcc_1 +%val14B = OpConvertFToS %s32matB %f16matAcc_1 +%val15B = OpFConvert %f32matB %f16matAcc_1 + +%val21A = OpConvertFToU %u16matA %f32matAcc_1 +%val22A = OpConvertFToU %u32matA %f32matAcc_1 +%val23A = OpConvertFToS %s16matA %f32matAcc_1 +%val24A = OpConvertFToS %s32matA %f32matAcc_1 +%val25A = OpFConvert %f16matA %f32matAcc_1 + +%val21B = OpConvertFToU %u16matB %f32matAcc_1 +%val22B = OpConvertFToU %u32matB %f32matAcc_1 +%val23B = OpConvertFToS %s16matB %f32matAcc_1 +%val24B = OpConvertFToS %s32matB %f32matAcc_1 +%val25B = OpFConvert %f16matB %f32matAcc_1 + +%val31A = OpConvertUToF %f16matA %u16matAcc_1 +%val32A = OpConvertUToF %f32matA %u16matAcc_1 +%val33A = OpUConvert %u32matA %u16matAcc_1 +%val34A = OpSConvert %s32matA %u16matAcc_1 + +%val31B = OpConvertUToF %f16matB %u16matAcc_1 +%val32B = OpConvertUToF %f32matB %u16matAcc_1 +%val33B = OpUConvert %u32matB %u16matAcc_1 +%val34B = OpSConvert %s32matB %u16matAcc_1 + +%val41A = OpConvertSToF %f16matA %s16matAcc_1 +%val42A = OpConvertSToF %f32matA %s16matAcc_1 +%val43A = OpUConvert %u32matA %s16matAcc_1 +%val44A = OpSConvert %s32matA %s16matAcc_1 + +%val41B = OpConvertSToF %f16matB %s16matAcc_1 +%val42B = OpConvertSToF %f32matB %s16matAcc_1 +%val43B = OpUConvert %u32matB %s16matAcc_1 +%val44B = OpSConvert %s32matB %s16matAcc_1 + +%val51A = OpCooperativeMatrixConvertNV %f16matA %f16matAcc_1 +%val52A = OpCooperativeMatrixConvertNV %f32matA %f32matAcc_1 +%val53A = OpCooperativeMatrixConvertNV %u16matA %u16matAcc_1 +%val54A = OpCooperativeMatrixConvertNV %s16matA %s16matAcc_1 + +%val51B = OpCooperativeMatrixConvertNV %f16matB %f16matAcc_1 +%val52B = OpCooperativeMatrixConvertNV %f32matB %f32matAcc_1 +%val53B = OpCooperativeMatrixConvertNV %u16matB %u16matAcc_1 +%val54B = OpCooperativeMatrixConvertNV %s16matB %s16matAcc_1 + +%val61B = OpCooperativeMatrixTransposeNV %f16matB %f16matAcc_1 +%val62B = OpCooperativeMatrixTransposeNV %f32matB %f32matAcc_1 +%val63B = OpCooperativeMatrixTransposeNV %u16matB %u16matAcc_1 +%val64B = OpCooperativeMatrixTransposeNV %s16matB %s16matAcc_1 + +%val71B = OpCooperativeMatrixTransposeNV %f16matB8x16 %f16matAcc16x8_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, CoopMat2TransposeShapeFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixConversionsNV +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_NV_cooperative_matrix2" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 + +%u32_8 = OpConstant %u32 8 +%u32_16 = OpConstant %u32 16 +%use_B = OpConstant %u32 1 +%use_Acc = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 + +%f16matAcc16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_Acc +%f16matB16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %use_B + +%f16_1 = OpConstant %f16 1 + +%f16matAcc16x8_1 = OpConstantComposite %f16matAcc16x8 %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val71B = OpCooperativeMatrixTransposeNV %f16matB16x8 %f16matAcc16x8_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected rows of Matrix type and Result Type to be " + "swapped with columns")); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index ae78b1450..0a918c989 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -7171,6 +7171,639 @@ OpFunctionEnd EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_3)); } +std::string GenCoopMat2Shader(const std::string& extra_types, + const std::string& main_body, + const std::string& after_main = "", + const std::string& extra_decorations = "") { + const std::string prefix = R"( +OpCapability Shader +OpCapability Float16 +OpCapability PhysicalStorageBufferAddresses +OpCapability VulkanMemoryModel +OpCapability CooperativeMatrixKHR +OpCapability TensorAddressingNV +OpCapability CooperativeMatrixTensorAddressingNV +OpCapability CooperativeMatrixBlockLoadsNV +OpExtension "SPV_KHR_physical_storage_buffer" +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_NV_tensor_addressing" +OpExtension "SPV_NV_cooperative_matrix2" +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 + +OpDecorate %f16_arr ArrayStride 2 +OpDecorate %46 Block +OpMemberDecorate %46 0 Offset 0 +OpDecorate %48 Binding 0 +OpDecorate %48 DescriptorSet 0 +OpDecorate %psb Restrict +)" + extra_decorations + R"( +%void = OpTypeVoid +%bool = OpTypeBool +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 + +%s32_0 = OpConstant %s32 0 +%f16_0 = OpConstant %f16 0 +%u32_2 = OpConstant %u32 2 +%u32_8 = OpConstant %u32 8 +%use_A = OpConstant %u32 0 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 + +%f16_arr = OpTypeRuntimeArray %f16 +%46 = OpTypeStruct %f16_arr +%47 = OpTypePointer StorageBuffer %46 +%48 = OpVariable %47 StorageBuffer +%51 = OpTypePointer StorageBuffer %f16_arr +%psbptr = OpTypePointer PhysicalStorageBuffer %f16_arr + +%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_8 %u32_8 %use_A +%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A + +%arr2 = OpTypeArray %u32 %u32_2 +%functy = OpTypeFunction %f16 %psbptr %arr2 %arr2 +)"; + + const std::string decode_func = + R"( +%decodefunc = OpFunction %f16 None %functy +%psb = OpFunctionParameter %psbptr +%c0 = OpFunctionParameter %arr2 +%c1 = OpFunctionParameter %arr2 +%entry2 = OpLabel +OpReturnValue %f16_0 +OpFunctionEnd +)"; + + const std::string func_begin = + R"( +%main = OpFunction %void None %func +%main_entry = OpLabel + +%array_ptr = OpAccessChain %51 %48 %s32_0 +)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + return prefix + extra_types + func_begin + main_body + suffix + decode_func + + after_main; +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutAndViewSuccess) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutInvalidDimFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 6 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must be between 1 and 5")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutInvalidClampFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 6 + %dim = OpConstant %u32 2 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be a valid TensorClampMode")); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewInvalidDimFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 6 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must be between 1 and 5")); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewInvalidPermutationFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p1 + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Permutation values don't form a valid permutation")); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewInvalidPermutation2Fail) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Incorrect number of permutation values.")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutBlockSizePass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetBlockSizeNV %layout %tl %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutBlockSizeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetBlockSizeNV %layout %tl %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutDimensionPass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetDimensionNV %layout %tl %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutDimensionFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetDimensionNV %layout %tl %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutStridePass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetStrideNV %layout %tl %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutStrideFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetStrideNV %layout %tl %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutSlicePass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSliceNV %layout %tl %b %b %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutSliceFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSliceNV %layout %tl %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorLayoutSetClampValuePass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 3 + %b = OpConstant %u32 1 + %layout = OpTypeTensorLayoutNV %dim %clamp + )", + R"( + %tl = OpCreateTensorLayoutNV %layout + %tl2 = OpTensorLayoutSetClampValueNV %layout %tl %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewDimensionPass) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %hasdim = OpConstantFalse %bool + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %p2 = OpConstant %u32 2 + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2 + %b = OpConstant %u32 1 + )", + R"( + %tv = OpCreateTensorViewNV %view + %tv2 = OpTensorViewSetDimensionNV %view %tv %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewDimensionFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %hasdim = OpConstantFalse %bool + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %p2 = OpConstant %u32 2 + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2 + %b = OpConstant %u32 1 + )", + R"( + %tv = OpCreateTensorViewNV %view + %tv2 = OpTensorViewSetDimensionNV %view %tv %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewStridePass) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %hasdim = OpConstantFalse %bool + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %p2 = OpConstant %u32 2 + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2 + %b = OpConstant %u32 1 + )", + R"( + %tv = OpCreateTensorViewNV %view + %tv2 = OpTensorViewSetStrideNV %view %tv %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewStrideFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %hasdim = OpConstantFalse %bool + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %p2 = OpConstant %u32 2 + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2 + %b = OpConstant %u32 1 + )", + R"( + %tv = OpCreateTensorViewNV %view + %tv2 = OpTensorViewSetStrideNV %view %tv %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("unexpected number of operands")); +} + +TEST_F(ValidateMemory, CoopMat2TensorViewClipPass) { + std::string spirv = GenCoopMat2Shader( + R"( + %dim = OpConstant %u32 3 + %hasdim = OpConstantFalse %bool + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %p2 = OpConstant %u32 2 + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 %p2 + %b = OpConstant %u32 1 + )", + R"( + %tv = OpCreateTensorViewNV %view + %tv2 = OpTensorViewSetClipNV %view %tv %b %b %b %b + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2LoadStoreTensorPass) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + %mat = OpUndef %f16mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None None + %mat3 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl Aligned 4 None + %mat4 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None TensorView %tv + %mat5 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc + %mat6 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None TensorView|DecodeFunc %tv %decodefunc + %mat7 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl Aligned 4 TensorView|DecodeFunc %tv %decodefunc + OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl None None + OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl Aligned 4 None + OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl None TensorView %tv + OpCooperativeMatrixStoreTensorNV %array_ptr %mat %tl Aligned 4 TensorView %tv + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, CoopMat2LoadTensorWrongLayoutTypeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + %mat = OpUndef %f16mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tv None None + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("does not have a tensor layout type")); +} + +TEST_F(ValidateMemory, CoopMat2LoadTensorWrongObjectTypeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + %mat = OpUndef %f32mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None None + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Result Type")); +} + +TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncTypeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + )", + R"( + %mat = OpUndef %f32mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f32mat %array_ptr %mat %tl None DecodeFunc %decodefunc + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("return type must match matrix component type")); +} + +TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncArrayTypeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %u32_3 = OpConstant %u32 3 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + %arr3 = OpTypeArray %u32 %u32_3 + %functy2 = OpTypeFunction %f16 %psbptr %arr3 %arr3 + )", + R"( + %mat = OpUndef %f16mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc2 + )", + R"( + %decodefunc2 = OpFunction %f16 None %functy2 + %psb2 = OpFunctionParameter %psbptr + %c02 = OpFunctionParameter %arr3 + %c12 = OpFunctionParameter %arr3 + %entry3 = OpLabel + OpReturnValue %f16_0 + OpFunctionEnd + )", + R"( + OpDecorate %psb2 Restrict + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("dimension equal to the tensor dimension")); +} + +TEST_F(ValidateMemory, CoopMat2LoadTensorDecodeFuncPointerTypeFail) { + std::string spirv = GenCoopMat2Shader( + R"( + %clamp = OpConstant %u32 0 + %dim = OpConstant %u32 2 + %p0 = OpConstant %u32 0 + %p1 = OpConstant %u32 1 + %hasdim = OpConstantFalse %bool + %layout = OpTypeTensorLayoutNV %dim %clamp + %view = OpTypeTensorViewNV %dim %hasdim %p0 %p1 + %sbptr = OpTypePointer StorageBuffer %f16_arr + %functy2 = OpTypeFunction %f16 %sbptr %arr2 %arr2 + )", + R"( + %mat = OpUndef %f16mat + %tl = OpCreateTensorLayoutNV %layout + %tv = OpCreateTensorViewNV %view + %mat2 = OpCooperativeMatrixLoadTensorNV %f16mat %array_ptr %mat %tl None DecodeFunc %decodefunc2 + )", + R"( + %decodefunc2 = OpFunction %f16 None %functy2 + %sb = OpFunctionParameter %sbptr + %c02 = OpFunctionParameter %arr2 + %c12 = OpFunctionParameter %arr2 + %entry3 = OpLabel + OpReturnValue %f16_0 + OpFunctionEnd + )"); + + CompileSuccessfully(spirv.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("first parameter must be pointer to PhysicalStorageBuffer")); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_misc_test.cpp b/test/val/val_misc_test.cpp index 6d5897c12..5304c59e0 100644 --- a/test/val/val_misc_test.cpp +++ b/test/val/val_misc_test.cpp @@ -350,6 +350,77 @@ OpEntryPoint Vertex %func "shader" EXPECT_THAT(getDiagnosticString(), HasSubstr("Invalid storage class for target environment")); } + +TEST_F(ValidateMisc, CoopMat2WorkgroupLocalSizeIdPass) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionModeId %main LocalSizeId %u32_16 %u32_16 %u32_16 +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%use_Acc = OpConstant %u32 2 +%workgroup = OpConstant %u32 2 + +%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %use_Acc + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str(), SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMisc, CoopMat2WorkgroupLocalSizeIdConstantNotDeclaredYetFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionModeId %main LocalSizeId %u32_16 %u32_8 %u32_16 +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%use_Acc = OpConstant %u32 2 +%workgroup = OpConstant %u32 2 + +%f16mat = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %use_Acc +%u32_8 = OpConstant %u32 8 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str(), SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeCooperativeMatrixKHR with ScopeWorkgroup used " + "before LocalSizeId constant value")); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/utils/check_copyright.py b/utils/check_copyright.py index a6459233a..c477ece9a 100755 --- a/utils/check_copyright.py +++ b/utils/check_copyright.py @@ -43,7 +43,8 @@ AUTHORS = ['The Khronos Group Inc.', 'Shiyu Liu', 'ZHOU He', 'Nintendo', - 'Epic Games, Inc.'] + 'Epic Games, Inc.', + 'NVIDIA Corporation'] CURRENT_YEAR = 2023 FIRST_YEAR = 2014