diff --git a/source/opt/types.h b/source/opt/types.h index 2da425b6b..c085b1e1a 100644 --- a/source/opt/types.h +++ b/source/opt/types.h @@ -162,6 +162,7 @@ class Vector : public Type { bool IsSame(Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } + uint32_t element_count() const { return count_; } Vector* AsVector() override { return this; } const Vector* AsVector() const override { return this; } @@ -179,6 +180,7 @@ class Matrix : public Type { bool IsSame(Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } + uint32_t element_count() const { return count_; } Matrix* AsMatrix() override { return this; } const Matrix* AsMatrix() const override { return this; } diff --git a/test/opt/test_types.cpp b/test/opt/test_types.cpp index 4ebe319b1..b9b336e53 100644 --- a/test/opt/test_types.cpp +++ b/test/opt/test_types.cpp @@ -240,4 +240,21 @@ TEST(Types, FloatWidth) { } } +TEST(Types, VectorElementCount) { + auto s32 = std::unique_ptr(new Integer(32, true)); + for (uint32_t c : {2, 3, 4}) { + auto s32v = std::unique_ptr(new Vector(s32.get(), c)); + EXPECT_EQ(c, s32v->element_count()); + } +} + +TEST(Types, MatrixElementCount) { + auto s32 = std::unique_ptr(new Integer(32, true)); + auto s32v4 = std::unique_ptr(new Vector(s32.get(), 4)); + for (uint32_t c : {1, 2, 3, 4, 10, 100}) { + auto s32m = std::unique_ptr(new Matrix(s32v4.get(), c)); + EXPECT_EQ(c, s32m->element_count()); + } +} + } // anonymous namespace