diff --git a/include/DataFrame/Utils/Matrix.h b/include/DataFrame/Utils/Matrix.h index fff70908..9fe12974 100644 --- a/include/DataFrame/Utils/Matrix.h +++ b/include/DataFrame/Utils/Matrix.h @@ -48,7 +48,7 @@ enum class matrix_orient : unsigned char { // ---------------------------------------------------------------------------- -template +template class Matrix { public: @@ -969,6 +969,206 @@ class Matrix { } }; +// ---------------------------------------------------------------------------- + +template +static inline bool +operator != (const Matrix &lhs, const Matrix &rhs) { + + if (lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()) { + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < lhs.cols(); ++c) + for (long r = 0; r < lhs.rows(); ++r) + if (lhs(r, c) != rhs(r, c)) + return (true); + } + else { + for (long r = 0; r < lhs.rows(); ++r) + for (long c = 0; c < lhs.cols(); ++c) + if (lhs(r, c) != rhs(r, c)) + return (true); + } + } + else return (true); + + return (false); +} + +// ---------------------------------------------------------------------------- + +template +static inline bool +operator == (const Matrix &lhs, const Matrix &rhs) { + + return (! (lhs != rhs)); +} + +// ---------------------------------------------------------------------------- + +template +static inline Matrix +operator + (const Matrix &lhs, const Matrix &rhs) { + + assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()); + + auto result = lhs; + + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < lhs.cols(); ++c) + for (long r = 0; r < lhs.rows(); ++r) + result(r, c) += rhs(r, c); + } + else { + for (long r = 0; r < lhs.rows(); ++r) + for (long c = 0; c < lhs.cols(); ++c) + result(r, c) += rhs(r, c); + } + return (result); +} + +// ---------------------------------------------------------------------------- + +template +static inline Matrix +operator - (const Matrix &lhs, const Matrix &rhs) { + + assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()); + + auto result = lhs; + + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < lhs.cols(); ++c) + for (long r = 0; r < lhs.rows(); ++r) + result(r, c) -= rhs(r, c); + } + else { + for (long r = 0; r < lhs.rows(); ++r) + for (long c = 0; c < lhs.cols(); ++c) + result(r, c) -= rhs(r, c); + } + return (result); +} + +// ---------------------------------------------------------------------------- + +template +static inline Matrix & +operator += (Matrix &lhs, const Matrix &rhs) { + + assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()); + + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < lhs.cols(); ++c) + for (long r = 0; r < lhs.rows(); ++r) + lhs(r, c) += rhs(r, c); + } + else { + for (long r = 0; r < lhs.rows(); ++r) + for (long c = 0; c < lhs.cols(); ++c) + lhs(r, c) += rhs(r, c); + } + return (lhs); +} + +// ---------------------------------------------------------------------------- + +template +static inline Matrix & +operator -= (Matrix &lhs, const Matrix &rhs) { + + assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols()); + + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < lhs.cols(); ++c) + for (long r = 0; r < lhs.rows(); ++r) + lhs(r, c) -= rhs(r, c); + } + else { + for (long r = 0; r < lhs.rows(); ++r) + for (long c = 0; c < lhs.cols(); ++c) + lhs(r, c) -= rhs(r, c); + } + return (lhs); +} + +// ---------------------------------------------------------------------------- + +// Naïve but cache friendly O(n^3) algorithm +// +template +static Matrix +operator * (const Matrix &lhs, const Matrix &rhs) { + + assert(lhs.cols() == rhs.rows()); + + const long lhs_rows { lhs.rows() }; + const long lhs_cols { lhs.cols() }; + const long rhs_cols { rhs.cols() }; + Matrix result { lhs_rows, rhs_cols }; + + constexpr long large_dim = 100; + + // Using SIMD for large matrixes + // + if (lhs_cols >= large_dim && rhs_cols >= large_dim) { + constexpr long block = 8; + + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < rhs_cols; ++c) { + for (long r = 0; r < lhs_rows; ++r) { + for (long k = 0; k < lhs_cols; k += block) { + const long min_s = std::min(block, lhs_cols); + +#pragma unroll(block) + // This loop should be optimized/unrolled by the + // compiler and use SIMD to execute in parallel + // + for (long w = 0; w < min_s; ++w) { + result(r, c) += lhs(k + w, r) * rhs(c, k + w); + } + } + } + } + } + else { // matrix_orient::row_major + for (long r = 0; r < lhs_rows; ++r) { + for (long c = 0; c < rhs_cols; ++c) { + for (long k = 0; k < lhs_cols; k += block) { + const long min_s = std::min(block, lhs_cols); + +#pragma unroll(block) + // This loop should be optimized/unrolled by the + // compiler and use SIMD to execute in parallel + // + for (long w = 0; w < min_s; ++w) { + result(r, c) += lhs(r, k + w) * rhs(k + w, c); + } + } + } + } + } + } + + // Naïve but cache friendly O(n^3) algorithm for smaller matrixes + // + else { + if constexpr (MO1 == matrix_orient::column_major) { + for (long c = 0; c < rhs_cols; ++c) + for (long r = 0; r < lhs_rows; ++r) + for (long k = 0; k < lhs_cols; ++k) + result(r, c) += lhs(k, r) * rhs(c, k); + } + else { // matrix_orient::row_major + for (long r = 0; r < lhs_rows; ++r) + for (long c = 0; c < rhs_cols; ++c) + for (long k = 0; k < lhs_cols; ++k) + result(r, c) += lhs(r, k) * rhs(k, c); + } + } + + return (result); +} + } // namespace hmdf // ---------------------------------------------------------------------------- diff --git a/test/matrix_tester.cc b/test/matrix_tester.cc index b47f735c..aca12bcc 100644 --- a/test/matrix_tester.cc +++ b/test/matrix_tester.cc @@ -59,6 +59,7 @@ int main(int, char *[]) { // Print the stuff out // + std::cout << "Row matrix\n"; for (long r = 0; r < row_mat.rows(); ++r) { for (long c = 0; c < row_mat.cols(); ++c) { std::cout << row_mat(r, c) << ", "; @@ -66,6 +67,7 @@ int main(int, char *[]) { std::cout << '\n'; } std::cout << "\n\n"; + std::cout << "Column matrix\n"; for (long r = 0; r < col_mat.rows(); ++r) { for (long c = 0; c < col_mat.cols(); ++c) { std::cout << col_mat(r, c) << ", "; @@ -118,6 +120,71 @@ int main(int, char *[]) { assert(((col_iter1 - col_iter2) == 7)); assert(((row_iter1 - row_iter2) == 7)); + const auto col_mat2 = col_mat; + + assert(col_mat != row_mat); + assert(col_mat == col_mat2); + + auto tran_mat = col_mat.transpose(); + auto tran_mat2 = col_mat.transpose2(); + + assert(tran_mat == tran_mat2); + for (long r = 0; r < tran_mat.rows(); ++r) + for (long c = 0; c < tran_mat.cols(); ++c) + assert(tran_mat(r, c) == col_mat(c, r)); + + // + // Test arithmetic functions + // + + auto sum_mat = col_mat + row_mat; + + assert(sum_mat(0, 0) == 0); + assert(sum_mat(4, 5) == 58); + assert(sum_mat(1, 1) == 13); + assert(sum_mat(3, 4) == 45); + + sum_mat += col_mat; + assert(sum_mat(0, 0) == 0); + assert(sum_mat(4, 5) == 87); + assert(sum_mat(1, 1) == 19); + assert(sum_mat(3, 4) == 68); + + row_mat_t lhs_mat { ROWS, COLS }; + col_mat_t rhs_mat { COLS, COLS }; + + value = 0; + for (long r = 0; r < lhs_mat.rows(); ++r) + for (long c = 0; c < lhs_mat.cols(); ++c) + lhs_mat(r, c) = value++; + value = 0; + for (long c = 0; c < rhs_mat.cols(); ++c) + for (long r = 0; r < rhs_mat.rows(); ++r) + rhs_mat(r, c) = value++; + + auto multi_mat = lhs_mat * rhs_mat; + + assert(multi_mat(0, 0) == 55); + assert(multi_mat(4, 5) == 5185); + assert(multi_mat(1, 1) == 451); + assert(multi_mat(3, 4) == 3277); + + col_mat_t big_lhs_mat { 100, 100 }; + col_mat_t big_rhs_mat { 100, 100 }; + + for (long c = 0; c < 100; ++c) + for (long r = 0; r < 100; ++r) { + big_lhs_mat(r, c) = c + 1; + big_rhs_mat(r, c) = c + 1; + } + + auto big_multi_mat = big_lhs_mat * big_rhs_mat; + + assert(big_multi_mat(0, 0) == 5050); + assert(big_multi_mat(99, 99) == 505000); + assert(big_multi_mat(98, 2) == 499950); + assert(big_multi_mat(2, 5) == 15150); + return (0); }