Add new checks to validate arithmetics pass

New operations:
- OpDot
- OpVectorTimesScalar
- OpMatrixTimesScalar
- OpVectorTimesMatrix
- OpMatrixTimesVector
- OpMatrixTimesMatrix
- OpOuterProduct
This commit is contained in:
Andrey Tuganov 2017-09-06 14:30:27 -04:00 committed by David Neto
parent 4442102247
commit c6dfc11880
6 changed files with 904 additions and 11 deletions

View File

@ -1290,8 +1290,10 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() {
}
case SpvOpVectorTimesScalar: {
if (operand_index_ == 0)
if (operand_index_ == 0) {
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
assert(inst_.type_id);
if (operand_index_ == 2)

View File

@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
return false;
}
bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
const Instruction* inst = FindDef(id);
assert(inst);
if (inst->opcode() == SpvOpTypeMatrix) {
return IsFloatScalarType(GetComponentType(id));
}
return false;
}
bool ValidationState_t::GetMatrixTypeInfo(
uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
uint32_t* column_type, uint32_t* component_type) const {
if (!id)
return false;
const Instruction* mat_inst = FindDef(id);
assert(mat_inst);
if (mat_inst->opcode() != SpvOpTypeMatrix)
return false;
const uint32_t vec_type = mat_inst->word(2);
const Instruction* vec_inst = FindDef(vec_type);
assert(vec_inst);
if (vec_inst->opcode() != SpvOpTypeVector) {
assert(0);
return false;
}
*num_cols = mat_inst->word(3);
*num_rows = vec_inst->word(3);
*column_type = mat_inst->word(2);
*component_type = vec_inst->word(2);
return true;
}
} /// namespace libspirv

View File

@ -333,27 +333,36 @@ class ValidationState_t {
// Returns type_id of the scalar component of |id|.
// |id| can be either
// - vector type
// - matrix type
// - object of either vector or matrix type
// - scalar, vector or matrix type
// - object of either scalar, vector or matrix type
uint32_t GetComponentType(uint32_t id) const;
// Returns dimension of scalar, vector or matrix type or object. Will invoke
// assertion and return 0 if |id| is none of the above.
// In case of matrix returns number of columns.
// Returns
// - 1 for scalar types or objects
// - vector size for vector types or objects
// - num columns for matrix types or objects
// Should not be called with any other arguments (will return zero and invoke
// assertion).
uint32_t GetDimension(uint32_t id) const;
// Returns bit width of scalar or component.
// |id| can be
// - scalar type or object
// - vector or matrix type or object
// - scalar, vector or matrix type
// - object of either scalar, vector or matrix type
// Will invoke assertion and return 0 if |id| is none of the above.
uint32_t GetBitWidth(uint32_t id) const;
// Provides detailed information on matrix type.
// Returns false iff |id| is not matrix type.
bool GetMatrixTypeInfo(
uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
uint32_t* column_type, uint32_t* component_type) const;
// Returns true iff |id| is a type corresponding to the name of the function.
// Only works for types not for objects.
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
bool IsIntVectorType(uint32_t id) const;
bool IsUnsignedIntScalarType(uint32_t id) const;

View File

@ -126,6 +126,287 @@ spv_result_t ArithmeticsPass(ValidationState_t& _,
break;
}
case SpvOpDot: {
if (!_.IsFloatScalarType(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float scalar type as type_id: "
<< spvOpcodeString(opcode);
uint32_t first_vector_num_components = 0;
for (size_t operand_index = 2; operand_index < inst->num_operands;
++operand_index) {
const uint32_t type_id =
_.GetTypeId(GetOperandWord(inst, operand_index));
if (!type_id || !_.IsFloatVectorType(type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector as operand: "
<< spvOpcodeString(opcode) << " operand index " << operand_index;
const uint32_t component_type = _.GetComponentType(type_id);
if (component_type != inst->type_id)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component type to be equal to type_id: "
<< spvOpcodeString(opcode) << " operand index " << operand_index;
const uint32_t num_components = _.GetDimension(type_id);
if (operand_index == 2) {
first_vector_num_components = num_components;
} else if (num_components != first_vector_num_components) {
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected operands to have the same number of componenets: "
<< spvOpcodeString(opcode);
}
}
break;
}
case SpvOpVectorTimesScalar: {
if (!_.IsFloatVectorType(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as type_id: "
<< spvOpcodeString(opcode);
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
if (inst->type_id != vector_type_id)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected vector operand type to be equal to type_id: "
<< spvOpcodeString(opcode);
const uint32_t component_type = _.GetComponentType(vector_type_id);
const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
if (component_type != scalar_type_id)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected scalar operand type to be equal to the component "
<< "type of the vector operand: "
<< spvOpcodeString(opcode);
break;
}
case SpvOpMatrixTimesScalar: {
if (!_.IsFloatMatrixType(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as type_id: "
<< spvOpcodeString(opcode);
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
if (inst->type_id != matrix_type_id)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected matrix operand type to be equal to type_id: "
<< spvOpcodeString(opcode);
const uint32_t component_type = _.GetComponentType(matrix_type_id);
const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
if (component_type != scalar_type_id)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected scalar operand type to be equal to the component "
<< "type of the matrix operand: "
<< spvOpcodeString(opcode);
break;
}
case SpvOpVectorTimesMatrix: {
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 3));
if (!_.IsFloatVectorType(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as type_id: "
<< spvOpcodeString(opcode);
const uint32_t res_component_type = _.GetComponentType(inst->type_id);
if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as left operand: "
<< spvOpcodeString(opcode);
if (res_component_type != _.GetComponentType(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component types of type_id and vector to be equal: "
<< spvOpcodeString(opcode);
uint32_t matrix_num_rows = 0;
uint32_t matrix_num_cols = 0;
uint32_t matrix_col_type = 0;
uint32_t matrix_component_type = 0;
if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
&matrix_num_cols, &matrix_col_type,
&matrix_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as right operand: "
<< spvOpcodeString(opcode);
if (res_component_type != matrix_component_type)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component types of type_id and matrix to be equal: "
<< spvOpcodeString(opcode);
if (matrix_num_cols != _.GetDimension(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of columns of the matrix to be equal to the "
<< "type_id vector size: " << spvOpcodeString(opcode);
if (matrix_num_rows != _.GetDimension(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of rows of the matrix to be equal to the "
<< "vector operand size: " << spvOpcodeString(opcode);
break;
}
case SpvOpMatrixTimesVector: {
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 3));
if (!_.IsFloatVectorType(inst->type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as type_id: "
<< spvOpcodeString(opcode);
uint32_t matrix_num_rows = 0;
uint32_t matrix_num_cols = 0;
uint32_t matrix_col_type = 0;
uint32_t matrix_component_type = 0;
if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
&matrix_num_cols, &matrix_col_type,
&matrix_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as left operand: "
<< spvOpcodeString(opcode);
if (inst->type_id != matrix_col_type)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected column type of the matrix to be equal to type_id: "
<< spvOpcodeString(opcode);
if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as right operand: "
<< spvOpcodeString(opcode);
if (matrix_component_type != _.GetComponentType(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component types of the operands to be equal: "
<< spvOpcodeString(opcode);
if (matrix_num_cols != _.GetDimension(vector_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of columns of the matrix to be equal to the "
<< "vector size: " << spvOpcodeString(opcode);
break;
}
case SpvOpMatrixTimesMatrix: {
const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
uint32_t res_num_rows = 0;
uint32_t res_num_cols = 0;
uint32_t res_col_type = 0;
uint32_t res_component_type = 0;
if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
&res_col_type, &res_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as type_id: "
<< spvOpcodeString(opcode);
uint32_t left_num_rows = 0;
uint32_t left_num_cols = 0;
uint32_t left_col_type = 0;
uint32_t left_component_type = 0;
if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
&left_col_type, &left_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as left operand: "
<< spvOpcodeString(opcode);
uint32_t right_num_rows = 0;
uint32_t right_num_cols = 0;
uint32_t right_col_type = 0;
uint32_t right_component_type = 0;
if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
&right_col_type, &right_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as right operand: "
<< spvOpcodeString(opcode);
if (!_.IsFloatScalarType(res_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as type_id: "
<< spvOpcodeString(opcode);
if (res_col_type != left_col_type)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected column types of type_id and left matrix to be "
<< "equal: " << spvOpcodeString(opcode);
if (res_component_type != right_component_type)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component types of type_id and right matrix to be "
<< "equal: " << spvOpcodeString(opcode);
if (res_num_cols != right_num_cols)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of columns of type_id and right matrix to be "
<< "equal: " << spvOpcodeString(opcode);
if (left_num_cols != right_num_rows)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of columns of left matrix and number of rows "
<< "of right matrix to be equal: " << spvOpcodeString(opcode);
assert(left_num_rows == res_num_rows);
break;
}
case SpvOpOuterProduct: {
const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
uint32_t res_num_rows = 0;
uint32_t res_num_cols = 0;
uint32_t res_col_type = 0;
uint32_t res_component_type = 0;
if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
&res_col_type, &res_component_type))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float matrix type as type_id: "
<< spvOpcodeString(opcode);
if (left_type_id != res_col_type)
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected column type of the type_id to be equal to the type "
<< "of the left operand: "
<< spvOpcodeString(opcode);
if (!right_type_id || !_.IsFloatVectorType(right_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected float vector type as right operand: "
<< spvOpcodeString(opcode);
if (res_component_type != _.GetComponentType(right_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected component types of the operands to be equal: "
<< spvOpcodeString(opcode);
if (res_num_cols != _.GetDimension(right_type_id))
return _.diag(SPV_ERROR_INVALID_DATA)
<< "Expected number of columns of the matrix to be equal to the "
<< "vector size of the right operand: " << spvOpcodeString(opcode);
break;
}
// TODO(atgoo@github.com): Support other operations.
default:
break;
}

View File

@ -682,8 +682,6 @@ TEST(Markv, VectorTimesScalar) {
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2
%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2
%res3 = OpVectorTimesScalar %u32vec3 %u32vec3_012 %u32_2
%res4 = OpVectorTimesScalar %s32vec2 %s32vec2_01 %s32_2
)");
}

View File

@ -35,6 +35,7 @@ R"(
OpCapability Shader
OpCapability Int64
OpCapability Float64
OpCapability Matrix
%ext_inst = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
@ -66,6 +67,12 @@ OpEntryPoint Fragment %main "main"
%f32vec4 = OpTypeVector %f32 4
%f64vec4 = OpTypeVector %f64 4
%f32mat22 = OpTypeMatrix %f32vec2 2
%f32mat23 = OpTypeMatrix %f32vec2 3
%f32mat32 = OpTypeMatrix %f32vec3 2
%f32mat33 = OpTypeMatrix %f32vec3 3
%f64mat22 = OpTypeMatrix %f64vec2 2
%f32_0 = OpConstant %f32 0
%f32_1 = OpConstant %f32 1
%f32_2 = OpConstant %f32 2
@ -133,6 +140,13 @@ OpEntryPoint Fragment %main "main"
%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
%f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
%f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
%f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
%main = OpFunction %void None %func
%main_entry = OpLabel)";
@ -536,4 +550,554 @@ TEST_F(ValidateArithmetics, UDivWrongOperand2) {
"UDiv operand index 3"));
}
TEST_F(ValidateArithmetics, DotSuccess) {
const std::string body = R"(
%val = OpDot %f32 %f32vec2_01 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, DotWrongTypeId) {
const std::string body = R"(
%val = OpDot %u32 %u32vec2_01 %u32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float scalar type as type_id: Dot"));
}
TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
const std::string body = R"(
%val = OpDot %f32 %f32 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector as operand: Dot operand index 2"));
}
TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
const std::string body = R"(
%val = OpDot %f32 %f32vec3_012 %f32_1
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector as operand: Dot operand index 3"));
}
TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
const std::string body = R"(
%val = OpDot %f64 %f32vec2_01 %f64vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component type to be equal to type_id: Dot operand index 2"));
}
TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
const std::string body = R"(
%val = OpDot %f32 %f32vec2_01 %f64vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component type to be equal to type_id: Dot operand index 3"));
}
TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
const std::string body = R"(
%val = OpDot %f32 %f32vec2_01 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected operands to have the same number of componenets: Dot"));
}
TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
const std::string body = R"(
%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
const std::string body = R"(
%val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as type_id: "
"VectorTimesScalar"));
}
TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
const std::string body = R"(
%val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected vector operand type to be equal to type_id: "
"VectorTimesScalar"));
}
TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
const std::string body = R"(
%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected scalar operand type to be equal to the component "
"type of the vector operand: VectorTimesScalar"));
}
TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
const std::string body = R"(
%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
const std::string body = R"(
%val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as type_id: "
"MatrixTimesScalar"));
}
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
const std::string body = R"(
%val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected matrix operand type to be equal to type_id: "
"MatrixTimesScalar"));
}
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
const std::string body = R"(
%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected scalar operand type to be equal to the component "
"type of the matrix operand: MatrixTimesScalar"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as type_id: "
"VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as left operand: "
"VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component types of type_id and vector to be equal: "
"VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as right operand: "
"VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component types of type_id and matrix to be equal: "
"VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of columns of the matrix to be equal to the type_id "
"vector size: VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
const std::string body = R"(
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of rows of the matrix to be equal to the vector "
"operand size: VectorTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as type_id: "
"MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as left operand: "
"MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected column type of the matrix to be equal to type_id: "
"MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as right operand: "
"MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component types of the operands to be equal: "
"MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
const std::string body = R"(
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of columns of the matrix to be equal to the vector "
"size: MatrixTimesVector"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as type_id: MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as left operand: MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as right operand: MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected column types of type_id and left matrix to be equal: "
"MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component types of type_id and right matrix to be equal: "
"MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of columns of type_id and right matrix to be equal: "
"MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
const std::string body = R"(
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of columns of left matrix and number of rows of right "
"matrix to be equal: MatrixTimesMatrix"));
}
TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
const std::string body = R"(
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
const std::string body = R"(
%val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
const std::string body = R"(
%val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
const std::string body = R"(
%val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float matrix type as type_id: "
"OuterProduct"));
}
TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
const std::string body = R"(
%val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected column type of the type_id to be equal to the type "
"of the left operand: OuterProduct"));
}
TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
const std::string body = R"(
%val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected float vector type as right operand: OuterProduct"));
}
TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
const std::string body = R"(
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected component types of the operands to be equal: OuterProduct"));
}
TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
const std::string body = R"(
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
)";
CompileSuccessfully(GenerateCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), HasSubstr(
"Expected number of columns of the matrix to be equal to the "
"vector size of the right operand: OuterProduct"));
}
} // anonymous namespace