Add validation for SPV_NV_cooperative_matrix (#2404)

This commit is contained in:
Jeff Bolz 2019-02-25 16:43:11 -06:00 committed by alan-baker
parent fc3897b5f5
commit 002ef361ca
18 changed files with 1390 additions and 60 deletions

2
DEPS
View File

@ -11,7 +11,7 @@ vars = {
'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036',
'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59',
're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f',
'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f',
'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab',
}
deps = {

View File

@ -154,10 +154,11 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
CASE(InBoundsAccessChain),
CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain),
CASE(CooperativeMatrixLengthNV)
};
// The 59 is determined by counting the opcodes listed in the spec.
static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
// The 60 is determined by counting the opcodes listed in the spec.
static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete");
#undef CASE
// clang-format on

View File

@ -260,6 +260,7 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode) {
case SpvOpTypeMatrix:
case SpvOpTypeArray:
case SpvOpTypeStruct:
case SpvOpTypeCooperativeMatrixNV:
return true;
default:
return false;
@ -325,6 +326,7 @@ int32_t spvOpcodeGeneratesType(SpvOp op) {
case SpvOpTypePipeStorage:
case SpvOpTypeNamedBarrier:
case SpvOpTypeAccelerationStructureNV:
case SpvOpTypeCooperativeMatrixNV:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,

View File

@ -45,7 +45,8 @@ inline bool IsAnnotationInst(SpvOp opcode) {
inline bool IsTypeInst(SpvOp opcode) {
return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) ||
opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier ||
opcode == SpvOpTypeAccelerationStructureNV;
opcode == SpvOpTypeAccelerationStructureNV ||
opcode == SpvOpTypeCooperativeMatrixNV;
}
inline bool IsConstantInst(SpvOp opcode) {
return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp;

View File

@ -39,8 +39,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpFRem:
case SpvOpFMod:
case SpvOpFNegate: {
bool supportsCoopMat =
(opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type))
!_.IsFloatVectorType(result_type) &&
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@ -58,8 +61,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpUDiv:
case SpvOpUMod: {
bool supportsCoopMat = (opcode == SpvOpUDiv);
if (!_.IsUnsignedIntScalarType(result_type) &&
!_.IsUnsignedIntVectorType(result_type))
!_.IsUnsignedIntVectorType(result_type) &&
!(supportsCoopMat &&
_.IsUnsignedIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@ -82,7 +88,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpSMod:
case SpvOpSRem:
case SpvOpSNegate: {
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
bool supportsCoopMat =
(opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@ -94,7 +103,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!type_id ||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
@ -176,7 +186,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
}
case SpvOpMatrixTimesScalar: {
if (!_.IsFloatMatrixType(result_type))
if (!_.IsFloatMatrixType(result_type) &&
!_.IsCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float matrix type as Result Type: "
<< spvOpcodeString(opcode);
@ -442,6 +453,92 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case SpvOpCooperativeMatrixMulAddNV: {
const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
if (!_.IsCooperativeMatrixType(A_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as A Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixType(B_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as B Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixType(C_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as C Type: "
<< spvOpcodeString(opcode);
}
if (!_.IsCooperativeMatrixType(D_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as Result Type: "
<< spvOpcodeString(opcode);
}
const auto A = _.FindDef(A_type_id);
const auto B = _.FindDef(B_type_id);
const auto C = _.FindDef(C_type_id);
const auto D = _.FindDef(D_type_id);
std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
std::tuple<bool, bool, uint32_t> Y) {
return (std::get<1>(X) && std::get<1>(Y) &&
std::get<2>(X) != std::get<2>(Y));
};
if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix scopes must match: "
<< spvOpcodeString(opcode);
}
if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
notEqual(C_rows, D_rows)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'M' mismatch: "
<< spvOpcodeString(opcode);
}
if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
notEqual(C_cols, D_cols)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'N' mismatch: "
<< spvOpcodeString(opcode);
}
if (notEqual(A_cols, B_rows)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cooperative matrix 'K' mismatch: "
<< spvOpcodeString(opcode);
}
break;
}
default:
break;
}

View File

@ -118,6 +118,10 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
*member_type = type_inst->word(component_index + 2);
break;
}
case SpvOpTypeCooperativeMatrixNV: {
*member_type = type_inst->word(2);
break;
}
default:
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Reached non-composite type while indexes still remain to "
@ -315,6 +319,26 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
break;
}
case SpvOpTypeCooperativeMatrixNV: {
const auto result_type_inst = _.FindDef(result_type);
assert(result_type_inst);
const auto component_type_id =
result_type_inst->GetOperandAs<uint32_t>(1);
if (3 != num_operands) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected single constituent";
}
const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
if (operand_type_id != component_type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Constituent type to be equal to the component type";
}
break;
}
default: {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be a composite type";

View File

@ -247,6 +247,36 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
}
}
} break;
case SpvOpTypeCooperativeMatrixNV: {
if (1 != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(inst->type_id()) << "' count must be one.";
}
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
const auto constituent = _.FindDef(constituent_id);
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
const auto component_type = _.FindDef(component_type_id);
if (!component_type || component_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match the Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s component type.";
}
} break;
default:
break;
}
@ -285,6 +315,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
return true;
case SpvOpTypeArray:
case SpvOpTypeMatrix:
case SpvOpTypeCooperativeMatrixNV:
case SpvOpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
@ -320,7 +351,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
// The binary parser already ensures that the op is valid for *some*
// environment. Here we check restrictions.
switch(op) {
switch (op) {
case SpvOpQuantizeToF16:
if (!_.HasCapability(SpvCapabilityShader)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
@ -365,7 +396,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
}
break;
default:
default:
break;
}

View File

@ -32,22 +32,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
switch (opcode) {
case SpvOpConvertFToU: {
if (!_.IsUnsignedIntScalarType(result_type) &&
!_.IsUnsignedIntVectorType(result_type))
!_.IsUnsignedIntVectorType(result_type) &&
!_.IsUnsignedIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type)))
!_.IsFloatVectorType(input_type) &&
!_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -58,22 +67,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
}
case SpvOpConvertFToS: {
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!_.IsIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type)))
!_.IsFloatVectorType(input_type) &&
!_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -86,22 +104,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpConvertSToF:
case SpvOpConvertUToF: {
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type))
!_.IsFloatVectorType(result_type) &&
!_.IsFloatCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
!_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -113,22 +140,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpUConvert: {
if (!_.IsUnsignedIntScalarType(result_type) &&
!_.IsUnsignedIntVectorType(result_type))
!_.IsUnsignedIntVectorType(result_type) &&
!_.IsUnsignedIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
!_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -139,22 +175,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
}
case SpvOpSConvert: {
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!_.IsIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
!_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@ -166,22 +211,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpFConvert: {
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type))
!_.IsFloatVectorType(result_type) &&
!_.IsFloatCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type)))
!_.IsFloatVectorType(input_type) &&
!_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
}
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)

View File

@ -167,7 +167,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
const auto opcode = inst->opcode();
if (spvOpcodeGeneratesType(def->opcode()) &&
!spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
!spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) {
!spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction &&
opcode != SpvOpCooperativeMatrixLengthNV &&
!(opcode == SpvOpSpecConstantOp &&
inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " cannot be a type";
@ -177,7 +180,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
!spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi &&
opcode != SpvOpExtInst && opcode != SpvOpExtInstImport &&
opcode != SpvOpSelectionMerge &&
opcode != SpvOpLoopMerge && opcode != SpvOpFunction) {
opcode != SpvOpLoopMerge && opcode != SpvOpFunction &&
opcode != SpvOpCooperativeMatrixLengthNV &&
!(opcode == SpvOpSpecConstantOp &&
inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " requires a type";

View File

@ -197,17 +197,49 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
return false;
}
bool ContainsCooperativeMatrix(ValidationState_t& _,
const Instruction* storage) {
const size_t elem_type_index = 1;
uint32_t elem_type_id;
Instruction* elem_type;
switch (storage->opcode()) {
case SpvOpTypeCooperativeMatrixNV:
return true;
case SpvOpTypeArray:
case SpvOpTypeRuntimeArray:
elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
elem_type = _.FindDef(elem_type_id);
return ContainsCooperativeMatrix(_, elem_type);
case SpvOpTypeStruct:
for (size_t member_type_index = 1;
member_type_index < storage->operands().size();
++member_type_index) {
auto member_type_id =
storage->GetOperandAs<uint32_t>(member_type_index);
auto member_type = _.FindDef(member_type_id);
if (ContainsCooperativeMatrix(_, member_type)) return true;
}
break;
default:
break;
}
return false;
}
std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
ValidationState_t& _, const Instruction* inst) {
SpvStorageClass dst_sc = SpvStorageClassMax;
SpvStorageClass src_sc = SpvStorageClassMax;
switch (inst->opcode()) {
case SpvOpCooperativeMatrixLoadNV:
case SpvOpLoad: {
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
auto load_pointer_type = _.FindDef(load_pointer->type_id());
dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
break;
}
case SpvOpCooperativeMatrixStoreNV:
case SpvOpStore: {
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
auto store_pointer_type = _.FindDef(store_pointer->type_id());
@ -232,7 +264,8 @@ std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
}
// This function is only called for OpLoad, OpStore, OpCopyMemory and
// OpCopyMemorySized.
// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
// OpCooperativeMatrixStoreNV.
uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
uint32_t offset = 1;
if (mask & SpvMemoryAccessAlignedMask) ++offset;
@ -245,6 +278,10 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
case SpvOpStore:
case SpvOpCopyMemory:
return inst->GetOperandAs<uint32_t>(2 + offset);
case SpvOpCooperativeMatrixLoadNV:
return inst->GetOperandAs<uint32_t>(5 + offset);
case SpvOpCooperativeMatrixStoreNV:
return inst->GetOperandAs<uint32_t>(4 + offset);
default:
assert(false && "unexpected opcode");
break;
@ -253,8 +290,9 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
return scope_id;
}
// This function is only called for OpLoad, OpStore, OpCopyMemory and
// OpCopyMemorySized.
// This function is only called for OpLoad, OpStore, OpCopyMemory,
// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
// OpCooperativeMatrixStoreNV.
uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
uint32_t offset = 1;
if (mask & SpvMemoryAccessAlignedMask) ++offset;
@ -268,6 +306,10 @@ uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
case SpvOpStore:
case SpvOpCopyMemory:
return inst->GetOperandAs<uint32_t>(2 + offset);
case SpvOpCooperativeMatrixLoadNV:
return inst->GetOperandAs<uint32_t>(5 + offset);
case SpvOpCooperativeMatrixStoreNV:
return inst->GetOperandAs<uint32_t>(4 + offset);
default:
assert(false && "unexpected opcode");
break;
@ -302,7 +344,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
uint32_t mask = inst->GetOperandAs<uint32_t>(index);
if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
if (inst->opcode() == SpvOpLoad) {
if (inst->opcode() == SpvOpLoad ||
inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
}
@ -320,7 +363,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
}
if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
if (inst->opcode() == SpvOpStore) {
if (inst->opcode() == SpvOpStore ||
inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerVisibleKHR cannot be used with OpStore.";
}
@ -672,6 +716,17 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
}
}
// Cooperative matrix types can only be allocated in Function or Private
if ((storage_class != SpvStorageClassFunction &&
storage_class != SpvStorageClassPrivate) &&
ContainsCooperativeMatrix(_, pointee)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Cooperative matrix types (or types containing them) can only be "
"allocated "
<< "in Function or Private storage classes or as function "
"parameters";
}
return SPV_SUCCESS;
}
@ -1003,10 +1058,11 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
switch (type_pointee->opcode()) {
case SpvOpTypeMatrix:
case SpvOpTypeVector:
case SpvOpTypeCooperativeMatrixNV:
case SpvOpTypeArray:
case SpvOpTypeRuntimeArray: {
// In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
// word 2 is the Element Type.
// In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
// OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
type_pointee = _.FindDef(type_pointee->word(2));
break;
}
@ -1136,6 +1192,140 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
return SPV_SUCCESS;
}
spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
const Instruction* inst) {
std::string instr_name =
"Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
// Result type must be a 32-bit unsigned int.
auto result_type = state.FindDef(inst->type_id());
if (result_type->opcode() != SpvOpTypeInt ||
result_type->GetOperandAs<uint32_t>(1) != 32 ||
result_type->GetOperandAs<uint32_t>(2) != 0) {
return state.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Result Type of " << instr_name << " <id> '"
<< state.getIdName(inst->id())
<< "' must be OpTypeInt with width 32 and signedness 0.";
}
auto type_id = inst->GetOperandAs<uint32_t>(2);
auto type = state.FindDef(type_id);
if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
return state.diag(SPV_ERROR_INVALID_ID, inst)
<< "The type in " << instr_name << " <id> '"
<< state.getIdName(type_id)
<< "' must be OpTypeCooperativeMatrixNV.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
const Instruction* inst) {
uint32_t type_id;
const char* opname;
if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
type_id = inst->type_id();
opname = "SpvOpCooperativeMatrixLoadNV";
} else {
// get Object operand's type
type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
opname = "SpvOpCooperativeMatrixStoreNV";
}
auto matrix_type = _.FindDef(type_id);
if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
<< _.getIdName(type_id) << "' is not a cooperative matrix type.";
} else {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "SpvOpCooperativeMatrixStoreNV Object type <id> '"
<< _.getIdName(type_id) << "' is not a cooperative matrix type.";
}
}
const bool uses_variable_pointers =
_.features().variable_pointers ||
_.features().variable_pointers_storage_buffer;
const auto pointer_index =
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
const auto pointer = _.FindDef(pointer_id);
if (!pointer ||
((_.addressing_model() == SpvAddressingModelLogical) &&
((!uses_variable_pointers &&
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
(uses_variable_pointers &&
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a logical pointer.";
}
const auto pointer_type_id = pointer->type_id();
const auto pointer_type = _.FindDef(pointer_type_id);
if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " type for pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a pointer type.";
}
const auto storage_class_index = 1u;
const auto storage_class =
pointer_type->GetOperandAs<uint32_t>(storage_class_index);
if (storage_class != SpvStorageClassWorkgroup &&
storage_class != SpvStorageClassStorageBuffer &&
storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " storage class for pointer type <id> '"
<< _.getIdName(pointer_type_id)
<< "' is not Workgroup or StorageBuffer.";
}
const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
const auto pointee_type = _.FindDef(pointee_id);
if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
_.IsFloatScalarOrVectorType(pointee_id))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opname << " Pointer <id> '" << _.getIdName(pointer->id())
<< "'s Type must be a scalar or vector type.";
}
const auto stride_index =
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
const auto stride = _.FindDef(stride_id);
if (!stride || !_.IsIntScalarType(stride->type_id())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Stride operand <id> '" << _.getIdName(stride_id)
<< "' must be a scalar integer type.";
}
const auto colmajor_index =
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
const auto colmajor = _.FindDef(colmajor_id);
if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
!(spvOpcodeIsConstant(colmajor->opcode()) ||
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Column Major operand <id> '" << _.getIdName(colmajor_id)
<< "' must be a boolean constant instruction.";
}
const auto memory_access_index =
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
if (inst->operands().size() > memory_access_index) {
if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
return error;
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
@ -1164,6 +1354,14 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpArrayLength:
if (auto error = ValidateArrayLength(_, inst)) return error;
break;
case SpvOpCooperativeMatrixLoadNV:
case SpvOpCooperativeMatrixStoreNV:
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
return error;
break;
case SpvOpCooperativeMatrixLengthNV:
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
break;
case SpvOpImageTexelPointer:
case SpvOpGenericPtrMemSemantics:
default:

View File

@ -36,11 +36,19 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _,
}
if (!is_const_int32) {
if (_.HasCapability(SpvCapabilityShader)) {
if (_.HasCapability(SpvCapabilityShader) &&
!_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be OpConstant when Shader capability is "
<< "present";
}
if (_.HasCapability(SpvCapabilityShader) &&
_.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
!spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be constant or specialization constant when "
<< "CooperativeMatrixNV capability is present";
}
return SPV_SUCCESS;
}
@ -130,11 +138,19 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst,
}
if (!is_const_int32) {
if (_.HasCapability(SpvCapabilityShader)) {
if (_.HasCapability(SpvCapabilityShader) &&
!_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be OpConstant when Shader capability is "
<< "present";
}
if (_.HasCapability(SpvCapabilityShader) &&
_.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
!spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be constant or specialization constant when "
<< "CooperativeMatrixNV capability is present";
}
return SPV_SUCCESS;
}

View File

@ -381,6 +381,53 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
return SPV_SUCCESS;
}
spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
const Instruction* inst) {
const auto component_type_index = 1;
const auto component_type_id =
inst->GetOperandAs<uint32_t>(component_type_index);
const auto component_type = _.FindDef(component_type_id);
if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
SpvOpTypeInt != component_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Component Type <id> '"
<< _.getIdName(component_type_id)
<< "' is not a scalar numerical type.";
}
const auto scope_index = 2;
const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
const auto scope = _.FindDef(scope_id);
if (!scope || !_.IsIntScalarType(scope->type_id()) ||
!spvOpcodeIsConstant(scope->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
<< "' is not a constant instruction with scalar integer type.";
}
const auto rows_index = 3;
const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
const auto rows = _.FindDef(rows_id);
if (!rows || !_.IsIntScalarType(rows->type_id()) ||
!spvOpcodeIsConstant(rows->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
<< "' is not a constant instruction with scalar integer type.";
}
const auto cols_index = 4;
const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
const auto cols = _.FindDef(cols_id);
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
!spvOpcodeIsConstant(cols->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
<< "' is not a constant instruction with scalar integer type.";
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
@ -416,6 +463,9 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
case SpvOpTypeForwardPointer:
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
break;
case SpvOpTypeCooperativeMatrixNV:
if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
break;
default:
break;
}

View File

@ -610,6 +610,9 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
case SpvOpTypeMatrix:
return GetComponentType(inst->word(2));
case SpvOpTypeCooperativeMatrixNV:
return inst->word(2);
default:
break;
}
@ -634,6 +637,10 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
case SpvOpTypeMatrix:
return inst->word(3);
case SpvOpTypeCooperativeMatrixNV:
// Actual dimension isn't known, return 0
return 0;
default:
break;
}
@ -862,6 +869,86 @@ bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
return true;
}
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
const Instruction* inst = FindDef(id);
assert(inst);
return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
}
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false;
return IsFloatScalarType(FindDef(id)->word(2));
}
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false;
return IsIntScalarType(FindDef(id)->word(2));
}
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
if (!IsCooperativeMatrixType(id)) return false;
return IsUnsignedIntScalarType(FindDef(id)->word(2));
}
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
const Instruction* inst, uint32_t m1, uint32_t m2) {
const auto m1_type = FindDef(m1);
const auto m2_type = FindDef(m2);
if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix types";
}
uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
m2_is_const_int32 = false;
uint32_t m1_value = 0, m2_value = 0;
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
EvalInt32IfConst(m1_scope_id);
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
EvalInt32IfConst(m2_scope_id);
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected scopes of Matrix and Result Type to be "
<< "identical";
}
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
EvalInt32IfConst(m1_rows_id);
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
EvalInt32IfConst(m2_rows_id);
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected rows of Matrix type and Result Type to be "
<< "identical";
}
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
EvalInt32IfConst(m1_cols_id);
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
EvalInt32IfConst(m2_cols_id);
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected columns of Matrix type and Result Type to be "
<< "identical";
}
return SPV_SUCCESS;
}
uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
size_t operand_index) const {
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
@ -890,7 +977,7 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
}
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
uint32_t id) {
uint32_t id) const {
const Instruction* const inst = FindDef(id);
assert(inst);
const uint32_t type = inst->type_id();

View File

@ -552,6 +552,10 @@ class ValidationState_t {
bool IsBoolVectorType(uint32_t id) const;
bool IsBoolScalarOrVectorType(uint32_t id) const;
bool IsPointerType(uint32_t id) const;
bool IsCooperativeMatrixType(uint32_t id) const;
bool IsFloatCooperativeMatrixType(uint32_t id) const;
bool IsIntCooperativeMatrixType(uint32_t id) const;
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
// Gets value from OpConstant and OpSpecConstant as uint64.
// Returns false on failure (no instruction, wrong instruction, not int).
@ -635,7 +639,7 @@ class ValidationState_t {
// Returns tuple <is_int32, is_const_int32, value>.
// OpSpecConstant* return |is_const_int32| as false since their values cannot
// be relied upon during validation.
std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id);
std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id) const;
// Returns the disassembly string for the given instruction.
std::string Disassemble(const Instruction& inst) const;
@ -643,6 +647,12 @@ class ValidationState_t {
// Returns the disassembly string for the given instruction.
std::string Disassemble(const uint32_t* words, uint16_t num_words) const;
// Returns whether type m1 and type m2 are cooperative matrices with
// the same "shape" (matching scope, rows, cols). If any are specialization
// constants, we assume they can match because we can't prove they don't.
spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst,
uint32_t m1, uint32_t m2);
private:
ValidationState_t(const ValidationState_t&);

View File

@ -1165,6 +1165,150 @@ TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
"vector size of the right operand: OuterProduct"));
}
std::string GenerateCoopMatCode(const std::string& extra_types,
const std::string& main_body) {
const std::string prefix =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_16 = OpConstant %u32 16
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
%f16_1 = OpConstant %f16 1
%f32_1 = OpConstant %f32 1
%u32_1 = OpConstant %u32 1
%s32_1 = OpConstant %s32 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%u32mat_1 = OpConstantComposite %u32mat %u32_1
%s32mat_1 = OpConstantComposite %s32mat %s32_1
%u32_c1 = OpSpecConstant %u32 1
%u32_c2 = OpSpecConstant %u32 2
%f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
%f16matc_1 = OpConstantComposite %f16matc %f16_1
%mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
%mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
%mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
%f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
%f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
%f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
const std::string func_begin =
R"(
%main = OpFunction %void None %func
%main_entry = OpLabel)";
const std::string suffix =
R"(
OpReturn
OpFunctionEnd)";
return prefix + extra_types + func_begin + main_body + suffix;
}
TEST_F(ValidateArithmetics, CoopMatSuccess) {
const std::string body = R"(
%val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
%val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
%val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
%val4 = OpFNegate %f16mat %f16mat_1
%val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
%val6 = OpISub %u32mat %u32mat_1 %u32mat_1
%val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
%val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
%val9 = OpISub %s32mat %s32mat_1 %s32mat_1
%val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
%val11 = OpSNegate %s32mat %s32mat_1
%val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
%val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
%val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
%val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
%val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
)";
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, CoopMatFMulFail) {
const std::string body = R"(
%val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
)";
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Expected floating scalar or vector type as Result Type: FMul"));
}
TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
const std::string body = R"(
%val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
)";
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected scalar operand type to be equal to the component "
"type of the matrix operand: MatrixTimesScalar"));
}
TEST_F(ValidateArithmetics, CoopMatScopeFail) {
const std::string types = R"(
%workgroup = OpConstant %u32 2
%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
)";
const std::string body = R"(
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
)";
CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
}
TEST_F(ValidateArithmetics, CoopMatDimFail) {
const std::string body = R"(
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
)";
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
}
TEST_F(ValidateArithmetics, IAddCarrySuccess) {
const std::string body = R"(
%val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1

View File

@ -1467,6 +1467,83 @@ OpFunctionEnd
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%f32_1 = OpConstant %f32 1
%f16mat_1 = OpConstantComposite %f16mat %f32_1
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("OpConstantComposite Constituent <id> '11[%float_1]' type does "
"not match the Result Type <id> '10[%10]'s component type."));
}
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u32 = OpTypeInt 32 0
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%f32_1 = OpConstant %f32 1
%main = OpFunction %void None %func
%main_entry = OpLabel
%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Expected Constituent type to be equal to the component type"));
}
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
const std::string spirv = R"(
OpCapability Shader

View File

@ -1184,6 +1184,172 @@ TEST_F(ValidateConversion, GenericCastToPtrExplicitPointToDifferentType) {
"GenericCastToPtrExplicit"));
}
TEST_F(ValidateConversion, CoopMatConversionSuccess) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
%u16mat = OpTypeCooperativeMatrixNV %u16 %subgroup %u32_8 %u32_8
%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
%s16mat = OpTypeCooperativeMatrixNV %s16 %subgroup %u32_8 %u32_8
%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
%f16_1 = OpConstant %f16 1
%f32_1 = OpConstant %f32 1
%u16_1 = OpConstant %u16 1
%u32_1 = OpConstant %u32 1
%s16_1 = OpConstant %s16 1
%s32_1 = OpConstant %s32 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%f32mat_1 = OpConstantComposite %f32mat %f32_1
%u16mat_1 = OpConstantComposite %u16mat %u16_1
%u32mat_1 = OpConstantComposite %u32mat %u32_1
%s16mat_1 = OpConstantComposite %s16mat %s16_1
%s32mat_1 = OpConstantComposite %s32mat %s32_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val11 = OpConvertFToU %u16mat %f16mat_1
%val12 = OpConvertFToU %u32mat %f16mat_1
%val13 = OpConvertFToS %s16mat %f16mat_1
%val14 = OpConvertFToS %s32mat %f16mat_1
%val15 = OpFConvert %f32mat %f16mat_1
%val21 = OpConvertFToU %u16mat %f32mat_1
%val22 = OpConvertFToU %u32mat %f32mat_1
%val23 = OpConvertFToS %s16mat %f32mat_1
%val24 = OpConvertFToS %s32mat %f32mat_1
%val25 = OpFConvert %f16mat %f32mat_1
%val31 = OpConvertUToF %f16mat %u16mat_1
%val32 = OpConvertUToF %f32mat %u16mat_1
%val33 = OpUConvert %u32mat %u16mat_1
%val34 = OpSConvert %s32mat %u16mat_1
%val41 = OpConvertSToF %f16mat %s16mat_1
%val42 = OpConvertSToF %f32mat %s16mat_1
%val43 = OpUConvert %u32mat %s16mat_1
%val44 = OpSConvert %s32mat %s16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchFail) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_4 = OpConstant %u32 4
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
%f16_1 = OpConstant %f16 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val15 = OpFConvert %f32mat %f16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Expected rows of Matrix type and Result Type to be identical"));
}
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
%f16 = OpTypeFloat 16
%f32 = OpTypeFloat 32
%u16 = OpTypeInt 16 0
%u32 = OpTypeInt 32 0
%s16 = OpTypeInt 16 1
%s32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%u32_4 = OpSpecConstant %u32 4
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
%f16_1 = OpConstant %f16 1
%f16mat_1 = OpConstantComposite %f16mat %f16_1
%main = OpFunction %void None %func
%main_entry = OpLabel
%val15 = OpFConvert %f32mat %f16mat_1
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateConversion, BitcastSuccess) {
const std::string body = R"(
%ptr = OpVariable %f32ptr_func Function

View File

@ -1774,6 +1774,372 @@ OpFunctionEnd
HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable"));
}
std::string GenCoopMatLoadStoreShader(const std::string& storeMemoryAccess,
const std::string& loadMemoryAccess) {
std::string s = R"(
OpCapability Shader
OpCapability GroupNonUniform
OpCapability VulkanMemoryModelKHR
OpCapability CooperativeMatrixNV
OpExtension "SPV_KHR_vulkan_memory_model"
OpExtension "SPV_NV_cooperative_matrix"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical VulkanKHR
OpEntryPoint GLCompute %4 "main" %11 %21
OpExecutionMode %4 LocalSize 1 1 1
OpDecorate %11 BuiltIn SubgroupId
OpDecorate %21 BuiltIn WorkgroupId
OpDecorate %74 ArrayStride 4
OpMemberDecorate %75 0 Offset 0
OpDecorate %75 Block
OpDecorate %77 DescriptorSet 0
OpDecorate %77 Binding 0
OpDecorate %92 ArrayStride 4
OpMemberDecorate %93 0 Offset 0
OpDecorate %93 Block
OpDecorate %95 DescriptorSet 0
OpDecorate %95 Binding 1
OpDecorate %102 ArrayStride 4
OpMemberDecorate %103 0 Offset 0
OpDecorate %103 Block
OpDecorate %105 DescriptorSet 0
OpDecorate %105 Binding 2
OpDecorate %117 ArrayStride 4
OpMemberDecorate %118 0 Offset 0
OpDecorate %118 Block
OpDecorate %120 DescriptorSet 0
OpDecorate %120 Binding 3
OpDecorate %123 SpecId 2
OpDecorate %124 SpecId 3
OpDecorate %125 SpecId 4
OpDecorate %126 SpecId 5
OpDecorate %127 SpecId 0
OpDecorate %128 SpecId 1
OpDecorate %129 BuiltIn WorkgroupSize
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 0
%7 = OpTypeVector %6 2
%8 = OpTypePointer Function %7
%10 = OpTypePointer Input %6
%11 = OpVariable %10 Input
%13 = OpConstant %6 2
%19 = OpTypeVector %6 3
%20 = OpTypePointer Input %19
%21 = OpVariable %20 Input
%27 = OpConstantComposite %7 %13 %13
%31 = OpTypePointer Function %6
%33 = OpConstant %6 1024
%34 = OpConstant %6 1
%38 = OpConstant %6 8
%39 = OpConstant %6 0
%68 = OpTypeFloat 32
%69 = OpConstant %6 16
%70 = OpConstant %6 3
%71 = OpTypeCooperativeMatrixNV %68 %70 %69 %38
%72 = OpTypePointer Function %71
%74 = OpTypeRuntimeArray %68
%75 = OpTypeStruct %74
%76 = OpTypePointer StorageBuffer %75
%77 = OpVariable %76 StorageBuffer
%78 = OpTypeInt 32 1
%79 = OpConstant %78 0
%81 = OpConstant %6 5
%82 = OpTypePointer StorageBuffer %68
%84 = OpConstant %6 64
%85 = OpTypeBool
%86 = OpConstantFalse %85
%88 = OpTypePointer Private %71
%89 = OpVariable %88 Private
%92 = OpTypeRuntimeArray %68
%93 = OpTypeStruct %92
%94 = OpTypePointer StorageBuffer %93
%95 = OpVariable %94 StorageBuffer
%99 = OpVariable %88 Private
%102 = OpTypeRuntimeArray %68
%103 = OpTypeStruct %102
%104 = OpTypePointer StorageBuffer %103
%105 = OpVariable %104 StorageBuffer
%109 = OpVariable %88 Private
%111 = OpVariable %88 Private
%112 = OpSpecConstantOp %6 CooperativeMatrixLengthNV %71
%113 = OpSpecConstantOp %78 IAdd %112 %79
%117 = OpTypeRuntimeArray %68
%118 = OpTypeStruct %117
%119 = OpTypePointer StorageBuffer %118
%120 = OpVariable %119 StorageBuffer
%123 = OpSpecConstant %78 1
%124 = OpSpecConstant %78 1
%125 = OpSpecConstant %78 1
%126 = OpSpecConstant %78 1
%127 = OpSpecConstant %6 1
%128 = OpSpecConstant %6 1
%129 = OpSpecConstantComposite %19 %127 %128 %34
%4 = OpFunction %2 None %3
%5 = OpLabel
%9 = OpVariable %8 Function
%18 = OpVariable %8 Function
%32 = OpVariable %31 Function
%44 = OpVariable %31 Function
%52 = OpVariable %31 Function
%60 = OpVariable %31 Function
%73 = OpVariable %72 Function
%91 = OpVariable %72 Function
%101 = OpVariable %72 Function
%12 = OpLoad %6 %11
%14 = OpUMod %6 %12 %13
%15 = OpLoad %6 %11
%16 = OpUDiv %6 %15 %13
%17 = OpCompositeConstruct %7 %14 %16
OpStore %9 %17
%22 = OpLoad %19 %21
%23 = OpVectorShuffle %7 %22 %22 0 1
%24 = OpCompositeExtract %6 %23 0
%25 = OpCompositeExtract %6 %23 1
%26 = OpCompositeConstruct %7 %24 %25
%28 = OpIMul %7 %26 %27
%29 = OpLoad %7 %9
%30 = OpIAdd %7 %28 %29
OpStore %18 %30
%35 = OpAccessChain %31 %18 %34
%36 = OpLoad %6 %35
%37 = OpIMul %6 %33 %36
%40 = OpAccessChain %31 %18 %39
%41 = OpLoad %6 %40
%42 = OpIMul %6 %38 %41
%43 = OpIAdd %6 %37 %42
OpStore %32 %43
%45 = OpAccessChain %31 %18 %34
%46 = OpLoad %6 %45
%47 = OpIMul %6 %33 %46
%48 = OpAccessChain %31 %18 %39
%49 = OpLoad %6 %48
%50 = OpIMul %6 %38 %49
%51 = OpIAdd %6 %47 %50
OpStore %44 %51
%53 = OpAccessChain %31 %18 %34
%54 = OpLoad %6 %53
%55 = OpIMul %6 %33 %54
%56 = OpAccessChain %31 %18 %39
%57 = OpLoad %6 %56
%58 = OpIMul %6 %38 %57
%59 = OpIAdd %6 %55 %58
OpStore %52 %59
%61 = OpAccessChain %31 %18 %34
%62 = OpLoad %6 %61
%63 = OpIMul %6 %33 %62
%64 = OpAccessChain %31 %18 %39
%65 = OpLoad %6 %64
%66 = OpIMul %6 %38 %65
%67 = OpIAdd %6 %63 %66
OpStore %60 %67
%80 = OpLoad %6 %32
%83 = OpAccessChain %82 %77 %79 %80
%87 = OpCooperativeMatrixLoadNV %71 %83 %84 %86 )" +
loadMemoryAccess + R"( %81
OpStore %73 %87
%90 = OpLoad %71 %73
OpStore %89 %90
%96 = OpLoad %6 %44
%97 = OpAccessChain %82 %95 %79 %96
%98 = OpCooperativeMatrixLoadNV %71 %97 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %91 %98
%100 = OpLoad %71 %91
OpStore %99 %100
%106 = OpLoad %6 %52
%107 = OpAccessChain %82 %105 %79 %106
%108 = OpCooperativeMatrixLoadNV %71 %107 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %101 %108
%110 = OpLoad %71 %101
OpStore %109 %110
%114 = OpConvertSToF %68 %113
%115 = OpCompositeConstruct %71 %114
OpStore %111 %115
%116 = OpLoad %71 %111
%121 = OpLoad %6 %60
%122 = OpAccessChain %82 %120 %79 %121
OpCooperativeMatrixStoreNV %122 %116 %84 %86 )" + storeMemoryAccess + R"( %81
OpReturn
OpFunctionEnd
)";
return s;
}
TEST_F(ValidateMemory, CoopMatLoadStoreSuccess) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}
TEST_F(ValidateMemory, CoopMatStoreMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
}
TEST_F(ValidateMemory, CoopMatLoadMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerAvailableKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
}
TEST_F(ValidateMemory, CoopMatInvalidStorageClassFail) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%str = OpTypeStruct %f16mat
%str_ptr = OpTypePointer Workgroup %str
%sh = OpVariable %str_ptr Workgroup
%main = OpFunction %void None %func
%main_entry = OpLabel
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Cooperative matrix types (or types containing them) can only be "
"allocated in Function or Private storage classes or as function "
"parameters"));
}
TEST_F(ValidateMemory, CoopMatMatrixLengthResultTypeBad) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthNV %i32 %f16mat
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("The Result Type of OpCooperativeMatrixLengthNV <id> "
"'11[%11]' must be OpTypeInt with width 32 and signedness 0"));
}
TEST_F(ValidateMemory, CoopMatMatrixLengthOperandTypeBad) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthNV %u32 %u32
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("The type in OpCooperativeMatrixLengthNV <id> '5[%uint]' "
"must be OpTypeCooperativeMatrixNV"));
}
TEST_F(ValidateMemory, CoopMatMatrixLengthGood) {
const std::string body =
R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
OpExtension "SPV_NV_cooperative_matrix"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
%void = OpTypeVoid
%func = OpTypeFunction %void
%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%i32 = OpTypeInt 32 1
%u32_8 = OpConstant %u32 8
%subgroup = OpConstant %u32 3
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
%main = OpFunction %void None %func
%main_entry = OpLabel
%1 = OpCooperativeMatrixLengthNV %u32 %f16mat
OpReturn
OpFunctionEnd)";
CompileSuccessfully(body.c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
std::string spirv = R"(
OpCapability Shader