From 2494035b638d40ff9d99577f7116b77bf387410f Mon Sep 17 00:00:00 2001 From: Andrej Orsula Date: Thu, 9 May 2024 23:28:46 +0200 Subject: [PATCH] Add `numpy` feature for automatic mapping of numpy::ndarray Signed-off-by: Andrej Orsula --- README.md | 5 +++ pyo3_bindgen/Cargo.toml | 10 +++--- pyo3_bindgen_engine/Cargo.toml | 6 +--- pyo3_bindgen_engine/src/typing/into_rs.rs | 39 ++++++++++++++++++++--- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index e01e062..492e8c8 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,11 @@ The workspace contains these packages: - **[pyo3_bindgen_engine](pyo3_bindgen_engine):** The underlying engine for generation of bindings - **[pyo3_bindgen_macros](pyo3_bindgen_macros):** Procedural macros for in-place generation +Features of `pyo3_bindgen`: + +- **`macros` \[experimental\]:** Enables `import_python!` macro from `pyo3_bindgen_macros` crate +- **`numpy` \[experimental\]:** Enables type mapping between Python [`numpy::ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) and Rust [`numpy::PyArray`](https://docs.rs/numpy/latest/numpy/array/struct.PyArray.html) + ## Instructions ### Option 1: Build script diff --git a/pyo3_bindgen/Cargo.toml b/pyo3_bindgen/Cargo.toml index 53f502e..70ec189 100644 --- a/pyo3_bindgen/Cargo.toml +++ b/pyo3_bindgen/Cargo.toml @@ -18,12 +18,10 @@ pyo3_bindgen_macros = { workspace = true, optional = true } [features] default = [] -macros = ["pyo3_bindgen_macros"] - -[lib] -name = "pyo3_bindgen" -path = "src/lib.rs" -crate-type = ["rlib"] +# Enables `import_python!` macro from `pyo3_bindgen_macros` crate +macros = ["dep:pyo3_bindgen_macros"] +# Enables type mapping between Python `numpy::ndarray` and Rust `numpy::PyArray` +numpy = ["pyo3_bindgen_engine/numpy"] [package.metadata.docs.rs] all-features = true diff --git a/pyo3_bindgen_engine/Cargo.toml b/pyo3_bindgen_engine/Cargo.toml index 90983d3..928f722 100644 --- a/pyo3_bindgen_engine/Cargo.toml +++ b/pyo3_bindgen_engine/Cargo.toml @@ -29,13 +29,9 @@ prettyplease = { workspace = true } [build-dependencies] pyo3-build-config = { workspace = true } -[lib] -name = "pyo3_bindgen_engine" -path = "src/lib.rs" -crate-type = ["rlib"] - [features] default = [] +numpy = [] [[bench]] name = "bindgen" diff --git a/pyo3_bindgen_engine/src/typing/into_rs.rs b/pyo3_bindgen_engine/src/typing/into_rs.rs index 5fb063a..d9f27f1 100644 --- a/pyo3_bindgen_engine/src/typing/into_rs.rs +++ b/pyo3_bindgen_engine/src/typing/into_rs.rs @@ -316,7 +316,9 @@ impl Type { } // Try to map the local types - if let Some(relative_path) = local_types.get(&Path::from_py(&type_name)) { + let type_name_without_delimiters = + type_name.split_once('[').map(|s| s.0).unwrap_or(&type_name); + if let Some(relative_path) = local_types.get(&Path::from_py(type_name_without_delimiters)) { let relative_path: syn::Path = relative_path.try_into().unwrap(); return OutputType::new( quote!(::pyo3::Bound<'py, #relative_path>), @@ -331,9 +333,38 @@ impl Type { ) } - fn try_map_external_type(_type_name: &str) -> Option { - // TODO: Handle types from other packages with Rust bindings here (e.g. NumPy) - None + fn try_map_external_type(type_name: &str) -> Option { + // TODO: Handle types from other packages with Rust bindings here + match type_name { + #[cfg(feature = "numpy")] + numpy_ndarray + if numpy_ndarray + .split_once('[') + .map(|s| s.0) + .unwrap_or(numpy_ndarray) + .split('.') + .last() + .unwrap_or(numpy_ndarray) + .to_lowercase() + == "ndarray" => + { + Some(OutputType::new( + quote!( + ::pyo3::Bound< + 'py, + ::numpy::PyArray<::pyo3::Py<::pyo3::types::PyAny>, ::numpy::IxDyn>, + > + ), + quote!( + &::pyo3::Bound< + 'py, + ::numpy::PyArray<::pyo3::Py<::pyo3::types::PyAny>, ::numpy::IxDyn>, + > + ), + )) + } + _ => None, + } } }