mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-11 09:00:06 +00:00
Add validation for SPV_NV_cooperative_matrix (#2404)
This commit is contained in:
parent
fc3897b5f5
commit
002ef361ca
2
DEPS
2
DEPS
@ -11,7 +11,7 @@ vars = {
|
||||
'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036',
|
||||
'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59',
|
||||
're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f',
|
||||
'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f',
|
||||
'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab',
|
||||
}
|
||||
|
||||
deps = {
|
||||
|
@ -154,10 +154,11 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
|
||||
CASE(InBoundsAccessChain),
|
||||
CASE(PtrAccessChain),
|
||||
CASE(InBoundsPtrAccessChain),
|
||||
CASE(CooperativeMatrixLengthNV)
|
||||
};
|
||||
|
||||
// The 59 is determined by counting the opcodes listed in the spec.
|
||||
static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
|
||||
// The 60 is determined by counting the opcodes listed in the spec.
|
||||
static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
|
||||
"OpSpecConstantOp opcode table is incomplete");
|
||||
#undef CASE
|
||||
// clang-format on
|
||||
|
@ -260,6 +260,7 @@ int32_t spvOpcodeIsComposite(const SpvOp opcode) {
|
||||
case SpvOpTypeMatrix:
|
||||
case SpvOpTypeArray:
|
||||
case SpvOpTypeStruct:
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -325,6 +326,7 @@ int32_t spvOpcodeGeneratesType(SpvOp op) {
|
||||
case SpvOpTypePipeStorage:
|
||||
case SpvOpTypeNamedBarrier:
|
||||
case SpvOpTypeAccelerationStructureNV:
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
return true;
|
||||
default:
|
||||
// In particular, OpTypeForwardPointer does not generate a type,
|
||||
|
@ -45,7 +45,8 @@ inline bool IsAnnotationInst(SpvOp opcode) {
|
||||
inline bool IsTypeInst(SpvOp opcode) {
|
||||
return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) ||
|
||||
opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier ||
|
||||
opcode == SpvOpTypeAccelerationStructureNV;
|
||||
opcode == SpvOpTypeAccelerationStructureNV ||
|
||||
opcode == SpvOpTypeCooperativeMatrixNV;
|
||||
}
|
||||
inline bool IsConstantInst(SpvOp opcode) {
|
||||
return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp;
|
||||
|
@ -39,8 +39,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
case SpvOpFRem:
|
||||
case SpvOpFMod:
|
||||
case SpvOpFNegate: {
|
||||
bool supportsCoopMat =
|
||||
(opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
|
||||
if (!_.IsFloatScalarType(result_type) &&
|
||||
!_.IsFloatVectorType(result_type))
|
||||
!_.IsFloatVectorType(result_type) &&
|
||||
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected floating scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
@ -58,8 +61,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
|
||||
case SpvOpUDiv:
|
||||
case SpvOpUMod: {
|
||||
bool supportsCoopMat = (opcode == SpvOpUDiv);
|
||||
if (!_.IsUnsignedIntScalarType(result_type) &&
|
||||
!_.IsUnsignedIntVectorType(result_type))
|
||||
!_.IsUnsignedIntVectorType(result_type) &&
|
||||
!(supportsCoopMat &&
|
||||
_.IsUnsignedIntCooperativeMatrixType(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected unsigned int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
@ -82,7 +88,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
case SpvOpSMod:
|
||||
case SpvOpSRem:
|
||||
case SpvOpSNegate: {
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
|
||||
bool supportsCoopMat =
|
||||
(opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
|
||||
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
@ -94,7 +103,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
++operand_index) {
|
||||
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
|
||||
if (!type_id ||
|
||||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
|
||||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
|
||||
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected int scalar or vector type as operand: "
|
||||
<< spvOpcodeString(opcode) << " operand index "
|
||||
@ -176,7 +186,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesScalar: {
|
||||
if (!_.IsFloatMatrixType(result_type))
|
||||
if (!_.IsFloatMatrixType(result_type) &&
|
||||
!_.IsCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected float matrix type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
@ -442,6 +453,92 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpCooperativeMatrixMulAddNV: {
|
||||
const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
|
||||
const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
|
||||
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
|
||||
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
|
||||
|
||||
if (!_.IsCooperativeMatrixType(A_type_id)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected cooperative matrix type as A Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
if (!_.IsCooperativeMatrixType(B_type_id)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected cooperative matrix type as B Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
if (!_.IsCooperativeMatrixType(C_type_id)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected cooperative matrix type as C Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
if (!_.IsCooperativeMatrixType(D_type_id)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected cooperative matrix type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
const auto A = _.FindDef(A_type_id);
|
||||
const auto B = _.FindDef(B_type_id);
|
||||
const auto C = _.FindDef(C_type_id);
|
||||
const auto D = _.FindDef(D_type_id);
|
||||
|
||||
std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
|
||||
A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
|
||||
|
||||
A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
|
||||
B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
|
||||
C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
|
||||
D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
|
||||
|
||||
A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
|
||||
B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
|
||||
C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
|
||||
D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
|
||||
|
||||
A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
|
||||
B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
|
||||
C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
|
||||
D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
|
||||
|
||||
const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
|
||||
std::tuple<bool, bool, uint32_t> Y) {
|
||||
return (std::get<1>(X) && std::get<1>(Y) &&
|
||||
std::get<2>(X) != std::get<2>(Y));
|
||||
};
|
||||
|
||||
if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
|
||||
notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
|
||||
notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Cooperative matrix scopes must match: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
|
||||
notEqual(C_rows, D_rows)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Cooperative matrix 'M' mismatch: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
|
||||
notEqual(C_cols, D_cols)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Cooperative matrix 'N' mismatch: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (notEqual(A_cols, B_rows)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Cooperative matrix 'K' mismatch: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -118,6 +118,10 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
|
||||
*member_type = type_inst->word(component_index + 2);
|
||||
break;
|
||||
}
|
||||
case SpvOpTypeCooperativeMatrixNV: {
|
||||
*member_type = type_inst->word(2);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Reached non-composite type while indexes still remain to "
|
||||
@ -315,6 +319,26 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
|
||||
|
||||
break;
|
||||
}
|
||||
case SpvOpTypeCooperativeMatrixNV: {
|
||||
const auto result_type_inst = _.FindDef(result_type);
|
||||
assert(result_type_inst);
|
||||
const auto component_type_id =
|
||||
result_type_inst->GetOperandAs<uint32_t>(1);
|
||||
|
||||
if (3 != num_operands) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected single constituent";
|
||||
}
|
||||
|
||||
const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
|
||||
|
||||
if (operand_type_id != component_type_id) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected Constituent type to be equal to the component type";
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected Result Type to be a composite type";
|
||||
|
@ -247,6 +247,36 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case SpvOpTypeCooperativeMatrixNV: {
|
||||
if (1 != constituent_count) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opcode_name << " Constituent <id> '"
|
||||
<< _.getIdName(inst->type_id()) << "' count must be one.";
|
||||
}
|
||||
const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
|
||||
const auto constituent = _.FindDef(constituent_id);
|
||||
if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opcode_name << " Constituent <id> '"
|
||||
<< _.getIdName(constituent_id)
|
||||
<< "' is not a constant or undef.";
|
||||
}
|
||||
const auto constituent_type = _.FindDef(constituent->type_id());
|
||||
if (!constituent_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, constituent)
|
||||
<< "Result type is not defined.";
|
||||
}
|
||||
|
||||
const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
|
||||
const auto component_type = _.FindDef(component_type_id);
|
||||
if (!component_type || component_type->id() != constituent_type->id()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opcode_name << " Constituent <id> '"
|
||||
<< _.getIdName(constituent_id)
|
||||
<< "' type does not match the Result Type <id> '"
|
||||
<< _.getIdName(result_type->id()) << "'s component type.";
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -285,6 +315,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
|
||||
return true;
|
||||
case SpvOpTypeArray:
|
||||
case SpvOpTypeMatrix:
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
case SpvOpTypeVector: {
|
||||
auto base_type = _.FindDef(instruction[2]);
|
||||
return base_type && IsTypeNullable(base_type->words(), _);
|
||||
@ -320,7 +351,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
|
||||
|
||||
// The binary parser already ensures that the op is valid for *some*
|
||||
// environment. Here we check restrictions.
|
||||
switch(op) {
|
||||
switch (op) {
|
||||
case SpvOpQuantizeToF16:
|
||||
if (!_.HasCapability(SpvCapabilityShader)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
@ -365,7 +396,7 @@ spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -32,22 +32,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
switch (opcode) {
|
||||
case SpvOpConvertFToU: {
|
||||
if (!_.IsUnsignedIntScalarType(result_type) &&
|
||||
!_.IsUnsignedIntVectorType(result_type))
|
||||
!_.IsUnsignedIntVectorType(result_type) &&
|
||||
!_.IsUnsignedIntCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected unsigned int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type || (!_.IsFloatScalarType(input_type) &&
|
||||
!_.IsFloatVectorType(input_type)))
|
||||
!_.IsFloatVectorType(input_type) &&
|
||||
!_.IsFloatCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be float scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@ -58,22 +67,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
case SpvOpConvertFToS: {
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
|
||||
!_.IsIntCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type || (!_.IsFloatScalarType(input_type) &&
|
||||
!_.IsFloatVectorType(input_type)))
|
||||
!_.IsFloatVectorType(input_type) &&
|
||||
!_.IsFloatCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be float scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@ -86,22 +104,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
case SpvOpConvertSToF:
|
||||
case SpvOpConvertUToF: {
|
||||
if (!_.IsFloatScalarType(result_type) &&
|
||||
!_.IsFloatVectorType(result_type))
|
||||
!_.IsFloatVectorType(result_type) &&
|
||||
!_.IsFloatCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected float scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type ||
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
|
||||
!_.IsIntCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be int scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@ -113,22 +140,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
|
||||
case SpvOpUConvert: {
|
||||
if (!_.IsUnsignedIntScalarType(result_type) &&
|
||||
!_.IsUnsignedIntVectorType(result_type))
|
||||
!_.IsUnsignedIntVectorType(result_type) &&
|
||||
!_.IsUnsignedIntCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected unsigned int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type ||
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
|
||||
!_.IsIntCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be int scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@ -139,22 +175,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
case SpvOpSConvert: {
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
|
||||
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
|
||||
!_.IsIntCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected int scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type ||
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
|
||||
(!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
|
||||
!_.IsIntCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be int scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@ -166,22 +211,31 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
|
||||
case SpvOpFConvert: {
|
||||
if (!_.IsFloatScalarType(result_type) &&
|
||||
!_.IsFloatVectorType(result_type))
|
||||
!_.IsFloatVectorType(result_type) &&
|
||||
!_.IsFloatCooperativeMatrixType(result_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected float scalar or vector type as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
|
||||
if (!input_type || (!_.IsFloatScalarType(input_type) &&
|
||||
!_.IsFloatVectorType(input_type)))
|
||||
!_.IsFloatVectorType(input_type) &&
|
||||
!_.IsFloatCooperativeMatrixType(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be float scalar or vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
if (_.IsCooperativeMatrixType(result_type) ||
|
||||
_.IsCooperativeMatrixType(input_type)) {
|
||||
spv_result_t ret =
|
||||
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
|
||||
if (ret != SPV_SUCCESS) return ret;
|
||||
} else {
|
||||
if (_.GetDimension(result_type) != _.GetDimension(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
|
@ -167,7 +167,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
|
||||
const auto opcode = inst->opcode();
|
||||
if (spvOpcodeGeneratesType(def->opcode()) &&
|
||||
!spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
|
||||
!spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) {
|
||||
!spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction &&
|
||||
opcode != SpvOpCooperativeMatrixLengthNV &&
|
||||
!(opcode == SpvOpSpecConstantOp &&
|
||||
inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Operand " << _.getIdName(operand_word)
|
||||
<< " cannot be a type";
|
||||
@ -177,7 +180,10 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
|
||||
!spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi &&
|
||||
opcode != SpvOpExtInst && opcode != SpvOpExtInstImport &&
|
||||
opcode != SpvOpSelectionMerge &&
|
||||
opcode != SpvOpLoopMerge && opcode != SpvOpFunction) {
|
||||
opcode != SpvOpLoopMerge && opcode != SpvOpFunction &&
|
||||
opcode != SpvOpCooperativeMatrixLengthNV &&
|
||||
!(opcode == SpvOpSpecConstantOp &&
|
||||
inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Operand " << _.getIdName(operand_word)
|
||||
<< " requires a type";
|
||||
|
@ -197,17 +197,49 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ContainsCooperativeMatrix(ValidationState_t& _,
|
||||
const Instruction* storage) {
|
||||
const size_t elem_type_index = 1;
|
||||
uint32_t elem_type_id;
|
||||
Instruction* elem_type;
|
||||
|
||||
switch (storage->opcode()) {
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
return true;
|
||||
case SpvOpTypeArray:
|
||||
case SpvOpTypeRuntimeArray:
|
||||
elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
|
||||
elem_type = _.FindDef(elem_type_id);
|
||||
return ContainsCooperativeMatrix(_, elem_type);
|
||||
case SpvOpTypeStruct:
|
||||
for (size_t member_type_index = 1;
|
||||
member_type_index < storage->operands().size();
|
||||
++member_type_index) {
|
||||
auto member_type_id =
|
||||
storage->GetOperandAs<uint32_t>(member_type_index);
|
||||
auto member_type = _.FindDef(member_type_id);
|
||||
if (ContainsCooperativeMatrix(_, member_type)) return true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
|
||||
ValidationState_t& _, const Instruction* inst) {
|
||||
SpvStorageClass dst_sc = SpvStorageClassMax;
|
||||
SpvStorageClass src_sc = SpvStorageClassMax;
|
||||
switch (inst->opcode()) {
|
||||
case SpvOpCooperativeMatrixLoadNV:
|
||||
case SpvOpLoad: {
|
||||
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||
auto load_pointer_type = _.FindDef(load_pointer->type_id());
|
||||
dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
|
||||
break;
|
||||
}
|
||||
case SpvOpCooperativeMatrixStoreNV:
|
||||
case SpvOpStore: {
|
||||
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
||||
auto store_pointer_type = _.FindDef(store_pointer->type_id());
|
||||
@ -232,7 +264,8 @@ std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
|
||||
}
|
||||
|
||||
// This function is only called for OpLoad, OpStore, OpCopyMemory and
|
||||
// OpCopyMemorySized.
|
||||
// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
|
||||
// OpCooperativeMatrixStoreNV.
|
||||
uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
||||
uint32_t offset = 1;
|
||||
if (mask & SpvMemoryAccessAlignedMask) ++offset;
|
||||
@ -245,6 +278,10 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
||||
case SpvOpStore:
|
||||
case SpvOpCopyMemory:
|
||||
return inst->GetOperandAs<uint32_t>(2 + offset);
|
||||
case SpvOpCooperativeMatrixLoadNV:
|
||||
return inst->GetOperandAs<uint32_t>(5 + offset);
|
||||
case SpvOpCooperativeMatrixStoreNV:
|
||||
return inst->GetOperandAs<uint32_t>(4 + offset);
|
||||
default:
|
||||
assert(false && "unexpected opcode");
|
||||
break;
|
||||
@ -253,8 +290,9 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
|
||||
return scope_id;
|
||||
}
|
||||
|
||||
// This function is only called for OpLoad, OpStore, OpCopyMemory and
|
||||
// OpCopyMemorySized.
|
||||
// This function is only called for OpLoad, OpStore, OpCopyMemory,
|
||||
// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
|
||||
// OpCooperativeMatrixStoreNV.
|
||||
uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
|
||||
uint32_t offset = 1;
|
||||
if (mask & SpvMemoryAccessAlignedMask) ++offset;
|
||||
@ -268,6 +306,10 @@ uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
|
||||
case SpvOpStore:
|
||||
case SpvOpCopyMemory:
|
||||
return inst->GetOperandAs<uint32_t>(2 + offset);
|
||||
case SpvOpCooperativeMatrixLoadNV:
|
||||
return inst->GetOperandAs<uint32_t>(5 + offset);
|
||||
case SpvOpCooperativeMatrixStoreNV:
|
||||
return inst->GetOperandAs<uint32_t>(4 + offset);
|
||||
default:
|
||||
assert(false && "unexpected opcode");
|
||||
break;
|
||||
@ -302,7 +344,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
||||
|
||||
uint32_t mask = inst->GetOperandAs<uint32_t>(index);
|
||||
if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
|
||||
if (inst->opcode() == SpvOpLoad) {
|
||||
if (inst->opcode() == SpvOpLoad ||
|
||||
inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
|
||||
}
|
||||
@ -320,7 +363,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
||||
}
|
||||
|
||||
if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
|
||||
if (inst->opcode() == SpvOpStore) {
|
||||
if (inst->opcode() == SpvOpStore ||
|
||||
inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "MakePointerVisibleKHR cannot be used with OpStore.";
|
||||
}
|
||||
@ -672,6 +716,17 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
}
|
||||
|
||||
// Cooperative matrix types can only be allocated in Function or Private
|
||||
if ((storage_class != SpvStorageClassFunction &&
|
||||
storage_class != SpvStorageClassPrivate) &&
|
||||
ContainsCooperativeMatrix(_, pointee)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Cooperative matrix types (or types containing them) can only be "
|
||||
"allocated "
|
||||
<< "in Function or Private storage classes or as function "
|
||||
"parameters";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
@ -1003,10 +1058,11 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
|
||||
switch (type_pointee->opcode()) {
|
||||
case SpvOpTypeMatrix:
|
||||
case SpvOpTypeVector:
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
case SpvOpTypeArray:
|
||||
case SpvOpTypeRuntimeArray: {
|
||||
// In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
|
||||
// word 2 is the Element Type.
|
||||
// In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
|
||||
// OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
|
||||
type_pointee = _.FindDef(type_pointee->word(2));
|
||||
break;
|
||||
}
|
||||
@ -1136,6 +1192,140 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
|
||||
const Instruction* inst) {
|
||||
std::string instr_name =
|
||||
"Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
|
||||
|
||||
// Result type must be a 32-bit unsigned int.
|
||||
auto result_type = state.FindDef(inst->type_id());
|
||||
if (result_type->opcode() != SpvOpTypeInt ||
|
||||
result_type->GetOperandAs<uint32_t>(1) != 32 ||
|
||||
result_type->GetOperandAs<uint32_t>(2) != 0) {
|
||||
return state.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "The Result Type of " << instr_name << " <id> '"
|
||||
<< state.getIdName(inst->id())
|
||||
<< "' must be OpTypeInt with width 32 and signedness 0.";
|
||||
}
|
||||
|
||||
auto type_id = inst->GetOperandAs<uint32_t>(2);
|
||||
auto type = state.FindDef(type_id);
|
||||
if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
||||
return state.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "The type in " << instr_name << " <id> '"
|
||||
<< state.getIdName(type_id)
|
||||
<< "' must be OpTypeCooperativeMatrixNV.";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
uint32_t type_id;
|
||||
const char* opname;
|
||||
if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
||||
type_id = inst->type_id();
|
||||
opname = "SpvOpCooperativeMatrixLoadNV";
|
||||
} else {
|
||||
// get Object operand's type
|
||||
type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
|
||||
opname = "SpvOpCooperativeMatrixStoreNV";
|
||||
}
|
||||
|
||||
auto matrix_type = _.FindDef(type_id);
|
||||
|
||||
if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
||||
if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
|
||||
<< _.getIdName(type_id) << "' is not a cooperative matrix type.";
|
||||
} else {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "SpvOpCooperativeMatrixStoreNV Object type <id> '"
|
||||
<< _.getIdName(type_id) << "' is not a cooperative matrix type.";
|
||||
}
|
||||
}
|
||||
|
||||
const bool uses_variable_pointers =
|
||||
_.features().variable_pointers ||
|
||||
_.features().variable_pointers_storage_buffer;
|
||||
const auto pointer_index =
|
||||
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
|
||||
const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
|
||||
const auto pointer = _.FindDef(pointer_id);
|
||||
if (!pointer ||
|
||||
((_.addressing_model() == SpvAddressingModelLogical) &&
|
||||
((!uses_variable_pointers &&
|
||||
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
|
||||
(uses_variable_pointers &&
|
||||
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opname << " Pointer <id> '" << _.getIdName(pointer_id)
|
||||
<< "' is not a logical pointer.";
|
||||
}
|
||||
|
||||
const auto pointer_type_id = pointer->type_id();
|
||||
const auto pointer_type = _.FindDef(pointer_type_id);
|
||||
if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opname << " type for pointer <id> '" << _.getIdName(pointer_id)
|
||||
<< "' is not a pointer type.";
|
||||
}
|
||||
|
||||
const auto storage_class_index = 1u;
|
||||
const auto storage_class =
|
||||
pointer_type->GetOperandAs<uint32_t>(storage_class_index);
|
||||
|
||||
if (storage_class != SpvStorageClassWorkgroup &&
|
||||
storage_class != SpvStorageClassStorageBuffer &&
|
||||
storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opname << " storage class for pointer type <id> '"
|
||||
<< _.getIdName(pointer_type_id)
|
||||
<< "' is not Workgroup or StorageBuffer.";
|
||||
}
|
||||
|
||||
const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
|
||||
const auto pointee_type = _.FindDef(pointee_id);
|
||||
if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
|
||||
_.IsFloatScalarOrVectorType(pointee_id))) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< opname << " Pointer <id> '" << _.getIdName(pointer->id())
|
||||
<< "'s Type must be a scalar or vector type.";
|
||||
}
|
||||
|
||||
const auto stride_index =
|
||||
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
|
||||
const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
|
||||
const auto stride = _.FindDef(stride_id);
|
||||
if (!stride || !_.IsIntScalarType(stride->type_id())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Stride operand <id> '" << _.getIdName(stride_id)
|
||||
<< "' must be a scalar integer type.";
|
||||
}
|
||||
|
||||
const auto colmajor_index =
|
||||
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
|
||||
const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
|
||||
const auto colmajor = _.FindDef(colmajor_id);
|
||||
if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
|
||||
!(spvOpcodeIsConstant(colmajor->opcode()) ||
|
||||
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Column Major operand <id> '" << _.getIdName(colmajor_id)
|
||||
<< "' must be a boolean constant instruction.";
|
||||
}
|
||||
|
||||
const auto memory_access_index =
|
||||
(inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
|
||||
if (inst->operands().size() > memory_access_index) {
|
||||
if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
|
||||
return error;
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
|
||||
@ -1164,6 +1354,14 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
|
||||
case SpvOpArrayLength:
|
||||
if (auto error = ValidateArrayLength(_, inst)) return error;
|
||||
break;
|
||||
case SpvOpCooperativeMatrixLoadNV:
|
||||
case SpvOpCooperativeMatrixStoreNV:
|
||||
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
|
||||
return error;
|
||||
break;
|
||||
case SpvOpCooperativeMatrixLengthNV:
|
||||
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
|
||||
break;
|
||||
case SpvOpImageTexelPointer:
|
||||
case SpvOpGenericPtrMemSemantics:
|
||||
default:
|
||||
|
@ -36,11 +36,19 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _,
|
||||
}
|
||||
|
||||
if (!is_const_int32) {
|
||||
if (_.HasCapability(SpvCapabilityShader)) {
|
||||
if (_.HasCapability(SpvCapabilityShader) &&
|
||||
!_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Scope ids must be OpConstant when Shader capability is "
|
||||
<< "present";
|
||||
}
|
||||
if (_.HasCapability(SpvCapabilityShader) &&
|
||||
_.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
|
||||
!spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Scope ids must be constant or specialization constant when "
|
||||
<< "CooperativeMatrixNV capability is present";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
@ -130,11 +138,19 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst,
|
||||
}
|
||||
|
||||
if (!is_const_int32) {
|
||||
if (_.HasCapability(SpvCapabilityShader)) {
|
||||
if (_.HasCapability(SpvCapabilityShader) &&
|
||||
!_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Scope ids must be OpConstant when Shader capability is "
|
||||
<< "present";
|
||||
}
|
||||
if (_.HasCapability(SpvCapabilityShader) &&
|
||||
_.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
|
||||
!spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Scope ids must be constant or specialization constant when "
|
||||
<< "CooperativeMatrixNV capability is present";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -381,6 +381,53 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
const auto component_type_index = 1;
|
||||
const auto component_type_id =
|
||||
inst->GetOperandAs<uint32_t>(component_type_index);
|
||||
const auto component_type = _.FindDef(component_type_id);
|
||||
if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
|
||||
SpvOpTypeInt != component_type->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpTypeCooperativeMatrixNV Component Type <id> '"
|
||||
<< _.getIdName(component_type_id)
|
||||
<< "' is not a scalar numerical type.";
|
||||
}
|
||||
|
||||
const auto scope_index = 2;
|
||||
const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
|
||||
const auto scope = _.FindDef(scope_id);
|
||||
if (!scope || !_.IsIntScalarType(scope->type_id()) ||
|
||||
!spvOpcodeIsConstant(scope->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
|
||||
<< "' is not a constant instruction with scalar integer type.";
|
||||
}
|
||||
|
||||
const auto rows_index = 3;
|
||||
const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
|
||||
const auto rows = _.FindDef(rows_id);
|
||||
if (!rows || !_.IsIntScalarType(rows->type_id()) ||
|
||||
!spvOpcodeIsConstant(rows->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
|
||||
<< "' is not a constant instruction with scalar integer type.";
|
||||
}
|
||||
|
||||
const auto cols_index = 4;
|
||||
const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
|
||||
const auto cols = _.FindDef(cols_id);
|
||||
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
|
||||
!spvOpcodeIsConstant(cols->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
|
||||
<< "' is not a constant instruction with scalar integer type.";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
@ -416,6 +463,9 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
case SpvOpTypeForwardPointer:
|
||||
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
|
||||
break;
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -610,6 +610,9 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
|
||||
case SpvOpTypeMatrix:
|
||||
return GetComponentType(inst->word(2));
|
||||
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
return inst->word(2);
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -634,6 +637,10 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
|
||||
case SpvOpTypeMatrix:
|
||||
return inst->word(3);
|
||||
|
||||
case SpvOpTypeCooperativeMatrixNV:
|
||||
// Actual dimension isn't known, return 0
|
||||
return 0;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -862,6 +869,86 @@ bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
assert(inst);
|
||||
return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
|
||||
if (!IsCooperativeMatrixType(id)) return false;
|
||||
return IsFloatScalarType(FindDef(id)->word(2));
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
|
||||
if (!IsCooperativeMatrixType(id)) return false;
|
||||
return IsIntScalarType(FindDef(id)->word(2));
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
|
||||
if (!IsCooperativeMatrixType(id)) return false;
|
||||
return IsUnsignedIntScalarType(FindDef(id)->word(2));
|
||||
}
|
||||
|
||||
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
||||
const Instruction* inst, uint32_t m1, uint32_t m2) {
|
||||
const auto m1_type = FindDef(m1);
|
||||
const auto m2_type = FindDef(m2);
|
||||
|
||||
if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
|
||||
m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
|
||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected cooperative matrix types";
|
||||
}
|
||||
|
||||
uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
|
||||
uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
|
||||
uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
|
||||
|
||||
uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
|
||||
uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
|
||||
uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
|
||||
|
||||
bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
|
||||
m2_is_const_int32 = false;
|
||||
uint32_t m1_value = 0, m2_value = 0;
|
||||
|
||||
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
||||
EvalInt32IfConst(m1_scope_id);
|
||||
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
||||
EvalInt32IfConst(m2_scope_id);
|
||||
|
||||
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected scopes of Matrix and Result Type to be "
|
||||
<< "identical";
|
||||
}
|
||||
|
||||
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
||||
EvalInt32IfConst(m1_rows_id);
|
||||
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
||||
EvalInt32IfConst(m2_rows_id);
|
||||
|
||||
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected rows of Matrix type and Result Type to be "
|
||||
<< "identical";
|
||||
}
|
||||
|
||||
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
||||
EvalInt32IfConst(m1_cols_id);
|
||||
std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
|
||||
EvalInt32IfConst(m2_cols_id);
|
||||
|
||||
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected columns of Matrix type and Result Type to be "
|
||||
<< "identical";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
|
||||
size_t operand_index) const {
|
||||
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
|
||||
@ -890,7 +977,7 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
|
||||
}
|
||||
|
||||
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
|
||||
uint32_t id) {
|
||||
uint32_t id) const {
|
||||
const Instruction* const inst = FindDef(id);
|
||||
assert(inst);
|
||||
const uint32_t type = inst->type_id();
|
||||
|
@ -552,6 +552,10 @@ class ValidationState_t {
|
||||
bool IsBoolVectorType(uint32_t id) const;
|
||||
bool IsBoolScalarOrVectorType(uint32_t id) const;
|
||||
bool IsPointerType(uint32_t id) const;
|
||||
bool IsCooperativeMatrixType(uint32_t id) const;
|
||||
bool IsFloatCooperativeMatrixType(uint32_t id) const;
|
||||
bool IsIntCooperativeMatrixType(uint32_t id) const;
|
||||
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
|
||||
|
||||
// Gets value from OpConstant and OpSpecConstant as uint64.
|
||||
// Returns false on failure (no instruction, wrong instruction, not int).
|
||||
@ -635,7 +639,7 @@ class ValidationState_t {
|
||||
// Returns tuple <is_int32, is_const_int32, value>.
|
||||
// OpSpecConstant* return |is_const_int32| as false since their values cannot
|
||||
// be relied upon during validation.
|
||||
std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id);
|
||||
std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id) const;
|
||||
|
||||
// Returns the disassembly string for the given instruction.
|
||||
std::string Disassemble(const Instruction& inst) const;
|
||||
@ -643,6 +647,12 @@ class ValidationState_t {
|
||||
// Returns the disassembly string for the given instruction.
|
||||
std::string Disassemble(const uint32_t* words, uint16_t num_words) const;
|
||||
|
||||
// Returns whether type m1 and type m2 are cooperative matrices with
|
||||
// the same "shape" (matching scope, rows, cols). If any are specialization
|
||||
// constants, we assume they can match because we can't prove they don't.
|
||||
spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst,
|
||||
uint32_t m1, uint32_t m2);
|
||||
|
||||
private:
|
||||
ValidationState_t(const ValidationState_t&);
|
||||
|
||||
|
@ -1165,6 +1165,150 @@ TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
|
||||
"vector size of the right operand: OuterProduct"));
|
||||
}
|
||||
|
||||
std::string GenerateCoopMatCode(const std::string& extra_types,
|
||||
const std::string& main_body) {
|
||||
const std::string prefix =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%u32_16 = OpConstant %u32 16
|
||||
%u32_4 = OpConstant %u32 4
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
|
||||
%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
|
||||
|
||||
%f16_1 = OpConstant %f16 1
|
||||
%f32_1 = OpConstant %f32 1
|
||||
%u32_1 = OpConstant %u32 1
|
||||
%s32_1 = OpConstant %s32 1
|
||||
|
||||
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||
%u32mat_1 = OpConstantComposite %u32mat %u32_1
|
||||
%s32mat_1 = OpConstantComposite %s32mat %s32_1
|
||||
|
||||
%u32_c1 = OpSpecConstant %u32 1
|
||||
%u32_c2 = OpSpecConstant %u32 2
|
||||
|
||||
%f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
|
||||
%f16matc_1 = OpConstantComposite %f16matc %f16_1
|
||||
|
||||
%mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
|
||||
%mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
|
||||
%mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
|
||||
%f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
|
||||
%f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
|
||||
%f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
|
||||
|
||||
const std::string func_begin =
|
||||
R"(
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel)";
|
||||
|
||||
const std::string suffix =
|
||||
R"(
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
return prefix + extra_types + func_begin + main_body + suffix;
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, CoopMatSuccess) {
|
||||
const std::string body = R"(
|
||||
%val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
|
||||
%val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
|
||||
%val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
|
||||
%val4 = OpFNegate %f16mat %f16mat_1
|
||||
%val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
|
||||
%val6 = OpISub %u32mat %u32mat_1 %u32mat_1
|
||||
%val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
|
||||
%val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
|
||||
%val9 = OpISub %s32mat %s32mat_1 %s32mat_1
|
||||
%val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
|
||||
%val11 = OpSNegate %s32mat %s32mat_1
|
||||
%val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
|
||||
%val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
|
||||
%val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
|
||||
%val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
|
||||
%val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, CoopMatFMulFail) {
|
||||
const std::string body = R"(
|
||||
%val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Expected floating scalar or vector type as Result Type: FMul"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
|
||||
const std::string body = R"(
|
||||
%val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Expected scalar operand type to be equal to the component "
|
||||
"type of the matrix operand: MatrixTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, CoopMatScopeFail) {
|
||||
const std::string types = R"(
|
||||
%workgroup = OpConstant %u32 2
|
||||
|
||||
%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
|
||||
%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
|
||||
)";
|
||||
|
||||
const std::string body = R"(
|
||||
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, CoopMatDimFail) {
|
||||
const std::string body = R"(
|
||||
%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, IAddCarrySuccess) {
|
||||
const std::string body = R"(
|
||||
%val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
|
||||
|
@ -1467,6 +1467,83 @@ OpFunctionEnd
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u32 = OpTypeInt 32 0
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%f32_1 = OpConstant %f32 1
|
||||
|
||||
%f16mat_1 = OpConstantComposite %f16mat %f32_1
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("OpConstantComposite Constituent <id> '11[%float_1]' type does "
|
||||
"not match the Result Type <id> '10[%10]'s component type."));
|
||||
}
|
||||
|
||||
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u32 = OpTypeInt 32 0
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%f32_1 = OpConstant %f32 1
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("Expected Constituent type to be equal to the component type"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
|
||||
const std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
|
@ -1184,6 +1184,172 @@ TEST_F(ValidateConversion, GenericCastToPtrExplicitPointToDifferentType) {
|
||||
"GenericCastToPtrExplicit"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, CoopMatConversionSuccess) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability Int16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u16 = OpTypeInt 16 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s16 = OpTypeInt 16 1
|
||||
%s32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
|
||||
%u16mat = OpTypeCooperativeMatrixNV %u16 %subgroup %u32_8 %u32_8
|
||||
%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
|
||||
%s16mat = OpTypeCooperativeMatrixNV %s16 %subgroup %u32_8 %u32_8
|
||||
%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
|
||||
|
||||
%f16_1 = OpConstant %f16 1
|
||||
%f32_1 = OpConstant %f32 1
|
||||
%u16_1 = OpConstant %u16 1
|
||||
%u32_1 = OpConstant %u32 1
|
||||
%s16_1 = OpConstant %s16 1
|
||||
%s32_1 = OpConstant %s32 1
|
||||
|
||||
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||
%f32mat_1 = OpConstantComposite %f32mat %f32_1
|
||||
%u16mat_1 = OpConstantComposite %u16mat %u16_1
|
||||
%u32mat_1 = OpConstantComposite %u32mat %u32_1
|
||||
%s16mat_1 = OpConstantComposite %s16mat %s16_1
|
||||
%s32mat_1 = OpConstantComposite %s32mat %s32_1
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%val11 = OpConvertFToU %u16mat %f16mat_1
|
||||
%val12 = OpConvertFToU %u32mat %f16mat_1
|
||||
%val13 = OpConvertFToS %s16mat %f16mat_1
|
||||
%val14 = OpConvertFToS %s32mat %f16mat_1
|
||||
%val15 = OpFConvert %f32mat %f16mat_1
|
||||
|
||||
%val21 = OpConvertFToU %u16mat %f32mat_1
|
||||
%val22 = OpConvertFToU %u32mat %f32mat_1
|
||||
%val23 = OpConvertFToS %s16mat %f32mat_1
|
||||
%val24 = OpConvertFToS %s32mat %f32mat_1
|
||||
%val25 = OpFConvert %f16mat %f32mat_1
|
||||
|
||||
%val31 = OpConvertUToF %f16mat %u16mat_1
|
||||
%val32 = OpConvertUToF %f32mat %u16mat_1
|
||||
%val33 = OpUConvert %u32mat %u16mat_1
|
||||
%val34 = OpSConvert %s32mat %u16mat_1
|
||||
|
||||
%val41 = OpConvertSToF %f16mat %s16mat_1
|
||||
%val42 = OpConvertSToF %f32mat %s16mat_1
|
||||
%val43 = OpUConvert %u32mat %s16mat_1
|
||||
%val44 = OpSConvert %s32mat %s16mat_1
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchFail) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability Int16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u16 = OpTypeInt 16 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s16 = OpTypeInt 16 1
|
||||
%s32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%u32_4 = OpConstant %u32 4
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
|
||||
|
||||
%f16_1 = OpConstant %f16 1
|
||||
|
||||
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%val15 = OpFConvert %f32mat %f16mat_1
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Expected rows of Matrix type and Result Type to be identical"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability Int16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
%f16 = OpTypeFloat 16
|
||||
%f32 = OpTypeFloat 32
|
||||
%u16 = OpTypeInt 16 0
|
||||
%u32 = OpTypeInt 32 0
|
||||
%s16 = OpTypeInt 16 1
|
||||
%s32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%u32_4 = OpSpecConstant %u32 4
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
|
||||
|
||||
%f16_1 = OpConstant %f16 1
|
||||
|
||||
%f16mat_1 = OpConstantComposite %f16mat %f16_1
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%val15 = OpFConvert %f32mat %f16mat_1
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, BitcastSuccess) {
|
||||
const std::string body = R"(
|
||||
%ptr = OpVariable %f32ptr_func Function
|
||||
|
@ -1774,6 +1774,372 @@ OpFunctionEnd
|
||||
HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable"));
|
||||
}
|
||||
|
||||
std::string GenCoopMatLoadStoreShader(const std::string& storeMemoryAccess,
|
||||
const std::string& loadMemoryAccess) {
|
||||
std::string s = R"(
|
||||
OpCapability Shader
|
||||
OpCapability GroupNonUniform
|
||||
OpCapability VulkanMemoryModelKHR
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical VulkanKHR
|
||||
OpEntryPoint GLCompute %4 "main" %11 %21
|
||||
OpExecutionMode %4 LocalSize 1 1 1
|
||||
OpDecorate %11 BuiltIn SubgroupId
|
||||
OpDecorate %21 BuiltIn WorkgroupId
|
||||
OpDecorate %74 ArrayStride 4
|
||||
OpMemberDecorate %75 0 Offset 0
|
||||
OpDecorate %75 Block
|
||||
OpDecorate %77 DescriptorSet 0
|
||||
OpDecorate %77 Binding 0
|
||||
OpDecorate %92 ArrayStride 4
|
||||
OpMemberDecorate %93 0 Offset 0
|
||||
OpDecorate %93 Block
|
||||
OpDecorate %95 DescriptorSet 0
|
||||
OpDecorate %95 Binding 1
|
||||
OpDecorate %102 ArrayStride 4
|
||||
OpMemberDecorate %103 0 Offset 0
|
||||
OpDecorate %103 Block
|
||||
OpDecorate %105 DescriptorSet 0
|
||||
OpDecorate %105 Binding 2
|
||||
OpDecorate %117 ArrayStride 4
|
||||
OpMemberDecorate %118 0 Offset 0
|
||||
OpDecorate %118 Block
|
||||
OpDecorate %120 DescriptorSet 0
|
||||
OpDecorate %120 Binding 3
|
||||
OpDecorate %123 SpecId 2
|
||||
OpDecorate %124 SpecId 3
|
||||
OpDecorate %125 SpecId 4
|
||||
OpDecorate %126 SpecId 5
|
||||
OpDecorate %127 SpecId 0
|
||||
OpDecorate %128 SpecId 1
|
||||
OpDecorate %129 BuiltIn WorkgroupSize
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2
|
||||
%6 = OpTypeInt 32 0
|
||||
%7 = OpTypeVector %6 2
|
||||
%8 = OpTypePointer Function %7
|
||||
%10 = OpTypePointer Input %6
|
||||
%11 = OpVariable %10 Input
|
||||
%13 = OpConstant %6 2
|
||||
%19 = OpTypeVector %6 3
|
||||
%20 = OpTypePointer Input %19
|
||||
%21 = OpVariable %20 Input
|
||||
%27 = OpConstantComposite %7 %13 %13
|
||||
%31 = OpTypePointer Function %6
|
||||
%33 = OpConstant %6 1024
|
||||
%34 = OpConstant %6 1
|
||||
%38 = OpConstant %6 8
|
||||
%39 = OpConstant %6 0
|
||||
%68 = OpTypeFloat 32
|
||||
%69 = OpConstant %6 16
|
||||
%70 = OpConstant %6 3
|
||||
%71 = OpTypeCooperativeMatrixNV %68 %70 %69 %38
|
||||
%72 = OpTypePointer Function %71
|
||||
%74 = OpTypeRuntimeArray %68
|
||||
%75 = OpTypeStruct %74
|
||||
%76 = OpTypePointer StorageBuffer %75
|
||||
%77 = OpVariable %76 StorageBuffer
|
||||
%78 = OpTypeInt 32 1
|
||||
%79 = OpConstant %78 0
|
||||
%81 = OpConstant %6 5
|
||||
%82 = OpTypePointer StorageBuffer %68
|
||||
%84 = OpConstant %6 64
|
||||
%85 = OpTypeBool
|
||||
%86 = OpConstantFalse %85
|
||||
%88 = OpTypePointer Private %71
|
||||
%89 = OpVariable %88 Private
|
||||
%92 = OpTypeRuntimeArray %68
|
||||
%93 = OpTypeStruct %92
|
||||
%94 = OpTypePointer StorageBuffer %93
|
||||
%95 = OpVariable %94 StorageBuffer
|
||||
%99 = OpVariable %88 Private
|
||||
%102 = OpTypeRuntimeArray %68
|
||||
%103 = OpTypeStruct %102
|
||||
%104 = OpTypePointer StorageBuffer %103
|
||||
%105 = OpVariable %104 StorageBuffer
|
||||
%109 = OpVariable %88 Private
|
||||
%111 = OpVariable %88 Private
|
||||
%112 = OpSpecConstantOp %6 CooperativeMatrixLengthNV %71
|
||||
%113 = OpSpecConstantOp %78 IAdd %112 %79
|
||||
%117 = OpTypeRuntimeArray %68
|
||||
%118 = OpTypeStruct %117
|
||||
%119 = OpTypePointer StorageBuffer %118
|
||||
%120 = OpVariable %119 StorageBuffer
|
||||
%123 = OpSpecConstant %78 1
|
||||
%124 = OpSpecConstant %78 1
|
||||
%125 = OpSpecConstant %78 1
|
||||
%126 = OpSpecConstant %78 1
|
||||
%127 = OpSpecConstant %6 1
|
||||
%128 = OpSpecConstant %6 1
|
||||
%129 = OpSpecConstantComposite %19 %127 %128 %34
|
||||
%4 = OpFunction %2 None %3
|
||||
%5 = OpLabel
|
||||
%9 = OpVariable %8 Function
|
||||
%18 = OpVariable %8 Function
|
||||
%32 = OpVariable %31 Function
|
||||
%44 = OpVariable %31 Function
|
||||
%52 = OpVariable %31 Function
|
||||
%60 = OpVariable %31 Function
|
||||
%73 = OpVariable %72 Function
|
||||
%91 = OpVariable %72 Function
|
||||
%101 = OpVariable %72 Function
|
||||
%12 = OpLoad %6 %11
|
||||
%14 = OpUMod %6 %12 %13
|
||||
%15 = OpLoad %6 %11
|
||||
%16 = OpUDiv %6 %15 %13
|
||||
%17 = OpCompositeConstruct %7 %14 %16
|
||||
OpStore %9 %17
|
||||
%22 = OpLoad %19 %21
|
||||
%23 = OpVectorShuffle %7 %22 %22 0 1
|
||||
%24 = OpCompositeExtract %6 %23 0
|
||||
%25 = OpCompositeExtract %6 %23 1
|
||||
%26 = OpCompositeConstruct %7 %24 %25
|
||||
%28 = OpIMul %7 %26 %27
|
||||
%29 = OpLoad %7 %9
|
||||
%30 = OpIAdd %7 %28 %29
|
||||
OpStore %18 %30
|
||||
%35 = OpAccessChain %31 %18 %34
|
||||
%36 = OpLoad %6 %35
|
||||
%37 = OpIMul %6 %33 %36
|
||||
%40 = OpAccessChain %31 %18 %39
|
||||
%41 = OpLoad %6 %40
|
||||
%42 = OpIMul %6 %38 %41
|
||||
%43 = OpIAdd %6 %37 %42
|
||||
OpStore %32 %43
|
||||
%45 = OpAccessChain %31 %18 %34
|
||||
%46 = OpLoad %6 %45
|
||||
%47 = OpIMul %6 %33 %46
|
||||
%48 = OpAccessChain %31 %18 %39
|
||||
%49 = OpLoad %6 %48
|
||||
%50 = OpIMul %6 %38 %49
|
||||
%51 = OpIAdd %6 %47 %50
|
||||
OpStore %44 %51
|
||||
%53 = OpAccessChain %31 %18 %34
|
||||
%54 = OpLoad %6 %53
|
||||
%55 = OpIMul %6 %33 %54
|
||||
%56 = OpAccessChain %31 %18 %39
|
||||
%57 = OpLoad %6 %56
|
||||
%58 = OpIMul %6 %38 %57
|
||||
%59 = OpIAdd %6 %55 %58
|
||||
OpStore %52 %59
|
||||
%61 = OpAccessChain %31 %18 %34
|
||||
%62 = OpLoad %6 %61
|
||||
%63 = OpIMul %6 %33 %62
|
||||
%64 = OpAccessChain %31 %18 %39
|
||||
%65 = OpLoad %6 %64
|
||||
%66 = OpIMul %6 %38 %65
|
||||
%67 = OpIAdd %6 %63 %66
|
||||
OpStore %60 %67
|
||||
%80 = OpLoad %6 %32
|
||||
%83 = OpAccessChain %82 %77 %79 %80
|
||||
%87 = OpCooperativeMatrixLoadNV %71 %83 %84 %86 )" +
|
||||
loadMemoryAccess + R"( %81
|
||||
OpStore %73 %87
|
||||
%90 = OpLoad %71 %73
|
||||
OpStore %89 %90
|
||||
%96 = OpLoad %6 %44
|
||||
%97 = OpAccessChain %82 %95 %79 %96
|
||||
%98 = OpCooperativeMatrixLoadNV %71 %97 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
|
||||
OpStore %91 %98
|
||||
%100 = OpLoad %71 %91
|
||||
OpStore %99 %100
|
||||
%106 = OpLoad %6 %52
|
||||
%107 = OpAccessChain %82 %105 %79 %106
|
||||
%108 = OpCooperativeMatrixLoadNV %71 %107 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
|
||||
OpStore %101 %108
|
||||
%110 = OpLoad %71 %101
|
||||
OpStore %109 %110
|
||||
%114 = OpConvertSToF %68 %113
|
||||
%115 = OpCompositeConstruct %71 %114
|
||||
OpStore %111 %115
|
||||
%116 = OpLoad %71 %111
|
||||
%121 = OpLoad %6 %60
|
||||
%122 = OpAccessChain %82 %120 %79 %121
|
||||
OpCooperativeMatrixStoreNV %122 %116 %84 %86 )" + storeMemoryAccess + R"( %81
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatLoadStoreSuccess) {
|
||||
std::string spirv =
|
||||
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
|
||||
"MakePointerVisibleKHR|NonPrivatePointerKHR");
|
||||
|
||||
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatStoreMemoryAccessFail) {
|
||||
std::string spirv =
|
||||
GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
|
||||
"MakePointerVisibleKHR|NonPrivatePointerKHR");
|
||||
|
||||
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatLoadMemoryAccessFail) {
|
||||
std::string spirv =
|
||||
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
|
||||
"MakePointerAvailableKHR|NonPrivatePointerKHR");
|
||||
|
||||
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatInvalidStorageClassFail) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%f16 = OpTypeFloat 16
|
||||
%u32 = OpTypeInt 32 0
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%str = OpTypeStruct %f16mat
|
||||
%str_ptr = OpTypePointer Workgroup %str
|
||||
%sh = OpVariable %str_ptr Workgroup
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Cooperative matrix types (or types containing them) can only be "
|
||||
"allocated in Function or Private storage classes or as function "
|
||||
"parameters"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatMatrixLengthResultTypeBad) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%f16 = OpTypeFloat 16
|
||||
%u32 = OpTypeInt 32 0
|
||||
%i32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%1 = OpCooperativeMatrixLengthNV %i32 %f16mat
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("The Result Type of OpCooperativeMatrixLengthNV <id> "
|
||||
"'11[%11]' must be OpTypeInt with width 32 and signedness 0"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatMatrixLengthOperandTypeBad) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%f16 = OpTypeFloat 16
|
||||
%u32 = OpTypeInt 32 0
|
||||
%i32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%1 = OpCooperativeMatrixLengthNV %u32 %u32
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr("The type in OpCooperativeMatrixLengthNV <id> '5[%uint]' "
|
||||
"must be OpTypeCooperativeMatrixNV"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, CoopMatMatrixLengthGood) {
|
||||
const std::string body =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability CooperativeMatrixNV
|
||||
OpExtension "SPV_NV_cooperative_matrix"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%f16 = OpTypeFloat 16
|
||||
%u32 = OpTypeInt 32 0
|
||||
%i32 = OpTypeInt 32 1
|
||||
|
||||
%u32_8 = OpConstant %u32 8
|
||||
%subgroup = OpConstant %u32 3
|
||||
|
||||
%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel
|
||||
|
||||
%1 = OpCooperativeMatrixLengthNV %u32 %f16mat
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
CompileSuccessfully(body.c_str());
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
|
||||
std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
|
Loading…
Reference in New Issue
Block a user