Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add complex array support #468

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ endif()

include(Flake8)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I have used designate initialiser, this feature is supported by c++20, therefore change it to c++ standard 20.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that you have to use designated initializer? Please let us know what it is.

If there is not, then just don't use designated initializer. I have no problem using it, but it alone is not a good reason to upgrade to C++20.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to mark the initial value for each structure member during structure initialization, but I did not consider what side effects upgrading to C++20 may bring.

set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
2 changes: 2 additions & 0 deletions contrib/standalone_buffer/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,7 @@ copy:
cp $(MMROOT)/cpp/modmesh/python/common.hpp $(SETUPROOT)/modmesh/python/common.hpp
cp -a $(MMROOT)/cpp/modmesh/toggle $(SETUPROOT)/modmesh/toggle
rm -rf $(SETUPROOT)/modmesh/toggle/pymod
cp -a $(MMROOT)/cpp/modmesh/math $(SETUPROOT)/modmesh/math
rm -rf $(SETUPROOT)/modmesh/math/pymod
cp $(MMROOT)/cpp/modmesh/base.hpp $(SETUPROOT)/modmesh/
find $(SETUPROOT) -name CMakeLists.txt -delete
61 changes: 60 additions & 1 deletion cpp/modmesh/buffer/SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
*/

#include <modmesh/buffer/SimpleArray.hpp>
#include <modmesh/math/math.hpp>

#include <unordered_map>

Expand Down Expand Up @@ -65,7 +66,9 @@ static std::unordered_map<std::string, DataType, DataTypeHasher> string_data_typ
{"uint32", DataType::Uint32},
{"uint64", DataType::Uint64},
{"float32", DataType::Float32},
{"float64", DataType::Float64}};
{"float64", DataType::Float64},
{"ComplexFloat32", DataType::ComplexFloat32},
Copy link
Collaborator

@tigercosmos tigercosmos Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just call it "complex32".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea to drop Float from the complex type name, but we should use the correct bit count. For a complex number using two 32-bit float, it should be called complex64. A complex using two 64-bit float, it should be called complex128.

It follows numpy convention: https://numpy.org/doc/stable/user/basics.types.html#relationship-between-numpy-data-types-and-c-data-types

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redarding this: #468 (comment)
@j8xixo12 I meant you should change the name here.

{"ComplexFloat64", DataType::ComplexFloat64}};

} /* end namespace detail */

Expand Down Expand Up @@ -145,6 +148,18 @@ DataType DataType::from<double>()
return DataType::Float64;
}

template <>
DataType DataType::from<Complex<float>>()
{
return DataType::ComplexFloat32;
}

template <>
DataType DataType::from<Complex<double>>()
{
return DataType::ComplexFloat64;
}

// According to the `DataType`, create the corresponding `SimpleArray<T>` instance
// and assign it to `m_instance_ptr`. The `m_instance_ptr` is a void pointer, so
// we need to use `reinterpret_cast` to convert the pointer of the array instance.
Expand All @@ -171,6 +186,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const DataType data_t
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape)
default:
throw std::invalid_argument("Unsupported datatype");
}
Expand All @@ -193,6 +210,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const std::shared_ptr
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape, buffer)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape, buffer)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape, buffer)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape, buffer)
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape, buffer)
default:
throw std::invalid_argument("Unsupported datatype");
}
Expand Down Expand Up @@ -289,6 +308,20 @@ SimpleArrayPlex::SimpleArrayPlex(SimpleArrayPlex const & other)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayFloat64(*array));
break;
}
case DataType::ComplexFloat32:
{
const auto * array = static_cast<SimpleArrayComplexFloat32 *>(other.m_instance_ptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat32(*array));
break;
}
case DataType::ComplexFloat64:
{
const auto * array = static_cast<SimpleArrayComplexFloat64 *>(other.m_instance_ptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat64(*array));
break;
}
default:
{
throw std::invalid_argument("Unsupported datatype");
Expand Down Expand Up @@ -406,6 +439,20 @@ SimpleArrayPlex & SimpleArrayPlex::operator=(SimpleArrayPlex const & other)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayFloat64(*array));
break;
}
case DataType::ComplexFloat32:
{
const auto * array = static_cast<SimpleArrayComplexFloat32 *>(other.m_instance_ptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat32(*array));
break;
}
case DataType::ComplexFloat64:
{
const auto * array = static_cast<SimpleArrayComplexFloat64 *>(other.m_instance_ptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat64(*array));
break;
}
default:
{
throw std::invalid_argument("Unsupported datatype");
Expand Down Expand Up @@ -506,6 +553,18 @@ SimpleArrayPlex::~SimpleArrayPlex()
delete reinterpret_cast<SimpleArrayFloat64 *>(m_instance_ptr);
break;
}
case DataType::ComplexFloat32:
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
delete reinterpret_cast<SimpleArrayComplexFloat32 *>(m_instance_ptr);
break;
}
case DataType::ComplexFloat64:
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
delete reinterpret_cast<SimpleArrayComplexFloat64 *>(m_instance_ptr);
break;
}
default:
break;
}
Expand Down
16 changes: 15 additions & 1 deletion cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*/

#include <modmesh/buffer/ConcreteBuffer.hpp>
#include <modmesh/math/math.hpp>

#include <limits>
#include <stdexcept>
Expand Down Expand Up @@ -160,7 +161,16 @@ class SimpleArrayMixinCalculators

value_type sum() const
{
value_type initial = 0;
value_type initial;
if constexpr (is_complex_v<value_type>)
{
initial = value_type();
}
else
{
initial = 0;
}
Comment on lines +165 to +172
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the element type is complex, different initialization should be performed here.


auto athis = static_cast<A const *>(this);
if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>)
{
Expand Down Expand Up @@ -748,6 +758,8 @@ using SimpleArrayUint32 = SimpleArray<uint32_t>;
using SimpleArrayUint64 = SimpleArray<uint64_t>;
using SimpleArrayFloat32 = SimpleArray<float>;
using SimpleArrayFloat64 = SimpleArray<double>;
using SimpleArrayComplexFloat32 = SimpleArray<Complex<float>>;
using SimpleArrayComplexFloat64 = SimpleArray<Complex<double>>;

class DataType
{
Expand All @@ -766,6 +778,8 @@ class DataType
Uint64,
Float32,
Float64,
ComplexFloat32,
ComplexFloat64
}; /* end enum enum_type */

DataType() = default;
Expand Down
3 changes: 3 additions & 0 deletions cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <pybind11/pybind11.h>

#include <modmesh/buffer/buffer.hpp>
#include <modmesh/math/math.hpp>

/**
* The purpose of including this header is to facilitate implicit casting of
Expand Down Expand Up @@ -97,6 +98,8 @@ DECL_MM_SIMPLE_ARRAY_CASTER(Uint32);
DECL_MM_SIMPLE_ARRAY_CASTER(Uint64);
DECL_MM_SIMPLE_ARRAY_CASTER(Float32);
DECL_MM_SIMPLE_ARRAY_CASTER(Float64);
DECL_MM_SIMPLE_ARRAY_CASTER(ComplexFloat32);
DECL_MM_SIMPLE_ARRAY_CASTER(ComplexFloat64);

#undef DECL_MM_SIMPLE_ARRAY_CASTER

Expand Down
22 changes: 20 additions & 2 deletions cpp/modmesh/buffer/pymod/TypeBroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*/

#include <modmesh/buffer/SimpleArray.hpp>
#include <modmesh/math/math.hpp>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h> // Must be the first include.

Expand Down Expand Up @@ -71,8 +72,17 @@ struct TypeBroadcastImpl
offset_out += arr_out.stride(it) * sidx[it] * step;
}

// NOLINTNEXTLINE(bugprone-signed-char-misuse, cert-str34-c)
arr_out.at(offset_out) = static_cast<out_type>(*ptr_in);
constexpr bool valid_conversion = (!is_complex_v<T> && !is_complex_v<D>) || (is_complex_v<T> && is_complex_v<D> && std::is_same_v<T, D>);

if constexpr (valid_conversion)
{
arr_out.at(offset_out) = static_cast<out_type>(*ptr_in);
}
else
{
throw std::runtime_error("Cannot convert between complex and non-complex types");
}
Comment on lines +75 to +84
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, only support float complex to float complex and double complex to double complex conversion.


// recursion here
copy_idx(arr_out, slices, arr_in, left_shape, sidx, dim - 1);
}
Expand Down Expand Up @@ -197,6 +207,14 @@ struct TypeBroadcast
{
TypeBroadcastImpl<T, double>::broadcast(arr_out, slices, arr_in);
}
else if (dtype_is_type<Complex<float>>(arr_in))
{
TypeBroadcastImpl<T, Complex<float>>::broadcast(arr_out, slices, arr_in);
}
else if (dtype_is_type<Complex<double>>(arr_in))
{
TypeBroadcastImpl<T, Complex<double>>::broadcast(arr_out, slices, arr_in);
}
else
{
throw std::runtime_error("input array data type not support!");
Expand Down
3 changes: 2 additions & 1 deletion cpp/modmesh/buffer/pymod/array_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <modmesh/buffer/SimpleArray.hpp>
#include <modmesh/buffer/pymod/TypeBroadcast.hpp>
#include <modmesh/math/math.hpp>

// We faced an issue where the template specialization for the caster of
// SimpleArray<T> doesn't function correctly on both macOS and Windows.
Expand Down Expand Up @@ -98,7 +99,7 @@ class ArrayPropertyHelper
const py::object & py_key = args[0];
const py::object & py_value = args[1];

const bool is_number = py::isinstance<py::bool_>(py_value) || py::isinstance<py::int_>(py_value) || py::isinstance<py::float_>(py_value);
const bool is_number = py::isinstance<py::bool_>(py_value) || py::isinstance<py::int_>(py_value) || py::isinstance<py::float_>(py_value) || is_complex_v<T>;

// sarr[K] = V
if (py::isinstance<py::int_>(py_key) && is_number)
Expand Down
4 changes: 4 additions & 0 deletions cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ void wrap_SimpleArray(pybind11::module & mod)
WrapSimpleArray<uint64_t>::commit(mod, "SimpleArrayUint64", "SimpleArrayUint64");
WrapSimpleArray<float>::commit(mod, "SimpleArrayFloat32", "SimpleArrayFloat32");
WrapSimpleArray<double>::commit(mod, "SimpleArrayFloat64", "SimpleArrayFloat64");
WrapSimpleArray<Complex<float>>::commit(mod, "SimpleArrayComplexFloat32", "SimpleArrayComplexFloat32");
WrapSimpleArray<Complex<double>>::commit(mod, "SimpleArrayComplexFloat64", "SimpleArrayComplexFloat64");

WrapSimpleCollector<bool>::commit(mod, "SimpleCollectorBool", "SimpleCollectorBool");
WrapSimpleCollector<int8_t>::commit(mod, "SimpleCollectorInt8", "SimpleCollectorInt8");
Expand All @@ -267,6 +269,8 @@ void wrap_SimpleArray(pybind11::module & mod)
WrapSimpleCollector<uint64_t>::commit(mod, "SimpleCollectorUint64", "SimpleCollectorUint64");
WrapSimpleCollector<float>::commit(mod, "SimpleCollectorFloat32", "SimpleCollectorFloat32");
WrapSimpleCollector<double>::commit(mod, "SimpleCollectorFloat64", "SimpleCollectorFloat64");
WrapSimpleCollector<Complex<float>>::commit(mod, "SimpleCollectorComplexFloat32", "SimpleCollectorComplexFloat32");
WrapSimpleCollector<Complex<double>>::commit(mod, "SimpleCollectorComplexFloat64", "SimpleCollectorComplexFloat64");
}

} /* end namespace python */
Expand Down
21 changes: 21 additions & 0 deletions cpp/modmesh/buffer/pymod/wrap_SimpleArrayPlex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ static auto execute_callback_with_typed_array(A & arrayplex, C && callback)
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64)
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32)
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64)
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32)
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64)
default:
{
throw std::invalid_argument("Unsupported datatype");
Expand Down Expand Up @@ -117,6 +119,21 @@ static void verify_python_value_datatype(pybind11::object const & value, DataTyp
}
break;
}
case DataType::ComplexFloat32:
{
if (!pybind11::isinstance<Complex<float>>(value))
{
throw pybind11::type_error("Data type mismatch, expected complex float");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message should use the type name.

}
}
case DataType::ComplexFloat64:
{
if (!pybind11::isinstance<Complex<double>>(value))
{
throw pybind11::type_error("Data type mismatch, expected complex double");
}
}

default:
throw std::runtime_error("Unsupported datatype");
}
Expand Down Expand Up @@ -148,6 +165,8 @@ static pybind11::object get_typed_array_value(const SimpleArrayPlex & array_plex
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Uint64, SimpleArrayUint64)
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float32, SimpleArrayFloat32)
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float64, SimpleArrayFloat64)
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat32, SimpleArrayComplexFloat32)
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat64, SimpleArrayComplexFloat64)
default:
{
throw std::runtime_error("Unsupported datatype");
Expand Down Expand Up @@ -180,6 +199,8 @@ static pybind11::object get_typed_array(const SimpleArrayPlex & array_plex)
DECL_MM_GET_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64)
DECL_MM_GET_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32)
DECL_MM_GET_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64)
DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32)
DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64)
default:
{
throw std::runtime_error("Unsupported datatype");
Expand Down
Loading
Loading