Fix validation message for cooperative matrix column type (#4502)

* Fix validation message for cooperative matrix column type

Fixes: #4497

* Add tests for cooperative matrix type validation
This commit is contained in:
David Neto 2021-09-10 11:28:00 -04:00 committed by GitHub
parent 2a938fcfa3
commit 013b1f3d6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 1 deletions

View File

@ -596,7 +596,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
!spvOpcodeIsConstant(cols->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
<< "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(cols_id)
<< "' is not a constant instruction with scalar integer type.";
}

View File

@ -1309,6 +1309,57 @@ TEST_F(ValidateArithmetics, CoopMatDimFail) {
HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
}
TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
const std::string types = R"(
%bad = OpTypeCooperativeMatrixNV %bool %subgroup %u32_8 %u32_8
)";
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
"'4[%bool]' is not a scalar numerical type."));
}
TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
const std::string types = R"(
%bad = OpTypeCooperativeMatrixNV %f16 %f32_1 %u32_8 %u32_8
)";
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
const std::string types = R"(
%bad = OpTypeCooperativeMatrixNV %f16 %subgroup %f32_1 %u32_8
)";
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
const std::string types = R"(
%bad = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %f32_1
)";
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
TEST_F(ValidateArithmetics, IAddCarrySuccess) {
const std::string body = R"(
%val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1