mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-26 21:30:07 +00:00
Add new checks to validate arithmetics pass
New operations: - OpDot - OpVectorTimesScalar - OpMatrixTimesScalar - OpVectorTimesMatrix - OpMatrixTimesVector - OpMatrixTimesMatrix - OpOuterProduct
This commit is contained in:
parent
4442102247
commit
c6dfc11880
@ -1290,8 +1290,10 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() {
|
||||
}
|
||||
|
||||
case SpvOpVectorTimesScalar: {
|
||||
if (operand_index_ == 0)
|
||||
if (operand_index_ == 0) {
|
||||
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
|
||||
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
|
||||
}
|
||||
|
||||
assert(inst_.type_id);
|
||||
if (operand_index_ == 2)
|
||||
|
@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
assert(inst);
|
||||
|
||||
if (inst->opcode() == SpvOpTypeMatrix) {
|
||||
return IsFloatScalarType(GetComponentType(id));
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::GetMatrixTypeInfo(
|
||||
uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
|
||||
uint32_t* column_type, uint32_t* component_type) const {
|
||||
if (!id)
|
||||
return false;
|
||||
|
||||
const Instruction* mat_inst = FindDef(id);
|
||||
assert(mat_inst);
|
||||
if (mat_inst->opcode() != SpvOpTypeMatrix)
|
||||
return false;
|
||||
|
||||
const uint32_t vec_type = mat_inst->word(2);
|
||||
const Instruction* vec_inst = FindDef(vec_type);
|
||||
assert(vec_inst);
|
||||
|
||||
if (vec_inst->opcode() != SpvOpTypeVector) {
|
||||
assert(0);
|
||||
return false;
|
||||
}
|
||||
|
||||
*num_cols = mat_inst->word(3);
|
||||
*num_rows = vec_inst->word(3);
|
||||
*column_type = mat_inst->word(2);
|
||||
*component_type = vec_inst->word(2);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} /// namespace libspirv
|
||||
|
@ -333,27 +333,36 @@ class ValidationState_t {
|
||||
|
||||
// Returns type_id of the scalar component of |id|.
|
||||
// |id| can be either
|
||||
// - vector type
|
||||
// - matrix type
|
||||
// - object of either vector or matrix type
|
||||
// - scalar, vector or matrix type
|
||||
// - object of either scalar, vector or matrix type
|
||||
uint32_t GetComponentType(uint32_t id) const;
|
||||
|
||||
// Returns dimension of scalar, vector or matrix type or object. Will invoke
|
||||
// assertion and return 0 if |id| is none of the above.
|
||||
// In case of matrix returns number of columns.
|
||||
// Returns
|
||||
// - 1 for scalar types or objects
|
||||
// - vector size for vector types or objects
|
||||
// - num columns for matrix types or objects
|
||||
// Should not be called with any other arguments (will return zero and invoke
|
||||
// assertion).
|
||||
uint32_t GetDimension(uint32_t id) const;
|
||||
|
||||
// Returns bit width of scalar or component.
|
||||
// |id| can be
|
||||
// - scalar type or object
|
||||
// - vector or matrix type or object
|
||||
// - scalar, vector or matrix type
|
||||
// - object of either scalar, vector or matrix type
|
||||
// Will invoke assertion and return 0 if |id| is none of the above.
|
||||
uint32_t GetBitWidth(uint32_t id) const;
|
||||
|
||||
// Provides detailed information on matrix type.
|
||||
// Returns false iff |id| is not matrix type.
|
||||
bool GetMatrixTypeInfo(
|
||||
uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
|
||||
uint32_t* column_type, uint32_t* component_type) const;
|
||||
|
||||
// Returns true iff |id| is a type corresponding to the name of the function.
|
||||
// Only works for types not for objects.
|
||||
bool IsFloatScalarType(uint32_t id) const;
|
||||
bool IsFloatVectorType(uint32_t id) const;
|
||||
bool IsFloatMatrixType(uint32_t id) const;
|
||||
bool IsIntScalarType(uint32_t id) const;
|
||||
bool IsIntVectorType(uint32_t id) const;
|
||||
bool IsUnsignedIntScalarType(uint32_t id) const;
|
||||
|
@ -126,6 +126,287 @@ spv_result_t ArithmeticsPass(ValidationState_t& _,
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpDot: {
|
||||
if (!_.IsFloatScalarType(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float scalar type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
uint32_t first_vector_num_components = 0;
|
||||
|
||||
for (size_t operand_index = 2; operand_index < inst->num_operands;
|
||||
++operand_index) {
|
||||
const uint32_t type_id =
|
||||
_.GetTypeId(GetOperandWord(inst, operand_index));
|
||||
|
||||
if (!type_id || !_.IsFloatVectorType(type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector as operand: "
|
||||
<< spvOpcodeString(opcode) << " operand index " << operand_index;
|
||||
|
||||
|
||||
const uint32_t component_type = _.GetComponentType(type_id);
|
||||
if (component_type != inst->type_id)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component type to be equal to type_id: "
|
||||
<< spvOpcodeString(opcode) << " operand index " << operand_index;
|
||||
|
||||
const uint32_t num_components = _.GetDimension(type_id);
|
||||
if (operand_index == 2) {
|
||||
first_vector_num_components = num_components;
|
||||
} else if (num_components != first_vector_num_components) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected operands to have the same number of componenets: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVectorTimesScalar: {
|
||||
if (!_.IsFloatVectorType(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
if (inst->type_id != vector_type_id)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected vector operand type to be equal to type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t component_type = _.GetComponentType(vector_type_id);
|
||||
|
||||
const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
if (component_type != scalar_type_id)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected scalar operand type to be equal to the component "
|
||||
<< "type of the vector operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesScalar: {
|
||||
if (!_.IsFloatMatrixType(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
if (inst->type_id != matrix_type_id)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected matrix operand type to be equal to type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t component_type = _.GetComponentType(matrix_type_id);
|
||||
|
||||
const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
if (component_type != scalar_type_id)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected scalar operand type to be equal to the component "
|
||||
<< "type of the matrix operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVectorTimesMatrix: {
|
||||
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
|
||||
if (!_.IsFloatVectorType(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
const uint32_t res_component_type = _.GetComponentType(inst->type_id);
|
||||
|
||||
if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as left operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (res_component_type != _.GetComponentType(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component types of type_id and vector to be equal: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
uint32_t matrix_num_rows = 0;
|
||||
uint32_t matrix_num_cols = 0;
|
||||
uint32_t matrix_col_type = 0;
|
||||
uint32_t matrix_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
|
||||
&matrix_num_cols, &matrix_col_type,
|
||||
&matrix_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as right operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (res_component_type != matrix_component_type)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component types of type_id and matrix to be equal: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (matrix_num_cols != _.GetDimension(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of columns of the matrix to be equal to the "
|
||||
<< "type_id vector size: " << spvOpcodeString(opcode);
|
||||
|
||||
if (matrix_num_rows != _.GetDimension(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of rows of the matrix to be equal to the "
|
||||
<< "vector operand size: " << spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesVector: {
|
||||
const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
|
||||
if (!_.IsFloatVectorType(inst->type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
uint32_t matrix_num_rows = 0;
|
||||
uint32_t matrix_num_cols = 0;
|
||||
uint32_t matrix_col_type = 0;
|
||||
uint32_t matrix_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
|
||||
&matrix_num_cols, &matrix_col_type,
|
||||
&matrix_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as left operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (inst->type_id != matrix_col_type)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected column type of the matrix to be equal to type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as right operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (matrix_component_type != _.GetComponentType(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component types of the operands to be equal: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (matrix_num_cols != _.GetDimension(vector_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of columns of the matrix to be equal to the "
|
||||
<< "vector size: " << spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesMatrix: {
|
||||
const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
|
||||
uint32_t res_num_rows = 0;
|
||||
uint32_t res_num_cols = 0;
|
||||
uint32_t res_col_type = 0;
|
||||
uint32_t res_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
|
||||
&res_col_type, &res_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
uint32_t left_num_rows = 0;
|
||||
uint32_t left_num_cols = 0;
|
||||
uint32_t left_col_type = 0;
|
||||
uint32_t left_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
|
||||
&left_col_type, &left_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as left operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
uint32_t right_num_rows = 0;
|
||||
uint32_t right_num_cols = 0;
|
||||
uint32_t right_col_type = 0;
|
||||
uint32_t right_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
|
||||
&right_col_type, &right_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as right operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!_.IsFloatScalarType(res_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (res_col_type != left_col_type)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected column types of type_id and left matrix to be "
|
||||
<< "equal: " << spvOpcodeString(opcode);
|
||||
|
||||
if (res_component_type != right_component_type)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component types of type_id and right matrix to be "
|
||||
<< "equal: " << spvOpcodeString(opcode);
|
||||
|
||||
if (res_num_cols != right_num_cols)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of columns of type_id and right matrix to be "
|
||||
<< "equal: " << spvOpcodeString(opcode);
|
||||
|
||||
if (left_num_cols != right_num_rows)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of columns of left matrix and number of rows "
|
||||
<< "of right matrix to be equal: " << spvOpcodeString(opcode);
|
||||
|
||||
assert(left_num_rows == res_num_rows);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpOuterProduct: {
|
||||
const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
|
||||
const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
|
||||
|
||||
uint32_t res_num_rows = 0;
|
||||
uint32_t res_num_cols = 0;
|
||||
uint32_t res_col_type = 0;
|
||||
uint32_t res_component_type = 0;
|
||||
if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
|
||||
&res_col_type, &res_component_type))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float matrix type as type_id: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (left_type_id != res_col_type)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected column type of the type_id to be equal to the type "
|
||||
<< "of the left operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!right_type_id || !_.IsFloatVectorType(right_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected float vector type as right operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (res_component_type != _.GetComponentType(right_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected component types of the operands to be equal: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (res_num_cols != _.GetDimension(right_type_id))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA)
|
||||
<< "Expected number of columns of the matrix to be equal to the "
|
||||
<< "vector size of the right operand: " << spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// TODO(atgoo@github.com): Support other operations.
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -682,8 +682,6 @@ TEST(Markv, VectorTimesScalar) {
|
||||
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
|
||||
%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2
|
||||
%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2
|
||||
%res3 = OpVectorTimesScalar %u32vec3 %u32vec3_012 %u32_2
|
||||
%res4 = OpVectorTimesScalar %s32vec2 %s32vec2_01 %s32_2
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -35,6 +35,7 @@ R"(
|
||||
OpCapability Shader
|
||||
OpCapability Int64
|
||||
OpCapability Float64
|
||||
OpCapability Matrix
|
||||
%ext_inst = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main"
|
||||
@ -66,6 +67,12 @@ OpEntryPoint Fragment %main "main"
|
||||
%f32vec4 = OpTypeVector %f32 4
|
||||
%f64vec4 = OpTypeVector %f64 4
|
||||
|
||||
%f32mat22 = OpTypeMatrix %f32vec2 2
|
||||
%f32mat23 = OpTypeMatrix %f32vec2 3
|
||||
%f32mat32 = OpTypeMatrix %f32vec3 2
|
||||
%f32mat33 = OpTypeMatrix %f32vec3 3
|
||||
%f64mat22 = OpTypeMatrix %f64vec2 2
|
||||
|
||||
%f32_0 = OpConstant %f32 0
|
||||
%f32_1 = OpConstant %f32 1
|
||||
%f32_2 = OpConstant %f32 2
|
||||
@ -133,6 +140,13 @@ OpEntryPoint Fragment %main "main"
|
||||
%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
|
||||
%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
|
||||
|
||||
%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
|
||||
%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
|
||||
%f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
|
||||
%f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
|
||||
|
||||
%f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
|
||||
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel)";
|
||||
|
||||
@ -536,4 +550,554 @@ TEST_F(ValidateArithmetics, UDivWrongOperand2) {
|
||||
"UDiv operand index 3"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotSuccess) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f32 %f32vec2_01 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %u32 %u32vec2_01 %u32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float scalar type as type_id: Dot"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f32 %f32 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector as operand: Dot operand index 2"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f32 %f32vec3_012 %f32_1
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector as operand: Dot operand index 3"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f64 %f32vec2_01 %f64vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component type to be equal to type_id: Dot operand index 2"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f32 %f32vec2_01 %f64vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component type to be equal to type_id: Dot operand index 3"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
|
||||
const std::string body = R"(
|
||||
%val = OpDot %f32 %f32vec2_01 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected operands to have the same number of componenets: Dot"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as type_id: "
|
||||
"VectorTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected vector operand type to be equal to type_id: "
|
||||
"VectorTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected scalar operand type to be equal to the component "
|
||||
"type of the vector operand: VectorTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as type_id: "
|
||||
"MatrixTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected matrix operand type to be equal to type_id: "
|
||||
"MatrixTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected scalar operand type to be equal to the component "
|
||||
"type of the matrix operand: MatrixTimesScalar"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as type_id: "
|
||||
"VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as left operand: "
|
||||
"VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component types of type_id and vector to be equal: "
|
||||
"VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as right operand: "
|
||||
"VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component types of type_id and matrix to be equal: "
|
||||
"VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of columns of the matrix to be equal to the type_id "
|
||||
"vector size: VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of rows of the matrix to be equal to the vector "
|
||||
"operand size: VectorTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as type_id: "
|
||||
"MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as left operand: "
|
||||
"MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected column type of the matrix to be equal to type_id: "
|
||||
"MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as right operand: "
|
||||
"MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component types of the operands to be equal: "
|
||||
"MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of columns of the matrix to be equal to the vector "
|
||||
"size: MatrixTimesVector"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as type_id: MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as left operand: MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as right operand: MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected column types of type_id and left matrix to be equal: "
|
||||
"MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component types of type_id and right matrix to be equal: "
|
||||
"MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of columns of type_id and right matrix to be equal: "
|
||||
"MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
|
||||
const std::string body = R"(
|
||||
%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of columns of left matrix and number of rows of right "
|
||||
"matrix to be equal: MatrixTimesMatrix"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float matrix type as type_id: "
|
||||
"OuterProduct"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected column type of the type_id to be equal to the type "
|
||||
"of the left operand: OuterProduct"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected float vector type as right operand: OuterProduct"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected component types of the operands to be equal: OuterProduct"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
|
||||
const std::string body = R"(
|
||||
%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateCode(body).c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(), HasSubstr(
|
||||
"Expected number of columns of the matrix to be equal to the "
|
||||
"vector size of the right operand: OuterProduct"));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
Loading…
Reference in New Issue
Block a user