Skip to content

Commit

Permalink
Merge pull request #3843 from pleroy/Arrays
Browse files Browse the repository at this point in the history
Clean up the arrays code
  • Loading branch information
pleroy authored Jan 10, 2024
2 parents d102d14 + 6582420 commit ff51614
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 68 deletions.
14 changes: 11 additions & 3 deletions numerics/fixed_arrays.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class FixedVector final {
constexpr FixedVector(
std::array<Scalar, size_>&& data); // NOLINT(runtime/explicit)

TransposedView<FixedVector> Transpose() const;

Scalar Norm() const;
Square<Scalar> Norm²() const;

Expand Down Expand Up @@ -104,11 +102,14 @@ class FixedMatrix final {
Scalar const* row() const;

FixedMatrix Transpose() const;

Scalar FrobeniusNorm() const;

bool operator==(FixedMatrix const& right) const;
bool operator!=(FixedMatrix const& right) const;

// Applies the matrix as a bilinear form. Present for compatibility with
// |SymmetricBilinearForm|. Prefer to use |TransposedView| and |operator*|.
template<typename LScalar, typename RScalar>
Product<Scalar, Product<LScalar, RScalar>>
operator()(FixedVector<LScalar, columns_> const& left,
Expand Down Expand Up @@ -214,6 +215,7 @@ class FixedUpperTriangularMatrix final {
std::array<Scalar, size()> data_;
};

// Prefer using the operator* that takes a TransposedView.
template<typename LScalar, typename RScalar, int size>
constexpr Product<LScalar, RScalar> InnerProduct(
FixedVector<LScalar, size> const& left,
Expand Down Expand Up @@ -363,7 +365,8 @@ constexpr FixedVector<Product<LScalar, RScalar>, rows> operator*(
FixedMatrix<LScalar, rows, columns> const& left,
FixedVector<RScalar, columns> const& right);

// Use this operator to multiply a row vector with a matrix.
// Use this operator to multiply a row vector with a matrix. We don't have an
// operator returning a TransposedView as that would cause dangling references.
template<typename LScalar, typename RScalar, int rows, int columns>
constexpr FixedVector<Product<LScalar, RScalar>, columns> operator*(
TransposedView<FixedMatrix<LScalar, rows, columns>> const& left,
Expand All @@ -379,6 +382,11 @@ template<typename Scalar, int rows, int columns>
std::ostream& operator<<(std::ostream& out,
FixedMatrix<Scalar, rows, columns> const& matrix);

template<typename Scalar, int rows>
std::ostream& operator<<(
std::ostream& out,
FixedStrictlyLowerTriangularMatrix<Scalar, rows> const& matrix);

template<typename Scalar, int rows>
std::ostream& operator<<(
std::ostream& out,
Expand Down
44 changes: 28 additions & 16 deletions numerics/fixed_arrays_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ constexpr FixedVector<Scalar, size_>::FixedVector(
std::array<Scalar, size_>&& data)
: data_(std::move(data)) {}

template<typename Scalar, int size_>
TransposedView<FixedVector<Scalar, size_>>
FixedVector<Scalar, size_>::Transpose() const {
return {.transpose = *this};
}

template<typename Scalar, int size_>
Scalar FixedVector<Scalar, size_>::Norm() const {
return Sqrt(Norm²());
Expand Down Expand Up @@ -167,15 +161,6 @@ Scalar const* FixedMatrix<Scalar, rows_, columns_>::row() const {
return &data_[r * columns()];
}

template<typename Scalar, int rows_, int columns_>
template<typename LScalar, typename RScalar>
Product<Scalar, Product<LScalar, RScalar>>
FixedMatrix<Scalar, rows_, columns_>::operator()(
FixedVector<LScalar, columns_> const& left,
FixedVector<RScalar, rows_> const& right) const {
return left.Transpose() * (*this * right);
}

template<typename Scalar, int rows_, int columns_>
FixedMatrix<Scalar, rows_, columns_>
FixedMatrix<Scalar, rows_, columns_>::Transpose() const {
Expand Down Expand Up @@ -211,6 +196,15 @@ bool FixedMatrix<Scalar, rows_, columns_>::operator!=(
return data_ != right.data_;
}

template<typename Scalar, int rows_, int columns_>
template<typename LScalar, typename RScalar>
Product<Scalar, Product<LScalar, RScalar>>
FixedMatrix<Scalar, rows_, columns_>::operator()(
FixedVector<LScalar, columns_> const& left,
FixedVector<RScalar, rows_> const& right) const {
return TransposedView{left} * (*this * right); // NOLINT
}

template<typename Scalar, int rows_, int columns_>
FixedMatrix<Scalar, rows_, columns_>
FixedMatrix<Scalar, rows_, columns_>::Identity() {
Expand Down Expand Up @@ -435,7 +429,7 @@ constexpr FixedVector<Scalar, size> operator-(
for (int i = 0; i < size; ++i) {
result[i] = -right[i];
}
return FixedVector<Difference<Scalar>, size>(std::move(result));
return FixedVector<Scalar, size>(std::move(result));
}

template<typename Scalar, int rows, int columns>
Expand Down Expand Up @@ -733,6 +727,24 @@ std::ostream& operator<<(std::ostream& out,
return out;
}

template<typename Scalar, int rows>
std::ostream& operator<<(
std::ostream& out,
FixedStrictlyLowerTriangularMatrix<Scalar, rows> const& matrix) {
out << "rows: " << matrix.rows() << "\n";
for (int i = 0; i < matrix.rows(); ++i) {
out << "{";
for (int j = 0; j < i; ++j) {
out << matrix(i, j);
if (j < i - 1) {
out << ", ";
}
}
out << "}\n";
}
return out;
}

template<typename Scalar, int rows>
std::ostream& operator<<(
std::ostream& out,
Expand Down
6 changes: 3 additions & 3 deletions numerics/fixed_arrays_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TEST_F(FixedArraysTest, Assignment) {
}

TEST_F(FixedArraysTest, Norm) {
EXPECT_EQ(35, v4_.Transpose() * v4_);
EXPECT_EQ(35, TransposedView{v4_} * v4_); // NOLINT
EXPECT_EQ(Sqrt(35.0), v4_.Norm());
EXPECT_EQ(35, v4_.Norm²());
EXPECT_EQ(Sqrt(517.0), m34_.FrobeniusNorm());
Expand Down Expand Up @@ -121,11 +121,11 @@ TEST_F(FixedArraysTest, VectorSpaces) {
}

TEST_F(FixedArraysTest, Algebra) {
EXPECT_EQ(-535, u3_.Transpose() * v3_);
EXPECT_EQ(-535, TransposedView{u3_} * v3_); // NOLINT
EXPECT_EQ((FixedMatrix<double, 3, 4>({-30, -30, 10, 40,
-93, -93, 31, 124,
141, 141, -47, -188})),
v3_ * v4_.Transpose());
v3_ * TransposedView{v4_});
EXPECT_EQ((FixedMatrix<double, 2, 4>({ 0, 14, -22, 3,
14, -63, 5, -92})),
m23_ * m34_);
Expand Down
2 changes: 1 addition & 1 deletion numerics/matrix_computations_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ template<typename Matrix, typename Vector>
typename RayleighQuotientGenerator<Matrix, Vector>::Result
RayleighQuotient(Matrix const& A, Vector const& x) {
// [GV13], section 8.2.3.
return x.Transpose() * (A * x) / (x.Transpose() * x);
return TransposedView{x} * (A * x) / (TransposedView{x} * x); // NOLINT
}

template<typename Matrix, typename Vector>
Expand Down
Loading

0 comments on commit ff51614

Please sign in to comment.