-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add complex array support #468
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
*/ | ||
|
||
#include <modmesh/buffer/SimpleArray.hpp> | ||
#include <modmesh/math/math.hpp> | ||
|
||
#include <unordered_map> | ||
|
||
|
@@ -65,7 +66,9 @@ static std::unordered_map<std::string, DataType, DataTypeHasher> string_data_typ | |
{"uint32", DataType::Uint32}, | ||
{"uint64", DataType::Uint64}, | ||
{"float32", DataType::Float32}, | ||
{"float64", DataType::Float64}}; | ||
{"float64", DataType::Float64}, | ||
{"ComplexFloat32", DataType::ComplexFloat32}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would just call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea to drop It follows numpy convention: https://numpy.org/doc/stable/user/basics.types.html#relationship-between-numpy-data-types-and-c-data-types There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redarding this: #468 (comment) |
||
{"ComplexFloat64", DataType::ComplexFloat64}}; | ||
|
||
} /* end namespace detail */ | ||
|
||
|
@@ -145,6 +148,18 @@ DataType DataType::from<double>() | |
return DataType::Float64; | ||
} | ||
|
||
template <> | ||
DataType DataType::from<Complex<float>>() | ||
{ | ||
return DataType::ComplexFloat32; | ||
} | ||
|
||
template <> | ||
DataType DataType::from<Complex<double>>() | ||
{ | ||
return DataType::ComplexFloat64; | ||
} | ||
|
||
// According to the `DataType`, create the corresponding `SimpleArray<T>` instance | ||
// and assign it to `m_instance_ptr`. The `m_instance_ptr` is a void pointer, so | ||
// we need to use `reinterpret_cast` to convert the pointer of the array instance. | ||
|
@@ -171,6 +186,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const DataType data_t | |
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape) | ||
default: | ||
throw std::invalid_argument("Unsupported datatype"); | ||
} | ||
|
@@ -193,6 +210,8 @@ SimpleArrayPlex::SimpleArrayPlex(const shape_type & shape, const std::shared_ptr | |
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Uint64, SimpleArrayUint64, shape, buffer) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float32, SimpleArrayFloat32, shape, buffer) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::Float64, SimpleArrayFloat64, shape, buffer) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32, shape, buffer) | ||
DECL_MM_CREATE_SIMPLE_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64, shape, buffer) | ||
default: | ||
throw std::invalid_argument("Unsupported datatype"); | ||
} | ||
|
@@ -289,6 +308,20 @@ SimpleArrayPlex::SimpleArrayPlex(SimpleArrayPlex const & other) | |
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayFloat64(*array)); | ||
break; | ||
} | ||
case DataType::ComplexFloat32: | ||
{ | ||
const auto * array = static_cast<SimpleArrayComplexFloat32 *>(other.m_instance_ptr); | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat32(*array)); | ||
break; | ||
} | ||
case DataType::ComplexFloat64: | ||
{ | ||
const auto * array = static_cast<SimpleArrayComplexFloat64 *>(other.m_instance_ptr); | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat64(*array)); | ||
break; | ||
} | ||
default: | ||
{ | ||
throw std::invalid_argument("Unsupported datatype"); | ||
|
@@ -406,6 +439,20 @@ SimpleArrayPlex & SimpleArrayPlex::operator=(SimpleArrayPlex const & other) | |
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayFloat64(*array)); | ||
break; | ||
} | ||
case DataType::ComplexFloat32: | ||
{ | ||
const auto * array = static_cast<SimpleArrayComplexFloat32 *>(other.m_instance_ptr); | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat32(*array)); | ||
break; | ||
} | ||
case DataType::ComplexFloat64: | ||
{ | ||
const auto * array = static_cast<SimpleArrayComplexFloat64 *>(other.m_instance_ptr); | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
m_instance_ptr = reinterpret_cast<void *>(new SimpleArrayComplexFloat64(*array)); | ||
break; | ||
} | ||
default: | ||
{ | ||
throw std::invalid_argument("Unsupported datatype"); | ||
|
@@ -506,6 +553,18 @@ SimpleArrayPlex::~SimpleArrayPlex() | |
delete reinterpret_cast<SimpleArrayFloat64 *>(m_instance_ptr); | ||
break; | ||
} | ||
case DataType::ComplexFloat32: | ||
{ | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
delete reinterpret_cast<SimpleArrayComplexFloat32 *>(m_instance_ptr); | ||
break; | ||
} | ||
case DataType::ComplexFloat64: | ||
{ | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) | ||
delete reinterpret_cast<SimpleArrayComplexFloat64 *>(m_instance_ptr); | ||
break; | ||
} | ||
default: | ||
break; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
*/ | ||
|
||
#include <modmesh/buffer/ConcreteBuffer.hpp> | ||
#include <modmesh/math/math.hpp> | ||
|
||
#include <limits> | ||
#include <stdexcept> | ||
|
@@ -160,7 +161,16 @@ class SimpleArrayMixinCalculators | |
|
||
value_type sum() const | ||
{ | ||
value_type initial = 0; | ||
value_type initial; | ||
if constexpr (is_complex_v<value_type>) | ||
{ | ||
initial = value_type(); | ||
} | ||
else | ||
{ | ||
initial = 0; | ||
} | ||
Comment on lines
+165
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the element type is |
||
|
||
auto athis = static_cast<A const *>(this); | ||
if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>) | ||
{ | ||
|
@@ -748,6 +758,8 @@ using SimpleArrayUint32 = SimpleArray<uint32_t>; | |
using SimpleArrayUint64 = SimpleArray<uint64_t>; | ||
using SimpleArrayFloat32 = SimpleArray<float>; | ||
using SimpleArrayFloat64 = SimpleArray<double>; | ||
using SimpleArrayComplexFloat32 = SimpleArray<Complex<float>>; | ||
using SimpleArrayComplexFloat64 = SimpleArray<Complex<double>>; | ||
|
||
class DataType | ||
{ | ||
|
@@ -766,6 +778,8 @@ class DataType | |
Uint64, | ||
Float32, | ||
Float64, | ||
ComplexFloat32, | ||
ComplexFloat64 | ||
}; /* end enum enum_type */ | ||
|
||
DataType() = default; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
*/ | ||
|
||
#include <modmesh/buffer/SimpleArray.hpp> | ||
#include <modmesh/math/math.hpp> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> // Must be the first include. | ||
|
||
|
@@ -71,8 +72,17 @@ struct TypeBroadcastImpl | |
offset_out += arr_out.stride(it) * sidx[it] * step; | ||
} | ||
|
||
// NOLINTNEXTLINE(bugprone-signed-char-misuse, cert-str34-c) | ||
arr_out.at(offset_out) = static_cast<out_type>(*ptr_in); | ||
constexpr bool valid_conversion = (!is_complex_v<T> && !is_complex_v<D>) || (is_complex_v<T> && is_complex_v<D> && std::is_same_v<T, D>); | ||
|
||
if constexpr (valid_conversion) | ||
{ | ||
arr_out.at(offset_out) = static_cast<out_type>(*ptr_in); | ||
} | ||
else | ||
{ | ||
throw std::runtime_error("Cannot convert between complex and non-complex types"); | ||
} | ||
Comment on lines
+75
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, only support |
||
|
||
// recursion here | ||
copy_idx(arr_out, slices, arr_in, left_shape, sidx, dim - 1); | ||
} | ||
|
@@ -197,6 +207,14 @@ struct TypeBroadcast | |
{ | ||
TypeBroadcastImpl<T, double>::broadcast(arr_out, slices, arr_in); | ||
} | ||
else if (dtype_is_type<Complex<float>>(arr_in)) | ||
{ | ||
TypeBroadcastImpl<T, Complex<float>>::broadcast(arr_out, slices, arr_in); | ||
} | ||
else if (dtype_is_type<Complex<double>>(arr_in)) | ||
{ | ||
TypeBroadcastImpl<T, Complex<double>>::broadcast(arr_out, slices, arr_in); | ||
} | ||
else | ||
{ | ||
throw std::runtime_error("input array data type not support!"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,8 @@ static auto execute_callback_with_typed_array(A & arrayplex, C && callback) | |
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64) | ||
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32) | ||
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64) | ||
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32) | ||
DECL_MM_RUN_CALLBACK_WITH_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64) | ||
default: | ||
{ | ||
throw std::invalid_argument("Unsupported datatype"); | ||
|
@@ -117,6 +119,21 @@ static void verify_python_value_datatype(pybind11::object const & value, DataTyp | |
} | ||
break; | ||
} | ||
case DataType::ComplexFloat32: | ||
{ | ||
if (!pybind11::isinstance<Complex<float>>(value)) | ||
{ | ||
throw pybind11::type_error("Data type mismatch, expected complex float"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message should use the type name. |
||
} | ||
} | ||
case DataType::ComplexFloat64: | ||
{ | ||
if (!pybind11::isinstance<Complex<double>>(value)) | ||
{ | ||
throw pybind11::type_error("Data type mismatch, expected complex double"); | ||
} | ||
} | ||
|
||
default: | ||
throw std::runtime_error("Unsupported datatype"); | ||
} | ||
|
@@ -148,6 +165,8 @@ static pybind11::object get_typed_array_value(const SimpleArrayPlex & array_plex | |
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Uint64, SimpleArrayUint64) | ||
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float32, SimpleArrayFloat32) | ||
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::Float64, SimpleArrayFloat64) | ||
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat32, SimpleArrayComplexFloat32) | ||
DECL_MM_GET_TYPED_ARRAY_VALUE_BY_INDEX(DataType::ComplexFloat64, SimpleArrayComplexFloat64) | ||
default: | ||
{ | ||
throw std::runtime_error("Unsupported datatype"); | ||
|
@@ -180,6 +199,8 @@ static pybind11::object get_typed_array(const SimpleArrayPlex & array_plex) | |
DECL_MM_GET_TYPED_ARRAY(DataType::Uint64, SimpleArrayUint64) | ||
DECL_MM_GET_TYPED_ARRAY(DataType::Float32, SimpleArrayFloat32) | ||
DECL_MM_GET_TYPED_ARRAY(DataType::Float64, SimpleArrayFloat64) | ||
DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat32, SimpleArrayComplexFloat32) | ||
DECL_MM_GET_TYPED_ARRAY(DataType::ComplexFloat64, SimpleArrayComplexFloat64) | ||
default: | ||
{ | ||
throw std::runtime_error("Unsupported datatype"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I have used designate initialiser, this feature is supported by c++20, therefore change it to c++ standard 20.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason that you have to use designated initializer? Please let us know what it is.
If there is not, then just don't use designated initializer. I have no problem using it, but it alone is not a good reason to upgrade to C++20.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to mark the initial value for each structure member during structure initialization, but I did not consider what side effects upgrading to C++20 may bring.