diff --git a/CMakeLists.txt b/CMakeLists.txt index 02fd6f76..4adf50d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,7 +78,7 @@ endif() include(Flake8) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/contrib/standalone_buffer/Makefile b/contrib/standalone_buffer/Makefile index f0e7cbe5..ef39cbbc 100644 --- a/contrib/standalone_buffer/Makefile +++ b/contrib/standalone_buffer/Makefile @@ -77,5 +77,7 @@ copy: cp $(MMROOT)/cpp/modmesh/python/common.hpp $(SETUPROOT)/modmesh/python/common.hpp cp -a $(MMROOT)/cpp/modmesh/toggle $(SETUPROOT)/modmesh/toggle rm -rf $(SETUPROOT)/modmesh/toggle/pymod + cp -a $(MMROOT)/cpp/modmesh/math $(SETUPROOT)/modmesh/math + rm -rf $(SETUPROOT)/modmesh/math/pymod cp $(MMROOT)/cpp/modmesh/base.hpp $(SETUPROOT)/modmesh/ find $(SETUPROOT) -name CMakeLists.txt -delete \ No newline at end of file diff --git a/cpp/modmesh/buffer/SimpleArray.cpp b/cpp/modmesh/buffer/SimpleArray.cpp index ec785a97..8720b40a 100644 --- a/cpp/modmesh/buffer/SimpleArray.cpp +++ b/cpp/modmesh/buffer/SimpleArray.cpp @@ -27,6 +27,7 @@ */ #include +#include #include @@ -65,7 +66,9 @@ static std::unordered_map string_data_typ {"uint32", DataType::Uint32}, {"uint64", DataType::Uint64}, {"float32", DataType::Float32}, - {"float64", DataType::Float64}}; + {"float64", DataType::Float64}, + {"ComplexFloat32", DataType::ComplexFloat32}, + {"ComplexFloat64", DataType::ComplexFloat64}}; } /* end namespace detail */ @@ -145,6 +148,18 @@ DataType DataType::from() return DataType::Float64; } +template <> +DataType DataType::from>() +{ + return DataType::ComplexFloat32; +} + +template <> +DataType DataType::from>() +{ + return DataType::ComplexFloat64; +} + // According to the `DataType`, create the corresponding `SimpleArray` instance // and assign it to `m_instance_ptr`. The `m_instance_ptr` is a void pointer, so // we need to use `reinterpret_cast` to convert the pointer of the array instance. @@ -171,6 +186,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const DataType data_t DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape) DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape) DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape) + DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape) + DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape) default: throw std::invalid_argument("Unsupported datatype"); } @@ -193,6 +210,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const std::shared_ptr DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape, buffer) DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape, buffer) DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape, buffer) + DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape, buffer) + DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape, buffer) default: throw std::invalid_argument("Unsupported datatype"); } @@ -289,6 +308,20 @@ SimpleArrayPlex::SimpleArrayPlex(SimpleArrayPlex const & other) m_instance_ptr = reinterpret_cast(new SimpleArrayFloat64(*array)); break; } + case DataType::ComplexFloat32: + { + const auto * array = static_cast(other.m_instance_ptr); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + m_instance_ptr = reinterpret_cast(new SimpleArrayComplexFloat32(*array)); + break; + } + case DataType::ComplexFloat64: + { + const auto * array = static_cast(other.m_instance_ptr); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + m_instance_ptr = reinterpret_cast(new SimpleArrayComplexFloat64(*array)); + break; + } default: { throw std::invalid_argument("Unsupported datatype"); @@ -406,6 +439,20 @@ SimpleArrayPlex & SimpleArrayPlex::operator=(SimpleArrayPlex const & other) m_instance_ptr = reinterpret_cast(new SimpleArrayFloat64(*array)); break; } + case DataType::ComplexFloat32: + { + const auto * array = static_cast(other.m_instance_ptr); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + m_instance_ptr = reinterpret_cast(new SimpleArrayComplexFloat32(*array)); + break; + } + case DataType::ComplexFloat64: + { + const auto * array = static_cast(other.m_instance_ptr); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + m_instance_ptr = reinterpret_cast(new SimpleArrayComplexFloat64(*array)); + break; + } default: { throw std::invalid_argument("Unsupported datatype"); @@ -506,6 +553,18 @@ SimpleArrayPlex::~SimpleArrayPlex() delete reinterpret_cast(m_instance_ptr); break; } + case DataType::ComplexFloat32: + { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + delete reinterpret_cast(m_instance_ptr); + break; + } + case DataType::ComplexFloat64: + { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + delete reinterpret_cast(m_instance_ptr); + break; + } default: break; } diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 7e398b05..dcfe4af5 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -29,6 +29,7 @@ */ #include +#include #include #include @@ -160,7 +161,16 @@ class SimpleArrayMixinCalculators value_type sum() const { - value_type initial = 0; + value_type initial; + if constexpr (is_complex_v) + { + initial = value_type(); + } + else + { + initial = 0; + } + auto athis = static_cast(this); if constexpr (!std::is_same_v>) { @@ -748,6 +758,8 @@ using SimpleArrayUint32 = SimpleArray; using SimpleArrayUint64 = SimpleArray; using SimpleArrayFloat32 = SimpleArray; using SimpleArrayFloat64 = SimpleArray; +using SimpleArrayComplexFloat32 = SimpleArray>; +using SimpleArrayComplexFloat64 = SimpleArray>; class DataType { @@ -766,6 +778,8 @@ class DataType Uint64, Float32, Float64, + ComplexFloat32, + ComplexFloat64 }; /* end enum enum_type */ DataType() = default; diff --git a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp index 9bb105af..55ac52c2 100644 --- a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp +++ b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp @@ -29,6 +29,7 @@ #include #include +#include /** * The purpose of including this header is to facilitate implicit casting of @@ -97,6 +98,8 @@ DECL_MM_SIMPLE_ARRAY_CASTER(Uint32); DECL_MM_SIMPLE_ARRAY_CASTER(Uint64); DECL_MM_SIMPLE_ARRAY_CASTER(Float32); DECL_MM_SIMPLE_ARRAY_CASTER(Float64); +DECL_MM_SIMPLE_ARRAY_CASTER(ComplexFloat32); +DECL_MM_SIMPLE_ARRAY_CASTER(ComplexFloat64); #undef DECL_MM_SIMPLE_ARRAY_CASTER diff --git a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp index 60cb7d8c..e2005c23 100644 --- a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp +++ b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp @@ -29,6 +29,7 @@ */ #include +#include #include #include // Must be the first include. @@ -71,8 +72,17 @@ struct TypeBroadcastImpl offset_out += arr_out.stride(it) * sidx[it] * step; } - // NOLINTNEXTLINE(bugprone-signed-char-misuse, cert-str34-c) - arr_out.at(offset_out) = static_cast(*ptr_in); + constexpr bool valid_conversion = (!is_complex_v && !is_complex_v) || (is_complex_v && is_complex_v && std::is_same_v); + + if constexpr (valid_conversion) + { + arr_out.at(offset_out) = static_cast(*ptr_in); + } + else + { + throw std::runtime_error("Cannot convert between complex and non-complex types"); + } + // recursion here copy_idx(arr_out, slices, arr_in, left_shape, sidx, dim - 1); } @@ -197,6 +207,14 @@ struct TypeBroadcast { TypeBroadcastImpl::broadcast(arr_out, slices, arr_in); } + else if (dtype_is_type>(arr_in)) + { + TypeBroadcastImpl>::broadcast(arr_out, slices, arr_in); + } + else if (dtype_is_type>(arr_in)) + { + TypeBroadcastImpl>::broadcast(arr_out, slices, arr_in); + } else { throw std::runtime_error("input array data type not support!"); diff --git a/cpp/modmesh/buffer/pymod/array_common.hpp b/cpp/modmesh/buffer/pymod/array_common.hpp index 0f75f3e7..f2e3172b 100644 --- a/cpp/modmesh/buffer/pymod/array_common.hpp +++ b/cpp/modmesh/buffer/pymod/array_common.hpp @@ -32,6 +32,7 @@ #include #include +#include // We faced an issue where the template specialization for the caster of // SimpleArray doesn't function correctly on both macOS and Windows. @@ -98,7 +99,7 @@ class ArrayPropertyHelper const py::object & py_key = args[0]; const py::object & py_value = args[1]; - const bool is_number = py::isinstance(py_value) || py::isinstance(py_value) || py::isinstance(py_value); + const bool is_number = py::isinstance(py_value) || py::isinstance(py_value) || py::isinstance(py_value) || is_complex_v; // sarr[K] = V if (py::isinstance(py_key) && is_number) diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 24621801..2a5e8d36 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -255,6 +255,8 @@ void wrap_SimpleArray(pybind11::module & mod) WrapSimpleArray::commit(mod, "SimpleArrayUint64", "SimpleArrayUint64"); WrapSimpleArray::commit(mod, "SimpleArrayFloat32", "SimpleArrayFloat32"); WrapSimpleArray::commit(mod, "SimpleArrayFloat64", "SimpleArrayFloat64"); + WrapSimpleArray>::commit(mod, "SimpleArrayComplexFloat32", "SimpleArrayComplexFloat32"); + WrapSimpleArray>::commit(mod, "SimpleArrayComplexFloat64", "SimpleArrayComplexFloat64"); WrapSimpleCollector::commit(mod, "SimpleCollectorBool", "SimpleCollectorBool"); WrapSimpleCollector::commit(mod, "SimpleCollectorInt8", "SimpleCollectorInt8"); @@ -267,6 +269,8 @@ void wrap_SimpleArray(pybind11::module & mod) WrapSimpleCollector::commit(mod, "SimpleCollectorUint64", "SimpleCollectorUint64"); WrapSimpleCollector::commit(mod, "SimpleCollectorFloat32", "SimpleCollectorFloat32"); WrapSimpleCollector::commit(mod, "SimpleCollectorFloat64", "SimpleCollectorFloat64"); + WrapSimpleCollector>::commit(mod, "SimpleCollectorComplexFloat32", "SimpleCollectorComplexFloat32"); + WrapSimpleCollector>::commit(mod, "SimpleCollectorComplexFloat64", "SimpleCollectorComplexFloat64"); } } /* end namespace python */ diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArrayPlex.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArrayPlex.cpp index 25b0fbe9..a5a414d1 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArrayPlex.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArrayPlex.cpp @@ -70,6 +70,8 @@ static auto execute_callback_with_typed_array(A & arrayplex, C && callback) DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64) DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32) DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64) + DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32) + DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64) default: { throw std::invalid_argument("Unsupported datatype"); @@ -117,6 +119,21 @@ static void verify_python_value_datatype(pybind11::object const & value, DataTyp } break; } + case DataType::ComplexFloat32: + { + if (!pybind11::isinstance>(value)) + { + throw pybind11::type_error("Data type mismatch, expected complex float"); + } + } + case DataType::ComplexFloat64: + { + if (!pybind11::isinstance>(value)) + { + throw pybind11::type_error("Data type mismatch, expected complex double"); + } + } + default: throw std::runtime_error("Unsupported datatype"); } @@ -148,6 +165,8 @@ static pybind11::object get_typed_array_value(const SimpleArrayPlex & array_plex DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Uint64, SimpleArrayUint64) DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float32, SimpleArrayFloat32) DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float64, SimpleArrayFloat64) + DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat32, SimpleArrayComplexFloat32) + DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat64, SimpleArrayComplexFloat64) default: { throw std::runtime_error("Unsupported datatype"); @@ -180,6 +199,8 @@ static pybind11::object get_typed_array(const SimpleArrayPlex & array_plex) DECL_MM_GET_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64) DECL_MM_GET_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32) DECL_MM_GET_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64) + DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32) + DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64) default: { throw std::runtime_error("Unsupported datatype"); diff --git a/cpp/modmesh/math/Complex.hpp b/cpp/modmesh/math/Complex.hpp index f521950f..8c0aa8ee 100644 --- a/cpp/modmesh/math/Complex.hpp +++ b/cpp/modmesh/math/Complex.hpp @@ -28,6 +28,7 @@ #include #include +#include namespace modmesh { @@ -40,50 +41,68 @@ struct ComplexImpl T real_v; T imag_v; - explicit ComplexImpl() - : ComplexImpl(0.0, 0.0) + ComplexImpl operator+(const ComplexImpl & other) const { + ComplexImpl ret(*this); + return ret += other; } - explicit ComplexImpl(T r, T i) - : real_v(r) - , imag_v(i) + ComplexImpl operator-(const ComplexImpl & other) const { + ComplexImpl ret(*this); + return ret -= other; } - explicit ComplexImpl(const ComplexImpl & other) - : real_v(other.real_v) - , imag_v(other.imag_v) + ComplexImpl operator*(const ComplexImpl & other) const { + ComplexImpl ret(*this); + ret *= other; + return ret; } - ComplexImpl operator+(const ComplexImpl & other) const + ComplexImpl operator/(const ComplexImpl & other) const { - return ComplexImpl(real_v + other.real_v, imag_v + other.imag_v); + ComplexImpl ret(*this); + return ret /= other; } - ComplexImpl operator-(const ComplexImpl & other) const + ComplexImpl operator/(const T & other) const { - return ComplexImpl(real_v - other.real_v, imag_v - other.imag_v); + ComplexImpl ret(*this); + return ret /= other; } - ComplexImpl operator*(const ComplexImpl & other) const + ComplexImpl & operator*=(const ComplexImpl & rhs) { - return ComplexImpl(real_v * other.real_v - imag_v * other.imag_v, real_v * other.imag_v + imag_v * other.real_v); + T real_v_copy = real_v; + real_v = real_v * rhs.real_v - imag_v * rhs.imag_v; + imag_v = real_v_copy * rhs.imag_v + imag_v * rhs.real_v; + return *this; } - ComplexImpl operator/(const T & other) const + ComplexImpl & operator/=(const ComplexImpl & rhs) { - return ComplexImpl(real_v / other, imag_v / other); + T denominator = rhs.norm(); + T real_v_copy = real_v; + + if (denominator == 0.0) + { + throw std::runtime_error("Division by zero in complex number"); + } + + real_v = (real_v * rhs.real_v + imag_v * rhs.imag_v) / denominator; + imag_v = (imag_v * rhs.real_v - real_v_copy * rhs.imag_v) / denominator; + return *this; } - ComplexImpl & operator=(const ComplexImpl & other) + ComplexImpl & operator/=(const T & rhs) { - if (this != &other) // Check for self-assignment + if (rhs == 0.0) { - real_v = other.real_v; - imag_v = other.imag_v; + throw std::runtime_error("Division by zero in complex number"); } + real_v /= rhs; + imag_v /= rhs; return *this; } @@ -101,6 +120,16 @@ struct ComplexImpl return *this; } + bool operator<(const ComplexImpl & rhs) + { + return this->norm() < rhs.norm(); + } + + bool operator>(const ComplexImpl & rhs) + { + return this->norm() > rhs.norm(); + } + T real() const { return real_v; } T imag() const { return imag_v; } T norm() const { return real_v * real_v + imag_v * imag_v; } @@ -111,6 +140,29 @@ struct ComplexImpl template using Complex = detail::ComplexImpl; +template +bool operator<(const Complex & lhs, const Complex & rhs) +{ + return lhs.norm() < rhs.norm(); +} + +template +bool operator>(const Complex & lhs, const Complex & rhs) +{ + return lhs.norm() > rhs.norm(); +} + +// clang-format off +template +struct is_complex : std::false_type {}; + +template +struct is_complex> : std::true_type {}; +// clang-format on + +template +constexpr bool is_complex_v = is_complex::value; + } /* end namespace modmesh */ // vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: diff --git a/cpp/modmesh/math/pymod/wrap_Complex.cpp b/cpp/modmesh/math/pymod/wrap_Complex.cpp index 25548dfa..8e71e82b 100644 --- a/cpp/modmesh/math/pymod/wrap_Complex.cpp +++ b/cpp/modmesh/math/pymod/wrap_Complex.cpp @@ -30,6 +30,7 @@ #include +#include #include namespace modmesh @@ -52,31 +53,42 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapComplex { namespace py = pybind11; // NOLINT(misc-unused-alias-decls) + PYBIND11_NUMPY_DTYPE(wrapped_type, real_v, imag_v); + (*this) - .def( - py::init( - [](const value_type & real_v, const value_type & imag_v) - { return std::make_shared(real_v, imag_v); }), - py::arg("real_v"), - py::arg("imag_v")) .def( py::init( []() { return std::make_shared(); })) + .def( + py::init( + [](const value_type & real, const value_type & imag) + { return std::make_shared(wrapped_type{.real_v = real, .imag_v = imag}); }), + py::arg("real"), + py::arg("imag")) .def( py::init( [](const wrapped_type & other) - { return std::make_shared(other); }), + { return std::make_shared(wrapped_type{.real_v = other.real_v, .imag_v = other.imag_v}); }), py::arg("other")) .def(py::self + py::self) // NOLINT(misc-redundant-expression) .def(py::self - py::self) // NOLINT(misc-redundant-expression) .def(py::self * py::self) // NOLINT(misc-redundant-expression) + .def(py::self / py::self) // NOLINT(misc-redundant-expression) .def(py::self / value_type()) // NOLINT(misc-redundant-expression) .def(py::self += py::self) // NOLINT(misc-redundant-expression) .def(py::self -= py::self) // NOLINT(misc-redundant-expression) - .def_property_readonly("real", &wrapped_type::real) - .def_property_readonly("imag", &wrapped_type::imag) + .def(py::self *= py::self) // NOLINT(misc-redundant-expression) + .def(py::self /= py::self) // NOLINT(misc-redundant-expression) + .def(py::self /= value_type()) // NOLINT(misc-redundant-expression) + .def("__lt__", &wrapped_type::operator<) + .def("__gt__", &wrapped_type::operator>) + .def_readonly("real", &wrapped_type::real_v) + .def_readonly("imag", &wrapped_type::imag_v) .def("norm", &wrapped_type::norm) + .def("dtype", + []() + { return py::dtype::of(); }) .def("__complex__", [](wrapped_type const & self) { return std::complex(self.real(), self.imag()); }); diff --git a/cpp/modmesh/python/common.hpp b/cpp/modmesh/python/common.hpp index d2e9cff7..d5b55a47 100644 --- a/cpp/modmesh/python/common.hpp +++ b/cpp/modmesh/python/common.hpp @@ -60,7 +60,7 @@ std::string to_str(T const & self) { return Formatter() << self >> Formatter::to template bool dtype_is_type(pybind11::array const & arr) { - return pybind11::detail::npy_format_descriptor::dtype().is(arr.dtype()); + return pybind11::detail::npy_format_descriptor::dtype().equal(arr.dtype()); } class WrapperProfilerStatus diff --git a/cpp/modmesh/transform/fourier.hpp b/cpp/modmesh/transform/fourier.hpp index 6854c6ef..c1e37d7d 100644 --- a/cpp/modmesh/transform/fourier.hpp +++ b/cpp/modmesh/transform/fourier.hpp @@ -26,7 +26,7 @@ void dft(SimpleArray> const & in, SimpleArray> & out) for (size_t j = 0; j < N; ++j) { T2 tmp = -2.0 * pi * i * j / N; - T1 twiddle_factor(std::cos(tmp), std::sin(tmp)); + T1 twiddle_factor{.real_v = std::cos(tmp), .imag_v = std::sin(tmp)}; out[i] += in[j] * twiddle_factor; } @@ -59,7 +59,7 @@ void fft(SimpleArray> const & in, SimpleArray> & out) { // Twiddle factor = exp(-2 * pi * i * k / N) T2 angle = angle_inc * k; - T1 twiddle_factor(std::cos(angle), std::sin(angle)); + T1 twiddle_factor{.real_v = std::cos(angle), .imag_v = std::sin(angle)}; T1 even(out[i + k]); T1 odd(out[i + k + half_size] * twiddle_factor); diff --git a/gtests/test_nopython_transform.cpp b/gtests/test_nopython_transform.cpp index 0b161860..da9bb746 100644 --- a/gtests/test_nopython_transform.cpp +++ b/gtests/test_nopython_transform.cpp @@ -14,9 +14,9 @@ class ParsevalTest : public ::testing::Test const size_t VN = 1024; modmesh::SimpleArray> signal{ - modmesh::small_vector{VN}, modmesh::Complex(0.0, 0.0)}; + modmesh::small_vector{VN}, modmesh::Complex{.real_v = 0.0, .imag_v = 0.0}}; modmesh::SimpleArray> out{ - modmesh::small_vector{VN}, modmesh::Complex(0.0, 0.0)}; + modmesh::small_vector{VN}, modmesh::Complex{.real_v = 0.0, .imag_v = 0.0}}; // Set up the test fixture: generate the signal once void SetUp() override @@ -26,7 +26,7 @@ class ParsevalTest : public ::testing::Test for (unsigned int i = 0; i < VN; ++i) { T val = val_dist(rng); - signal[i] = modmesh::Complex(val, 0.0); + signal[i] = modmesh::Complex{.real_v = val, .imag_v = 0.0}; } } @@ -59,13 +59,13 @@ class DeltaFunctionTest : public ::testing::Test std::mt19937 rng{std::random_device{}()}; modmesh::SimpleArray> signal{ - modmesh::small_vector{VN}, modmesh::Complex(0.0, 0.0)}; + modmesh::small_vector{VN}, modmesh::Complex{.real_v = 0.0, .imag_v = 0.0}}; modmesh::SimpleArray> out{ - modmesh::small_vector{VN}, modmesh::Complex(0.0, 0.0)}; + modmesh::small_vector{VN}, modmesh::Complex{.real_v = 0.0, .imag_v = 0.0}}; void SetUp() override { - signal[0] = modmesh::Complex(1.0, 0.0); + signal[0] = modmesh::Complex{.real_v = 1.0, .imag_v = 0.0}; } void verify_delta_function() diff --git a/modmesh/core.py b/modmesh/core.py index 6adde8ea..3944f87e 100644 --- a/modmesh/core.py +++ b/modmesh/core.py @@ -65,6 +65,8 @@ 'SimpleArrayUint64', 'SimpleArrayFloat32', 'SimpleArrayFloat64', + 'SimpleArrayComplexFloat32', + 'SimpleArrayComplexFloat64', 'SimpleCollectorBool', 'SimpleCollectorInt8', 'SimpleCollectorInt16', diff --git a/tests/test_math.py b/tests/test_math.py index c20d61a4..af186db5 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -175,6 +175,58 @@ def test_operator_div_float64_scalar(self): self.assert_allclose64(result.real, expected_real) self.assert_allclose64(result.imag, expected_imag) + def test_operator_div_float32(self): + cplx1 = mm.ComplexFloat32(self.real1_32, self.imag1_32) + cplx2 = mm.ComplexFloat32(self.real2_32, self.imag2_32) + + result = cplx1 / cplx2 + + denominator = (self.real2_32 * self.real2_32 + self.imag2_32 * + self.imag2_32) + expected_real = (self.real1_32 * self.real2_32 + + self.imag1_32 * self.imag2_32) / denominator + expected_imag = (self.imag1_32 * self.real2_32 - + self.real1_32 * self.imag2_32) / denominator + + self.assert_allclose32(result.real, expected_real) + self.assert_allclose32(result.imag, expected_imag) + + def test_operator_div_float64(self): + cplx1 = mm.ComplexFloat64(self.real1_64, self.imag1_64) + cplx2 = mm.ComplexFloat64(self.real2_64, self.imag2_64) + + result = cplx1 / cplx2 + + denominator = (self.real2_64 * self.real2_64 + self.imag2_64 * + self.imag2_64) + expected_real = (self.real1_64 * self.real2_64 + self.imag1_64 * + self.imag2_64) / denominator + expected_imag = (self.imag1_64 * self.real2_64 - self.real1_64 * + self.imag2_64) / denominator + + self.assert_allclose64(result.real, expected_real) + self.assert_allclose64(result.imag, expected_imag) + + def test_operator_comparison_float32(self): + cplx1 = mm.ComplexFloat32(self.real1_32, self.imag1_32) + cplx2 = mm.ComplexFloat32(self.real2_32, self.imag2_32) + + norm1 = cplx1.norm() + norm2 = cplx2.norm() + + self.assertEqual(cplx1 < cplx2, norm1 < norm2) + self.assertEqual(cplx1 > cplx2, norm1 > norm2) + + def test_operator_comparison_float64(self): + cplx1 = mm.ComplexFloat64(self.real1_64, self.imag1_64) + cplx2 = mm.ComplexFloat64(self.real2_64, self.imag2_64) + + norm1 = cplx1.norm() + norm2 = cplx2.norm() + + self.assertEqual(cplx1 < cplx2, norm1 < norm2) + self.assertEqual(cplx1 > cplx2, norm1 > norm2) + def test_norm_float32(self): cplx = mm.ComplexFloat32(self.real1_32, self.imag1_32) @@ -192,3 +244,43 @@ def test_norm_float64(self): expected_val = self.real1_64 ** 2 + self.imag1_64 ** 2 self.assert_allclose64(result, expected_val) + + def test_dtype_verification_float32(self): + dtype = mm.ComplexFloat32.dtype() + expected_dtype = np.dtype([('real_v', np.float32), + ('imag_v', np.float32)]) + + self.assertEqual(dtype, expected_dtype) + + def test_dtype_verification_float64(self): + dtype = mm.ComplexFloat64.dtype() + expected_dtype = np.dtype([('real_v', np.float64), + ('imag_v', np.float64)]) + + self.assertEqual(dtype, expected_dtype) + + def test_complex_array_float32(self): + cplx = mm.ComplexFloat32(self.real1_32, self.imag1_32) + sarr = mm.SimpleArrayComplexFloat32(10) + sarr.fill(cplx) + ndarr = np.array(sarr, copy=False, dtype=mm.ComplexFloat32.dtype()) + + self.assertEqual(ndarr.dtype, mm.ComplexFloat32.dtype()) + + sarr = mm.SimpleArrayComplexFloat32(array=ndarr) + + self.assertEqual(sarr.ndarray.dtype, ndarr.dtype) + self.assertEqual(10 * 4 * 2, sarr.nbytes) + + def test_complex_array_float64(self): + cplx = mm.ComplexFloat64(self.real1_64, self.imag1_64) + sarr = mm.SimpleArrayComplexFloat64(10) + sarr.fill(cplx) + ndarr = np.array(sarr, copy=False, dtype=mm.ComplexFloat64.dtype()) + + self.assertEqual(ndarr.dtype, mm.ComplexFloat64.dtype()) + + sarr = mm.SimpleArrayComplexFloat64(array=ndarr) + + self.assertEqual(sarr.ndarray.dtype, ndarr.dtype) + self.assertEqual(10 * 8 * 2, sarr.nbytes)