Skip to content

Commit

Permalink
Added arithmetic operators and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hosseinmoein committed Nov 12, 2024
1 parent 2b0fd33 commit 266e347
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 1 deletion.
202 changes: 201 additions & 1 deletion include/DataFrame/Utils/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ enum class matrix_orient : unsigned char {

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO = matrix_orient::column_major>
template<typename T, matrix_orient MO = matrix_orient::column_major>
class Matrix {

public:
Expand Down Expand Up @@ -969,6 +969,206 @@ class Matrix {
}
};

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline bool
operator != (const Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

if (lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()) {
if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < lhs.cols(); ++c)
for (long r = 0; r < lhs.rows(); ++r)
if (lhs(r, c) != rhs(r, c))
return (true);
}
else {
for (long r = 0; r < lhs.rows(); ++r)
for (long c = 0; c < lhs.cols(); ++c)
if (lhs(r, c) != rhs(r, c))
return (true);
}
}
else return (true);

return (false);
}

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline bool
operator == (const Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

return (! (lhs != rhs));
}

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline Matrix<T, MO1>
operator + (const Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());

auto result = lhs;

if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < lhs.cols(); ++c)
for (long r = 0; r < lhs.rows(); ++r)
result(r, c) += rhs(r, c);
}
else {
for (long r = 0; r < lhs.rows(); ++r)
for (long c = 0; c < lhs.cols(); ++c)
result(r, c) += rhs(r, c);
}
return (result);
}

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline Matrix<T, MO1>
operator - (const Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());

auto result = lhs;

if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < lhs.cols(); ++c)
for (long r = 0; r < lhs.rows(); ++r)
result(r, c) -= rhs(r, c);
}
else {
for (long r = 0; r < lhs.rows(); ++r)
for (long c = 0; c < lhs.cols(); ++c)
result(r, c) -= rhs(r, c);
}
return (result);
}

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline Matrix<T, MO1> &
operator += (Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());

if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < lhs.cols(); ++c)
for (long r = 0; r < lhs.rows(); ++r)
lhs(r, c) += rhs(r, c);
}
else {
for (long r = 0; r < lhs.rows(); ++r)
for (long c = 0; c < lhs.cols(); ++c)
lhs(r, c) += rhs(r, c);
}
return (lhs);
}

// ----------------------------------------------------------------------------

template<typename T, matrix_orient MO1, matrix_orient MO2>
static inline Matrix<T, MO1> &
operator -= (Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());

if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < lhs.cols(); ++c)
for (long r = 0; r < lhs.rows(); ++r)
lhs(r, c) -= rhs(r, c);
}
else {
for (long r = 0; r < lhs.rows(); ++r)
for (long c = 0; c < lhs.cols(); ++c)
lhs(r, c) -= rhs(r, c);
}
return (lhs);
}

// ----------------------------------------------------------------------------

// Naïve but cache friendly O(n^3) algorithm
//
template<typename T, matrix_orient MO1, matrix_orient MO2>
static Matrix<T, MO1>
operator * (const Matrix<T, MO1> &lhs, const Matrix<T, MO2> &rhs) {

assert(lhs.cols() == rhs.rows());

const long lhs_rows { lhs.rows() };
const long lhs_cols { lhs.cols() };
const long rhs_cols { rhs.cols() };
Matrix<T, MO1> result { lhs_rows, rhs_cols };

constexpr long large_dim = 100;

// Using SIMD for large matrixes
//
if (lhs_cols >= large_dim && rhs_cols >= large_dim) {
constexpr long block = 8;

if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < rhs_cols; ++c) {
for (long r = 0; r < lhs_rows; ++r) {
for (long k = 0; k < lhs_cols; k += block) {
const long min_s = std::min(block, lhs_cols);

#pragma unroll(block)
// This loop should be optimized/unrolled by the
// compiler and use SIMD to execute in parallel
//
for (long w = 0; w < min_s; ++w) {
result(r, c) += lhs(k + w, r) * rhs(c, k + w);
}
}
}
}
}
else { // matrix_orient::row_major
for (long r = 0; r < lhs_rows; ++r) {
for (long c = 0; c < rhs_cols; ++c) {
for (long k = 0; k < lhs_cols; k += block) {
const long min_s = std::min(block, lhs_cols);

#pragma unroll(block)
// This loop should be optimized/unrolled by the
// compiler and use SIMD to execute in parallel
//
for (long w = 0; w < min_s; ++w) {
result(r, c) += lhs(r, k + w) * rhs(k + w, c);
}
}
}
}
}
}

// Naïve but cache friendly O(n^3) algorithm for smaller matrixes
//
else {
if constexpr (MO1 == matrix_orient::column_major) {
for (long c = 0; c < rhs_cols; ++c)
for (long r = 0; r < lhs_rows; ++r)
for (long k = 0; k < lhs_cols; ++k)
result(r, c) += lhs(k, r) * rhs(c, k);
}
else { // matrix_orient::row_major
for (long r = 0; r < lhs_rows; ++r)
for (long c = 0; c < rhs_cols; ++c)
for (long k = 0; k < lhs_cols; ++k)
result(r, c) += lhs(r, k) * rhs(k, c);
}
}

return (result);
}

} // namespace hmdf

// ----------------------------------------------------------------------------
Expand Down
67 changes: 67 additions & 0 deletions test/matrix_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ int main(int, char *[]) {

// Print the stuff out
//
std::cout << "Row matrix\n";
for (long r = 0; r < row_mat.rows(); ++r) {
for (long c = 0; c < row_mat.cols(); ++c) {
std::cout << row_mat(r, c) << ", ";
}
std::cout << '\n';
}
std::cout << "\n\n";
std::cout << "Column matrix\n";
for (long r = 0; r < col_mat.rows(); ++r) {
for (long c = 0; c < col_mat.cols(); ++c) {
std::cout << col_mat(r, c) << ", ";
Expand Down Expand Up @@ -118,6 +120,71 @@ int main(int, char *[]) {
assert(((col_iter1 - col_iter2) == 7));
assert(((row_iter1 - row_iter2) == 7));

const auto col_mat2 = col_mat;

assert(col_mat != row_mat);
assert(col_mat == col_mat2);

auto tran_mat = col_mat.transpose();
auto tran_mat2 = col_mat.transpose2();

assert(tran_mat == tran_mat2);
for (long r = 0; r < tran_mat.rows(); ++r)
for (long c = 0; c < tran_mat.cols(); ++c)
assert(tran_mat(r, c) == col_mat(c, r));

//
// Test arithmetic functions
//

auto sum_mat = col_mat + row_mat;

assert(sum_mat(0, 0) == 0);
assert(sum_mat(4, 5) == 58);
assert(sum_mat(1, 1) == 13);
assert(sum_mat(3, 4) == 45);

sum_mat += col_mat;
assert(sum_mat(0, 0) == 0);
assert(sum_mat(4, 5) == 87);
assert(sum_mat(1, 1) == 19);
assert(sum_mat(3, 4) == 68);

row_mat_t lhs_mat { ROWS, COLS };
col_mat_t rhs_mat { COLS, COLS };

value = 0;
for (long r = 0; r < lhs_mat.rows(); ++r)
for (long c = 0; c < lhs_mat.cols(); ++c)
lhs_mat(r, c) = value++;
value = 0;
for (long c = 0; c < rhs_mat.cols(); ++c)
for (long r = 0; r < rhs_mat.rows(); ++r)
rhs_mat(r, c) = value++;

auto multi_mat = lhs_mat * rhs_mat;

assert(multi_mat(0, 0) == 55);
assert(multi_mat(4, 5) == 5185);
assert(multi_mat(1, 1) == 451);
assert(multi_mat(3, 4) == 3277);

col_mat_t big_lhs_mat { 100, 100 };
col_mat_t big_rhs_mat { 100, 100 };

for (long c = 0; c < 100; ++c)
for (long r = 0; r < 100; ++r) {
big_lhs_mat(r, c) = c + 1;
big_rhs_mat(r, c) = c + 1;
}

auto big_multi_mat = big_lhs_mat * big_rhs_mat;

assert(big_multi_mat(0, 0) == 5050);
assert(big_multi_mat(99, 99) == 505000);
assert(big_multi_mat(98, 2) == 499950);
assert(big_multi_mat(2, 5) == 15150);

return (0);
}

Expand Down

0 comments on commit 266e347

Please sign in to comment.