From 002ef361cabc486a2f3567d646363334d50cc462 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 25 Feb 2019 16:43:11 -0600 Subject: [PATCH] Add validation for SPV_NV_cooperative_matrix (#2404) --- DEPS | 2 +- source/assembly_grammar.cpp | 5 +- source/opcode.cpp | 2 + source/opt/reflect.h | 3 +- source/val/validate_arithmetics.cpp | 107 +++++++- source/val/validate_composites.cpp | 24 ++ source/val/validate_constants.cpp | 35 ++- source/val/validate_conversion.cpp | 126 +++++++--- source/val/validate_id.cpp | 10 +- source/val/validate_memory.cpp | 212 +++++++++++++++- source/val/validate_scopes.cpp | 20 +- source/val/validate_type.cpp | 50 ++++ source/val/validation_state.cpp | 89 ++++++- source/val/validation_state.h | 12 +- test/val/val_arithmetics_test.cpp | 144 +++++++++++ test/val/val_composites_test.cpp | 77 ++++++ test/val/val_conversion_test.cpp | 166 +++++++++++++ test/val/val_memory_test.cpp | 366 ++++++++++++++++++++++++++++ 18 files changed, 1390 insertions(+), 60 deletions(-) diff --git a/DEPS b/DEPS index 5668c6602..3c5361fc2 100644 --- a/DEPS +++ b/DEPS @@ -11,7 +11,7 @@ vars = { 'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036', 'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59', 're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f', - 'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f', + 'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab', } deps = { diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp index 4d98e3dab..79f18eee3 100644 --- a/source/assembly_grammar.cpp +++ b/source/assembly_grammar.cpp @@ -154,10 +154,11 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = { CASE(InBoundsAccessChain), CASE(PtrAccessChain), CASE(InBoundsPtrAccessChain), + CASE(CooperativeMatrixLengthNV) }; -// The 59 is determined by counting the opcodes listed in the spec. -static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]), +// The 60 is determined by counting the opcodes listed in the spec. +static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]), "OpSpecConstantOp opcode table is incomplete"); #undef CASE // clang-format on diff --git a/source/opcode.cpp b/source/opcode.cpp index 78c238686..da096a404 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -260,6 +260,7 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode) { case SpvOpTypeMatrix: case SpvOpTypeArray: case SpvOpTypeStruct: + case SpvOpTypeCooperativeMatrixNV: return true; default: return false; @@ -325,6 +326,7 @@ int32_t spvOpcodeGeneratesType(SpvOp op) { case SpvOpTypePipeStorage: case SpvOpTypeNamedBarrier: case SpvOpTypeAccelerationStructureNV: + case SpvOpTypeCooperativeMatrixNV: return true; default: // In particular, OpTypeForwardPointer does not generate a type, diff --git a/source/opt/reflect.h b/source/opt/reflect.h index 79d90bda4..810644288 100644 --- a/source/opt/reflect.h +++ b/source/opt/reflect.h @@ -45,7 +45,8 @@ inline bool IsAnnotationInst(SpvOp opcode) { inline bool IsTypeInst(SpvOp opcode) { return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) || opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier || - opcode == SpvOpTypeAccelerationStructureNV; + opcode == SpvOpTypeAccelerationStructureNV || + opcode == SpvOpTypeCooperativeMatrixNV; } inline bool IsConstantInst(SpvOp opcode) { return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp; diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp index 2314e7dfc..433330d74 100644 --- a/source/val/validate_arithmetics.cpp +++ b/source/val/validate_arithmetics.cpp @@ -39,8 +39,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { case SpvOpFRem: case SpvOpFMod: case SpvOpFNegate: { + bool supportsCoopMat = + (opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod); if (!_.IsFloatScalarType(result_type) && - !_.IsFloatVectorType(result_type)) + !_.IsFloatVectorType(result_type) && + !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected floating scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -58,8 +61,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { case SpvOpUDiv: case SpvOpUMod: { + bool supportsCoopMat = (opcode == SpvOpUDiv); if (!_.IsUnsignedIntScalarType(result_type) && - !_.IsUnsignedIntVectorType(result_type)) + !_.IsUnsignedIntVectorType(result_type) && + !(supportsCoopMat && + _.IsUnsignedIntCooperativeMatrixType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -82,7 +88,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { case SpvOpSMod: case SpvOpSRem: case SpvOpSNegate: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + bool supportsCoopMat = + (opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod); + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && + !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -94,7 +103,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { ++operand_index) { const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!type_id || - (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) + (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) && + !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as operand: " << spvOpcodeString(opcode) << " operand index " @@ -176,7 +186,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { } case SpvOpMatrixTimesScalar: { - if (!_.IsFloatMatrixType(result_type)) + if (!_.IsFloatMatrixType(result_type) && + !_.IsCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as Result Type: " << spvOpcodeString(opcode); @@ -442,6 +453,92 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { break; } + case SpvOpCooperativeMatrixMulAddNV: { + const uint32_t D_type_id = _.GetOperandTypeId(inst, 1); + const uint32_t A_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t B_type_id = _.GetOperandTypeId(inst, 3); + const uint32_t C_type_id = _.GetOperandTypeId(inst, 4); + + if (!_.IsCooperativeMatrixType(A_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix type as A Type: " + << spvOpcodeString(opcode); + } + if (!_.IsCooperativeMatrixType(B_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix type as B Type: " + << spvOpcodeString(opcode); + } + if (!_.IsCooperativeMatrixType(C_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix type as C Type: " + << spvOpcodeString(opcode); + } + if (!_.IsCooperativeMatrixType(D_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix type as Result Type: " + << spvOpcodeString(opcode); + } + + const auto A = _.FindDef(A_type_id); + const auto B = _.FindDef(B_type_id); + const auto C = _.FindDef(C_type_id); + const auto D = _.FindDef(D_type_id); + + std::tuple A_scope, B_scope, C_scope, D_scope, + A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols; + + A_scope = _.EvalInt32IfConst(A->GetOperandAs(2)); + B_scope = _.EvalInt32IfConst(B->GetOperandAs(2)); + C_scope = _.EvalInt32IfConst(C->GetOperandAs(2)); + D_scope = _.EvalInt32IfConst(D->GetOperandAs(2)); + + A_rows = _.EvalInt32IfConst(A->GetOperandAs(3)); + B_rows = _.EvalInt32IfConst(B->GetOperandAs(3)); + C_rows = _.EvalInt32IfConst(C->GetOperandAs(3)); + D_rows = _.EvalInt32IfConst(D->GetOperandAs(3)); + + A_cols = _.EvalInt32IfConst(A->GetOperandAs(4)); + B_cols = _.EvalInt32IfConst(B->GetOperandAs(4)); + C_cols = _.EvalInt32IfConst(C->GetOperandAs(4)); + D_cols = _.EvalInt32IfConst(D->GetOperandAs(4)); + + const auto notEqual = [](std::tuple X, + std::tuple Y) { + return (std::get<1>(X) && std::get<1>(Y) && + std::get<2>(X) != std::get<2>(Y)); + }; + + if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) || + notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) || + notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cooperative matrix scopes must match: " + << spvOpcodeString(opcode); + } + + if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) || + notEqual(C_rows, D_rows)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cooperative matrix 'M' mismatch: " + << spvOpcodeString(opcode); + } + + if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) || + notEqual(C_cols, D_cols)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cooperative matrix 'N' mismatch: " + << spvOpcodeString(opcode); + } + + if (notEqual(A_cols, B_rows)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cooperative matrix 'K' mismatch: " + << spvOpcodeString(opcode); + } + break; + } + default: break; } diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp index ccc558773..de3210efb 100644 --- a/source/val/validate_composites.cpp +++ b/source/val/validate_composites.cpp @@ -118,6 +118,10 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _, *member_type = type_inst->word(component_index + 2); break; } + case SpvOpTypeCooperativeMatrixNV: { + *member_type = type_inst->word(2); + break; + } default: return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Reached non-composite type while indexes still remain to " @@ -315,6 +319,26 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _, break; } + case SpvOpTypeCooperativeMatrixNV: { + const auto result_type_inst = _.FindDef(result_type); + assert(result_type_inst); + const auto component_type_id = + result_type_inst->GetOperandAs(1); + + if (3 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected single constituent"; + } + + const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2); + + if (operand_type_id != component_type_id) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the component type"; + } + + break; + } default: { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a composite type"; diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp index e2f20f672..c413b4fba 100644 --- a/source/val/validate_constants.cpp +++ b/source/val/validate_constants.cpp @@ -247,6 +247,36 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _, } } } break; + case SpvOpTypeCooperativeMatrixNV: { + if (1 != constituent_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(inst->type_id()) << "' count must be one."; + } + const auto constituent_id = inst->GetOperandAs(2); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_type = _.FindDef(constituent->type_id()); + if (!constituent_type) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + + const auto component_type_id = result_type->GetOperandAs(1); + const auto component_type = _.FindDef(component_type_id); + if (!component_type || component_type->id() != constituent_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' type does not match the Result Type '" + << _.getIdName(result_type->id()) << "'s component type."; + } + } break; default: break; } @@ -285,6 +315,7 @@ bool IsTypeNullable(const std::vector& instruction, return true; case SpvOpTypeArray: case SpvOpTypeMatrix: + case SpvOpTypeCooperativeMatrixNV: case SpvOpTypeVector: { auto base_type = _.FindDef(instruction[2]); return base_type && IsTypeNullable(base_type->words(), _); @@ -320,7 +351,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _, // The binary parser already ensures that the op is valid for *some* // environment. Here we check restrictions. - switch(op) { + switch (op) { case SpvOpQuantizeToF16: if (!_.HasCapability(SpvCapabilityShader)) { return _.diag(SPV_ERROR_INVALID_ID, inst) @@ -365,7 +396,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _, } break; - default: + default: break; } diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp index 73da58255..17af9f465 100644 --- a/source/val/validate_conversion.cpp +++ b/source/val/validate_conversion.cpp @@ -32,22 +32,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { switch (opcode) { case SpvOpConvertFToU: { if (!_.IsUnsignedIntScalarType(result_type) && - !_.IsUnsignedIntVectorType(result_type)) + !_.IsUnsignedIntVectorType(result_type) && + !_.IsUnsignedIntCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && - !_.IsFloatVectorType(input_type))) + !_.IsFloatVectorType(input_type) && + !_.IsFloatCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -58,22 +67,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { } case SpvOpConvertFToS: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && + !_.IsIntCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && - !_.IsFloatVectorType(input_type))) + !_.IsFloatVectorType(input_type) && + !_.IsFloatCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -86,22 +104,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case SpvOpConvertSToF: case SpvOpConvertUToF: { if (!_.IsFloatScalarType(result_type) && - !_.IsFloatVectorType(result_type)) + !_.IsFloatVectorType(result_type) && + !_.IsFloatCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || - (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && + !_.IsIntCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -113,22 +140,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case SpvOpUConvert: { if (!_.IsUnsignedIntScalarType(result_type) && - !_.IsUnsignedIntVectorType(result_type)) + !_.IsUnsignedIntVectorType(result_type) && + !_.IsUnsignedIntCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || - (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && + !_.IsIntCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -139,22 +175,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { } case SpvOpSConvert: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && + !_.IsIntCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || - (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && + !_.IsIntCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -166,22 +211,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case SpvOpFConvert: { if (!_.IsFloatScalarType(result_type) && - !_.IsFloatVectorType(result_type)) + !_.IsFloatVectorType(result_type) && + !_.IsFloatCooperativeMatrixType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && - !_.IsFloatVectorType(input_type))) + !_.IsFloatVectorType(input_type) && + !_.IsFloatCooperativeMatrixType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected input to have the same dimension as Result Type: " - << spvOpcodeString(opcode); + if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { + spv_result_t ret = + _.CooperativeMatrixShapesMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else { + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + } if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp index 21a04113d..cb18e131b 100644 --- a/source/val/validate_id.cpp +++ b/source/val/validate_id.cpp @@ -167,7 +167,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) { const auto opcode = inst->opcode(); if (spvOpcodeGeneratesType(def->opcode()) && !spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) && - !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) { + !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction && + opcode != SpvOpCooperativeMatrixLengthNV && + !(opcode == SpvOpSpecConstantOp && + inst->word(3) == SpvOpCooperativeMatrixLengthNV)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Operand " << _.getIdName(operand_word) << " cannot be a type"; @@ -177,7 +180,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) { !spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi && opcode != SpvOpExtInst && opcode != SpvOpExtInstImport && opcode != SpvOpSelectionMerge && - opcode != SpvOpLoopMerge && opcode != SpvOpFunction) { + opcode != SpvOpLoopMerge && opcode != SpvOpFunction && + opcode != SpvOpCooperativeMatrixLengthNV && + !(opcode == SpvOpSpecConstantOp && + inst->word(3) == SpvOpCooperativeMatrixLengthNV)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Operand " << _.getIdName(operand_word) << " requires a type"; diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 9e93cf134..f6127a124 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -197,17 +197,49 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage, return false; } +bool ContainsCooperativeMatrix(ValidationState_t& _, + const Instruction* storage) { + const size_t elem_type_index = 1; + uint32_t elem_type_id; + Instruction* elem_type; + + switch (storage->opcode()) { + case SpvOpTypeCooperativeMatrixNV: + return true; + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + elem_type_id = storage->GetOperandAs(elem_type_index); + elem_type = _.FindDef(elem_type_id); + return ContainsCooperativeMatrix(_, elem_type); + case SpvOpTypeStruct: + for (size_t member_type_index = 1; + member_type_index < storage->operands().size(); + ++member_type_index) { + auto member_type_id = + storage->GetOperandAs(member_type_index); + auto member_type = _.FindDef(member_type_id); + if (ContainsCooperativeMatrix(_, member_type)) return true; + } + break; + default: + break; + } + return false; +} + std::pair GetStorageClass( ValidationState_t& _, const Instruction* inst) { SpvStorageClass dst_sc = SpvStorageClassMax; SpvStorageClass src_sc = SpvStorageClassMax; switch (inst->opcode()) { + case SpvOpCooperativeMatrixLoadNV: case SpvOpLoad: { auto load_pointer = _.FindDef(inst->GetOperandAs(2)); auto load_pointer_type = _.FindDef(load_pointer->type_id()); dst_sc = load_pointer_type->GetOperandAs(1); break; } + case SpvOpCooperativeMatrixStoreNV: case SpvOpStore: { auto store_pointer = _.FindDef(inst->GetOperandAs(0)); auto store_pointer_type = _.FindDef(store_pointer->type_id()); @@ -232,7 +264,8 @@ std::pair GetStorageClass( } // This function is only called for OpLoad, OpStore, OpCopyMemory and -// OpCopyMemorySized. +// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and +// OpCooperativeMatrixStoreNV. uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) { uint32_t offset = 1; if (mask & SpvMemoryAccessAlignedMask) ++offset; @@ -245,6 +278,10 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) { case SpvOpStore: case SpvOpCopyMemory: return inst->GetOperandAs(2 + offset); + case SpvOpCooperativeMatrixLoadNV: + return inst->GetOperandAs(5 + offset); + case SpvOpCooperativeMatrixStoreNV: + return inst->GetOperandAs(4 + offset); default: assert(false && "unexpected opcode"); break; @@ -253,8 +290,9 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) { return scope_id; } -// This function is only called for OpLoad, OpStore, OpCopyMemory and -// OpCopyMemorySized. +// This function is only called for OpLoad, OpStore, OpCopyMemory, +// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and +// OpCooperativeMatrixStoreNV. uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) { uint32_t offset = 1; if (mask & SpvMemoryAccessAlignedMask) ++offset; @@ -268,6 +306,10 @@ uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) { case SpvOpStore: case SpvOpCopyMemory: return inst->GetOperandAs(2 + offset); + case SpvOpCooperativeMatrixLoadNV: + return inst->GetOperandAs(5 + offset); + case SpvOpCooperativeMatrixStoreNV: + return inst->GetOperandAs(4 + offset); default: assert(false && "unexpected opcode"); break; @@ -302,7 +344,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, uint32_t mask = inst->GetOperandAs(index); if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) { - if (inst->opcode() == SpvOpLoad) { + if (inst->opcode() == SpvOpLoad || + inst->opcode() == SpvOpCooperativeMatrixLoadNV) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerAvailableKHR cannot be used with OpLoad."; } @@ -320,7 +363,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, } if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) { - if (inst->opcode() == SpvOpStore) { + if (inst->opcode() == SpvOpStore || + inst->opcode() == SpvOpCooperativeMatrixStoreNV) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerVisibleKHR cannot be used with OpStore."; } @@ -672,6 +716,17 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { } } + // Cooperative matrix types can only be allocated in Function or Private + if ((storage_class != SpvStorageClassFunction && + storage_class != SpvStorageClassPrivate) && + ContainsCooperativeMatrix(_, pointee)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Cooperative matrix types (or types containing them) can only be " + "allocated " + << "in Function or Private storage classes or as function " + "parameters"; + } + return SPV_SUCCESS; } @@ -1003,10 +1058,11 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, switch (type_pointee->opcode()) { case SpvOpTypeMatrix: case SpvOpTypeVector: + case SpvOpTypeCooperativeMatrixNV: case SpvOpTypeArray: case SpvOpTypeRuntimeArray: { - // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray, - // word 2 is the Element Type. + // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV, + // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type. type_pointee = _.FindDef(type_pointee->word(2)); break; } @@ -1136,6 +1192,140 @@ spv_result_t ValidateArrayLength(ValidationState_t& state, return SPV_SUCCESS; } +spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state, + const Instruction* inst) { + std::string instr_name = + "Op" + std::string(spvOpcodeString(static_cast(inst->opcode()))); + + // Result type must be a 32-bit unsigned int. + auto result_type = state.FindDef(inst->type_id()); + if (result_type->opcode() != SpvOpTypeInt || + result_type->GetOperandAs(1) != 32 || + result_type->GetOperandAs(2) != 0) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The Result Type of " << instr_name << " '" + << state.getIdName(inst->id()) + << "' must be OpTypeInt with width 32 and signedness 0."; + } + + auto type_id = inst->GetOperandAs(2); + auto type = state.FindDef(type_id); + if (type->opcode() != SpvOpTypeCooperativeMatrixNV) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The type in " << instr_name << " '" + << state.getIdName(type_id) + << "' must be OpTypeCooperativeMatrixNV."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _, + const Instruction* inst) { + uint32_t type_id; + const char* opname; + if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) { + type_id = inst->type_id(); + opname = "SpvOpCooperativeMatrixLoadNV"; + } else { + // get Object operand's type + type_id = _.FindDef(inst->GetOperandAs(1))->type_id(); + opname = "SpvOpCooperativeMatrixStoreNV"; + } + + auto matrix_type = _.FindDef(type_id); + + if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) { + if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "SpvOpCooperativeMatrixLoadNV Result Type '" + << _.getIdName(type_id) << "' is not a cooperative matrix type."; + } else { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "SpvOpCooperativeMatrixStoreNV Object type '" + << _.getIdName(type_id) << "' is not a cooperative matrix type."; + } + } + + const bool uses_variable_pointers = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + const auto pointer_index = + (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u; + const auto pointer_id = inst->GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + ((_.addressing_model() == SpvAddressingModelLogical) && + ((!uses_variable_pointers && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (uses_variable_pointers && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer '" << _.getIdName(pointer_id) + << "' is not a logical pointer."; + } + + const auto pointer_type_id = pointer->type_id(); + const auto pointer_type = _.FindDef(pointer_type_id); + if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " type for pointer '" << _.getIdName(pointer_id) + << "' is not a pointer type."; + } + + const auto storage_class_index = 1u; + const auto storage_class = + pointer_type->GetOperandAs(storage_class_index); + + if (storage_class != SpvStorageClassWorkgroup && + storage_class != SpvStorageClassStorageBuffer && + storage_class != SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " storage class for pointer type '" + << _.getIdName(pointer_type_id) + << "' is not Workgroup or StorageBuffer."; + } + + const auto pointee_id = pointer_type->GetOperandAs(2); + const auto pointee_type = _.FindDef(pointee_id); + if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) || + _.IsFloatScalarOrVectorType(pointee_id))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer '" << _.getIdName(pointer->id()) + << "'s Type must be a scalar or vector type."; + } + + const auto stride_index = + (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u; + const auto stride_id = inst->GetOperandAs(stride_index); + const auto stride = _.FindDef(stride_id); + if (!stride || !_.IsIntScalarType(stride->type_id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Stride operand '" << _.getIdName(stride_id) + << "' must be a scalar integer type."; + } + + const auto colmajor_index = + (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u; + const auto colmajor_id = inst->GetOperandAs(colmajor_index); + const auto colmajor = _.FindDef(colmajor_id); + if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) || + !(spvOpcodeIsConstant(colmajor->opcode()) || + spvOpcodeIsSpecConstant(colmajor->opcode()))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Column Major operand '" << _.getIdName(colmajor_id) + << "' must be a boolean constant instruction."; + } + + const auto memory_access_index = + (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u; + if (inst->operands().size() > memory_access_index) { + if (auto error = CheckMemoryAccess(_, inst, memory_access_index)) + return error; + } + + return SPV_SUCCESS; +} + } // namespace spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { @@ -1164,6 +1354,14 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { case SpvOpArrayLength: if (auto error = ValidateArrayLength(_, inst)) return error; break; + case SpvOpCooperativeMatrixLoadNV: + case SpvOpCooperativeMatrixStoreNV: + if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst)) + return error; + break; + case SpvOpCooperativeMatrixLengthNV: + if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error; + break; case SpvOpImageTexelPointer: case SpvOpGenericPtrMemSemantics: default: diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp index b6401310d..2223a7786 100644 --- a/source/val/validate_scopes.cpp +++ b/source/val/validate_scopes.cpp @@ -36,11 +36,19 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _, } if (!is_const_int32) { - if (_.HasCapability(SpvCapabilityShader)) { + if (_.HasCapability(SpvCapabilityShader) && + !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Scope ids must be OpConstant when Shader capability is " << "present"; } + if (_.HasCapability(SpvCapabilityShader) && + _.HasCapability(SpvCapabilityCooperativeMatrixNV) && + !spvOpcodeIsConstant(_.GetIdOpcode(scope))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Scope ids must be constant or specialization constant when " + << "CooperativeMatrixNV capability is present"; + } return SPV_SUCCESS; } @@ -130,11 +138,19 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst, } if (!is_const_int32) { - if (_.HasCapability(SpvCapabilityShader)) { + if (_.HasCapability(SpvCapabilityShader) && + !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Scope ids must be OpConstant when Shader capability is " << "present"; } + if (_.HasCapability(SpvCapabilityShader) && + _.HasCapability(SpvCapabilityCooperativeMatrixNV) && + !spvOpcodeIsConstant(_.GetIdOpcode(scope))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Scope ids must be constant or specialization constant when " + << "CooperativeMatrixNV capability is present"; + } return SPV_SUCCESS; } diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index a5428d74c..ad72a37c3 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -381,6 +381,53 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _, return SPV_SUCCESS; } +spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _, + const Instruction* inst) { + const auto component_type_index = 1; + const auto component_type_id = + inst->GetOperandAs(component_type_index); + const auto component_type = _.FindDef(component_type_id); + if (!component_type || (SpvOpTypeFloat != component_type->opcode() && + SpvOpTypeInt != component_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixNV Component Type '" + << _.getIdName(component_type_id) + << "' is not a scalar numerical type."; + } + + const auto scope_index = 2; + const auto scope_id = inst->GetOperandAs(scope_index); + const auto scope = _.FindDef(scope_id); + if (!scope || !_.IsIntScalarType(scope->type_id()) || + !spvOpcodeIsConstant(scope->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixNV Scope '" << _.getIdName(scope_id) + << "' is not a constant instruction with scalar integer type."; + } + + const auto rows_index = 3; + const auto rows_id = inst->GetOperandAs(rows_index); + const auto rows = _.FindDef(rows_id); + if (!rows || !_.IsIntScalarType(rows->type_id()) || + !spvOpcodeIsConstant(rows->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixNV Rows '" << _.getIdName(rows_id) + << "' is not a constant instruction with scalar integer type."; + } + + const auto cols_index = 4; + const auto cols_id = inst->GetOperandAs(cols_index); + const auto cols = _.FindDef(cols_id); + if (!cols || !_.IsIntScalarType(cols->type_id()) || + !spvOpcodeIsConstant(cols->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeMatrixNV Cols '" << _.getIdName(rows_id) + << "' is not a constant instruction with scalar integer type."; + } + + return SPV_SUCCESS; +} + } // namespace spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { @@ -416,6 +463,9 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { case SpvOpTypeForwardPointer: if (auto error = ValidateTypeForwardPointer(_, inst)) return error; break; + case SpvOpTypeCooperativeMatrixNV: + if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error; + break; default: break; } diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 2633963cb..e6e5e2622 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -610,6 +610,9 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const { case SpvOpTypeMatrix: return GetComponentType(inst->word(2)); + case SpvOpTypeCooperativeMatrixNV: + return inst->word(2); + default: break; } @@ -634,6 +637,10 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const { case SpvOpTypeMatrix: return inst->word(3); + case SpvOpTypeCooperativeMatrixNV: + // Actual dimension isn't known, return 0 + return 0; + default: break; } @@ -862,6 +869,86 @@ bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type, return true; } +bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeCooperativeMatrixNV; +} + +bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const { + if (!IsCooperativeMatrixType(id)) return false; + return IsFloatScalarType(FindDef(id)->word(2)); +} + +bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const { + if (!IsCooperativeMatrixType(id)) return false; + return IsIntScalarType(FindDef(id)->word(2)); +} + +bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const { + if (!IsCooperativeMatrixType(id)) return false; + return IsUnsignedIntScalarType(FindDef(id)->word(2)); +} + +spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( + const Instruction* inst, uint32_t m1, uint32_t m2) { + const auto m1_type = FindDef(m1); + const auto m2_type = FindDef(m2); + + if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV || + m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative matrix types"; + } + + uint32_t m1_scope_id = m1_type->GetOperandAs(2); + uint32_t m1_rows_id = m1_type->GetOperandAs(3); + uint32_t m1_cols_id = m1_type->GetOperandAs(4); + + uint32_t m2_scope_id = m2_type->GetOperandAs(2); + uint32_t m2_rows_id = m2_type->GetOperandAs(3); + uint32_t m2_cols_id = m2_type->GetOperandAs(4); + + bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false, + m2_is_const_int32 = false; + uint32_t m1_value = 0, m2_value = 0; + + std::tie(m1_is_int32, m1_is_const_int32, m1_value) = + EvalInt32IfConst(m1_scope_id); + std::tie(m2_is_int32, m2_is_const_int32, m2_value) = + EvalInt32IfConst(m2_scope_id); + + if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected scopes of Matrix and Result Type to be " + << "identical"; + } + + std::tie(m1_is_int32, m1_is_const_int32, m1_value) = + EvalInt32IfConst(m1_rows_id); + std::tie(m2_is_int32, m2_is_const_int32, m2_value) = + EvalInt32IfConst(m2_rows_id); + + if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected rows of Matrix type and Result Type to be " + << "identical"; + } + + std::tie(m1_is_int32, m1_is_const_int32, m1_value) = + EvalInt32IfConst(m1_cols_id); + std::tie(m2_is_int32, m2_is_const_int32, m2_value) = + EvalInt32IfConst(m2_cols_id); + + if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected columns of Matrix type and Result Type to be " + << "identical"; + } + + return SPV_SUCCESS; +} + uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst, size_t operand_index) const { return GetTypeId(inst->GetOperandAs(operand_index)); @@ -890,7 +977,7 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const { } std::tuple ValidationState_t::EvalInt32IfConst( - uint32_t id) { + uint32_t id) const { const Instruction* const inst = FindDef(id); assert(inst); const uint32_t type = inst->type_id(); diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 55005a605..94fa9456c 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -552,6 +552,10 @@ class ValidationState_t { bool IsBoolVectorType(uint32_t id) const; bool IsBoolScalarOrVectorType(uint32_t id) const; bool IsPointerType(uint32_t id) const; + bool IsCooperativeMatrixType(uint32_t id) const; + bool IsFloatCooperativeMatrixType(uint32_t id) const; + bool IsIntCooperativeMatrixType(uint32_t id) const; + bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const; // Gets value from OpConstant and OpSpecConstant as uint64. // Returns false on failure (no instruction, wrong instruction, not int). @@ -635,7 +639,7 @@ class ValidationState_t { // Returns tuple . // OpSpecConstant* return |is_const_int32| as false since their values cannot // be relied upon during validation. - std::tuple EvalInt32IfConst(uint32_t id); + std::tuple EvalInt32IfConst(uint32_t id) const; // Returns the disassembly string for the given instruction. std::string Disassemble(const Instruction& inst) const; @@ -643,6 +647,12 @@ class ValidationState_t { // Returns the disassembly string for the given instruction. std::string Disassemble(const uint32_t* words, uint16_t num_words) const; + // Returns whether type m1 and type m2 are cooperative matrices with + // the same "shape" (matching scope, rows, cols). If any are specialization + // constants, we assume they can match because we can't prove they don't. + spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst, + uint32_t m1, uint32_t m2); + private: ValidationState_t(const ValidationState_t&); diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp index 87e006c12..b82fc97e1 100644 --- a/test/val/val_arithmetics_test.cpp +++ b/test/val/val_arithmetics_test.cpp @@ -1165,6 +1165,150 @@ TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) { "vector size of the right operand: OuterProduct")); } +std::string GenerateCoopMatCode(const std::string& extra_types, + const std::string& main_body) { + const std::string prefix = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_16 = OpConstant %u32 16 +%u32_4 = OpConstant %u32 4 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 +%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8 +%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8 + +%f16_1 = OpConstant %f16 1 +%f32_1 = OpConstant %f32 1 +%u32_1 = OpConstant %u32 1 +%s32_1 = OpConstant %s32 1 + +%f16mat_1 = OpConstantComposite %f16mat %f16_1 +%u32mat_1 = OpConstantComposite %u32mat %u32_1 +%s32mat_1 = OpConstantComposite %s32mat %s32_1 + +%u32_c1 = OpSpecConstant %u32 1 +%u32_c2 = OpSpecConstant %u32 2 + +%f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2 +%f16matc_1 = OpConstantComposite %f16matc %f16_1 + +%mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4 +%mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16 +%mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16 +%f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1 +%f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1 +%f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)"; + + const std::string func_begin = + R"( +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + return prefix + extra_types + func_begin + main_body + suffix; +} + +TEST_F(ValidateArithmetics, CoopMatSuccess) { + const std::string body = R"( +%val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1 +%val2 = OpFSub %f16mat %f16mat_1 %f16mat_1 +%val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1 +%val4 = OpFNegate %f16mat %f16mat_1 +%val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1 +%val6 = OpISub %u32mat %u32mat_1 %u32mat_1 +%val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1 +%val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1 +%val9 = OpISub %s32mat %s32mat_1 %s32mat_1 +%val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1 +%val11 = OpSNegate %s32mat %s32mat_1 +%val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1 +%val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1 +%val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1 +%val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1 +%val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1 +)"; + + CompileSuccessfully(GenerateCoopMatCode("", body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, CoopMatFMulFail) { + const std::string body = R"( +%val1 = OpFMul %f16mat %f16mat_1 %f16mat_1 +)"; + + CompileSuccessfully(GenerateCoopMatCode("", body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected floating scalar or vector type as Result Type: FMul")); +} + +TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) { + const std::string body = R"( +%val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1 +)"; + + CompileSuccessfully(GenerateCoopMatCode("", body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected scalar operand type to be equal to the component " + "type of the matrix operand: MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, CoopMatScopeFail) { + const std::string types = R"( +%workgroup = OpConstant %u32 2 + +%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16 +%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1 +)"; + + const std::string body = R"( +%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1 +)"; + + CompileSuccessfully(GenerateCoopMatCode(types, body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV")); +} + +TEST_F(ValidateArithmetics, CoopMatDimFail) { + const std::string body = R"( +%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1 +)"; + + CompileSuccessfully(GenerateCoopMatCode("", body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV")); +} + TEST_F(ValidateArithmetics, IAddCarrySuccess) { const std::string body = R"( %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1 diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp index bf7f15d51..db6ff5b19 100644 --- a/test/val/val_composites_test.cpp +++ b/test/val/val_composites_test.cpp @@ -1467,6 +1467,83 @@ OpFunctionEnd EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%f32_1 = OpConstant %f32 1 + +%f16mat_1 = OpConstantComposite %f16mat %f32_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '11[%float_1]' type does " + "not match the Result Type '10[%10]'s component type.")); +} + +TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%f32_1 = OpConstant %f32 1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%f16mat_1 = OpCompositeConstruct %f16mat %f32_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Constituent type to be equal to the component type")); +} + TEST_F(ValidateComposites, ExtractDynamicLabelIndex) { const std::string spirv = R"( OpCapability Shader diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp index 5e4ad4908..f905657bd 100644 --- a/test/val/val_conversion_test.cpp +++ b/test/val/val_conversion_test.cpp @@ -1184,6 +1184,172 @@ TEST_F(ValidateConversion, GenericCastToPtrExplicitPointToDifferentType) { "GenericCastToPtrExplicit")); } +TEST_F(ValidateConversion, CoopMatConversionSuccess) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 +%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8 +%u16mat = OpTypeCooperativeMatrixNV %u16 %subgroup %u32_8 %u32_8 +%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8 +%s16mat = OpTypeCooperativeMatrixNV %s16 %subgroup %u32_8 %u32_8 +%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8 + +%f16_1 = OpConstant %f16 1 +%f32_1 = OpConstant %f32 1 +%u16_1 = OpConstant %u16 1 +%u32_1 = OpConstant %u32 1 +%s16_1 = OpConstant %s16 1 +%s32_1 = OpConstant %s32 1 + +%f16mat_1 = OpConstantComposite %f16mat %f16_1 +%f32mat_1 = OpConstantComposite %f32mat %f32_1 +%u16mat_1 = OpConstantComposite %u16mat %u16_1 +%u32mat_1 = OpConstantComposite %u32mat %u32_1 +%s16mat_1 = OpConstantComposite %s16mat %s16_1 +%s32mat_1 = OpConstantComposite %s32mat %s32_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val11 = OpConvertFToU %u16mat %f16mat_1 +%val12 = OpConvertFToU %u32mat %f16mat_1 +%val13 = OpConvertFToS %s16mat %f16mat_1 +%val14 = OpConvertFToS %s32mat %f16mat_1 +%val15 = OpFConvert %f32mat %f16mat_1 + +%val21 = OpConvertFToU %u16mat %f32mat_1 +%val22 = OpConvertFToU %u32mat %f32mat_1 +%val23 = OpConvertFToS %s16mat %f32mat_1 +%val24 = OpConvertFToS %s32mat %f32mat_1 +%val25 = OpFConvert %f16mat %f32mat_1 + +%val31 = OpConvertUToF %f16mat %u16mat_1 +%val32 = OpConvertUToF %f32mat %u16mat_1 +%val33 = OpUConvert %u32mat %u16mat_1 +%val34 = OpSConvert %s32mat %u16mat_1 + +%val41 = OpConvertSToF %f16mat %s16mat_1 +%val42 = OpConvertSToF %f32mat %s16mat_1 +%val43 = OpUConvert %u32mat %s16mat_1 +%val44 = OpSConvert %s32mat %s16mat_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, CoopMatConversionShapesMismatchFail) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_4 = OpConstant %u32 4 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 +%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4 + +%f16_1 = OpConstant %f16 1 + +%f16mat_1 = OpConstantComposite %f16mat %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val15 = OpFConvert %f32mat %f16mat_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected rows of Matrix type and Result Type to be identical")); +} + +TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_4 = OpSpecConstant %u32 4 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 +%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4 + +%f16_1 = OpConstant %f16 1 + +%f16mat_1 = OpConstantComposite %f16mat %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val15 = OpFConvert %f32mat %f16mat_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + TEST_F(ValidateConversion, BitcastSuccess) { const std::string body = R"( %ptr = OpVariable %f32ptr_func Function diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index b567a7bba..246b85edb 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -1774,6 +1774,372 @@ OpFunctionEnd HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable")); } +std::string GenCoopMatLoadStoreShader(const std::string& storeMemoryAccess, + const std::string& loadMemoryAccess) { + std::string s = R"( +OpCapability Shader +OpCapability GroupNonUniform +OpCapability VulkanMemoryModelKHR +OpCapability CooperativeMatrixNV +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_NV_cooperative_matrix" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical VulkanKHR +OpEntryPoint GLCompute %4 "main" %11 %21 +OpExecutionMode %4 LocalSize 1 1 1 +OpDecorate %11 BuiltIn SubgroupId +OpDecorate %21 BuiltIn WorkgroupId +OpDecorate %74 ArrayStride 4 +OpMemberDecorate %75 0 Offset 0 +OpDecorate %75 Block +OpDecorate %77 DescriptorSet 0 +OpDecorate %77 Binding 0 +OpDecorate %92 ArrayStride 4 +OpMemberDecorate %93 0 Offset 0 +OpDecorate %93 Block +OpDecorate %95 DescriptorSet 0 +OpDecorate %95 Binding 1 +OpDecorate %102 ArrayStride 4 +OpMemberDecorate %103 0 Offset 0 +OpDecorate %103 Block +OpDecorate %105 DescriptorSet 0 +OpDecorate %105 Binding 2 +OpDecorate %117 ArrayStride 4 +OpMemberDecorate %118 0 Offset 0 +OpDecorate %118 Block +OpDecorate %120 DescriptorSet 0 +OpDecorate %120 Binding 3 +OpDecorate %123 SpecId 2 +OpDecorate %124 SpecId 3 +OpDecorate %125 SpecId 4 +OpDecorate %126 SpecId 5 +OpDecorate %127 SpecId 0 +OpDecorate %128 SpecId 1 +OpDecorate %129 BuiltIn WorkgroupSize +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %6 2 +%8 = OpTypePointer Function %7 +%10 = OpTypePointer Input %6 +%11 = OpVariable %10 Input +%13 = OpConstant %6 2 +%19 = OpTypeVector %6 3 +%20 = OpTypePointer Input %19 +%21 = OpVariable %20 Input +%27 = OpConstantComposite %7 %13 %13 +%31 = OpTypePointer Function %6 +%33 = OpConstant %6 1024 +%34 = OpConstant %6 1 +%38 = OpConstant %6 8 +%39 = OpConstant %6 0 +%68 = OpTypeFloat 32 +%69 = OpConstant %6 16 +%70 = OpConstant %6 3 +%71 = OpTypeCooperativeMatrixNV %68 %70 %69 %38 +%72 = OpTypePointer Function %71 +%74 = OpTypeRuntimeArray %68 +%75 = OpTypeStruct %74 +%76 = OpTypePointer StorageBuffer %75 +%77 = OpVariable %76 StorageBuffer +%78 = OpTypeInt 32 1 +%79 = OpConstant %78 0 +%81 = OpConstant %6 5 +%82 = OpTypePointer StorageBuffer %68 +%84 = OpConstant %6 64 +%85 = OpTypeBool +%86 = OpConstantFalse %85 +%88 = OpTypePointer Private %71 +%89 = OpVariable %88 Private +%92 = OpTypeRuntimeArray %68 +%93 = OpTypeStruct %92 +%94 = OpTypePointer StorageBuffer %93 +%95 = OpVariable %94 StorageBuffer +%99 = OpVariable %88 Private +%102 = OpTypeRuntimeArray %68 +%103 = OpTypeStruct %102 +%104 = OpTypePointer StorageBuffer %103 +%105 = OpVariable %104 StorageBuffer +%109 = OpVariable %88 Private +%111 = OpVariable %88 Private +%112 = OpSpecConstantOp %6 CooperativeMatrixLengthNV %71 +%113 = OpSpecConstantOp %78 IAdd %112 %79 +%117 = OpTypeRuntimeArray %68 +%118 = OpTypeStruct %117 +%119 = OpTypePointer StorageBuffer %118 +%120 = OpVariable %119 StorageBuffer +%123 = OpSpecConstant %78 1 +%124 = OpSpecConstant %78 1 +%125 = OpSpecConstant %78 1 +%126 = OpSpecConstant %78 1 +%127 = OpSpecConstant %6 1 +%128 = OpSpecConstant %6 1 +%129 = OpSpecConstantComposite %19 %127 %128 %34 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%9 = OpVariable %8 Function +%18 = OpVariable %8 Function +%32 = OpVariable %31 Function +%44 = OpVariable %31 Function +%52 = OpVariable %31 Function +%60 = OpVariable %31 Function +%73 = OpVariable %72 Function +%91 = OpVariable %72 Function +%101 = OpVariable %72 Function +%12 = OpLoad %6 %11 +%14 = OpUMod %6 %12 %13 +%15 = OpLoad %6 %11 +%16 = OpUDiv %6 %15 %13 +%17 = OpCompositeConstruct %7 %14 %16 +OpStore %9 %17 +%22 = OpLoad %19 %21 +%23 = OpVectorShuffle %7 %22 %22 0 1 +%24 = OpCompositeExtract %6 %23 0 +%25 = OpCompositeExtract %6 %23 1 +%26 = OpCompositeConstruct %7 %24 %25 +%28 = OpIMul %7 %26 %27 +%29 = OpLoad %7 %9 +%30 = OpIAdd %7 %28 %29 +OpStore %18 %30 +%35 = OpAccessChain %31 %18 %34 +%36 = OpLoad %6 %35 +%37 = OpIMul %6 %33 %36 +%40 = OpAccessChain %31 %18 %39 +%41 = OpLoad %6 %40 +%42 = OpIMul %6 %38 %41 +%43 = OpIAdd %6 %37 %42 +OpStore %32 %43 +%45 = OpAccessChain %31 %18 %34 +%46 = OpLoad %6 %45 +%47 = OpIMul %6 %33 %46 +%48 = OpAccessChain %31 %18 %39 +%49 = OpLoad %6 %48 +%50 = OpIMul %6 %38 %49 +%51 = OpIAdd %6 %47 %50 +OpStore %44 %51 +%53 = OpAccessChain %31 %18 %34 +%54 = OpLoad %6 %53 +%55 = OpIMul %6 %33 %54 +%56 = OpAccessChain %31 %18 %39 +%57 = OpLoad %6 %56 +%58 = OpIMul %6 %38 %57 +%59 = OpIAdd %6 %55 %58 +OpStore %52 %59 +%61 = OpAccessChain %31 %18 %34 +%62 = OpLoad %6 %61 +%63 = OpIMul %6 %33 %62 +%64 = OpAccessChain %31 %18 %39 +%65 = OpLoad %6 %64 +%66 = OpIMul %6 %38 %65 +%67 = OpIAdd %6 %63 %66 +OpStore %60 %67 +%80 = OpLoad %6 %32 +%83 = OpAccessChain %82 %77 %79 %80 +%87 = OpCooperativeMatrixLoadNV %71 %83 %84 %86 )" + + loadMemoryAccess + R"( %81 +OpStore %73 %87 +%90 = OpLoad %71 %73 +OpStore %89 %90 +%96 = OpLoad %6 %44 +%97 = OpAccessChain %82 %95 %79 %96 +%98 = OpCooperativeMatrixLoadNV %71 %97 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +OpStore %91 %98 +%100 = OpLoad %71 %91 +OpStore %99 %100 +%106 = OpLoad %6 %52 +%107 = OpAccessChain %82 %105 %79 %106 +%108 = OpCooperativeMatrixLoadNV %71 %107 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +OpStore %101 %108 +%110 = OpLoad %71 %101 +OpStore %109 %110 +%114 = OpConvertSToF %68 %113 +%115 = OpCompositeConstruct %71 %114 +OpStore %111 %115 +%116 = OpLoad %71 %111 +%121 = OpLoad %6 %60 +%122 = OpAccessChain %82 %120 %79 %121 +OpCooperativeMatrixStoreNV %122 %116 %84 %86 )" + storeMemoryAccess + R"( %81 +OpReturn +OpFunctionEnd +)"; + + return s; +} + +TEST_F(ValidateMemory, CoopMatLoadStoreSuccess) { + std::string spirv = + GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, CoopMatStoreMemoryAccessFail) { + std::string spirv = + GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerVisibleKHR cannot be used with OpStore")); +} + +TEST_F(ValidateMemory, CoopMatLoadMemoryAccessFail) { + std::string spirv = + GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerAvailableKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad")); +} + +TEST_F(ValidateMemory, CoopMatInvalidStorageClassFail) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%str = OpTypeStruct %f16mat +%str_ptr = OpTypePointer Workgroup %str +%sh = OpVariable %str_ptr Workgroup + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Cooperative matrix types (or types containing them) can only be " + "allocated in Function or Private storage classes or as function " + "parameters")); +} + +TEST_F(ValidateMemory, CoopMatMatrixLengthResultTypeBad) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 +%i32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%1 = OpCooperativeMatrixLengthNV %i32 %f16mat + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The Result Type of OpCooperativeMatrixLengthNV " + "'11[%11]' must be OpTypeInt with width 32 and signedness 0")); +} + +TEST_F(ValidateMemory, CoopMatMatrixLengthOperandTypeBad) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 +%i32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%1 = OpCooperativeMatrixLengthNV %u32 %u32 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The type in OpCooperativeMatrixLengthNV '5[%uint]' " + "must be OpTypeCooperativeMatrixNV")); +} + +TEST_F(ValidateMemory, CoopMatMatrixLengthGood) { + const std::string body = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeMatrixNV +OpExtension "SPV_NV_cooperative_matrix" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 +%i32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%subgroup = OpConstant %u32 3 + +%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%1 = OpCooperativeMatrixLengthNV %u32 %f16mat + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) { std::string spirv = R"( OpCapability Shader