Skip to content

Commit

Permalink
Add numpy feature for automatic mapping of numpy::ndarray
Browse files Browse the repository at this point in the history
Signed-off-by: Andrej Orsula <[email protected]>
  • Loading branch information
AndrejOrsula committed May 9, 2024
1 parent 8b0b82a commit 2494035
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ The workspace contains these packages:
- **[pyo3_bindgen_engine](pyo3_bindgen_engine):** The underlying engine for generation of bindings
- **[pyo3_bindgen_macros](pyo3_bindgen_macros):** Procedural macros for in-place generation

Features of `pyo3_bindgen`:

- **`macros` \[experimental\]:** Enables `import_python!` macro from `pyo3_bindgen_macros` crate
- **`numpy` \[experimental\]:** Enables type mapping between Python [`numpy::ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) and Rust [`numpy::PyArray`](https://docs.rs/numpy/latest/numpy/array/struct.PyArray.html)

## Instructions

### <a href="#-option-1-build-script"><img src="https://rustacean.net/assets/rustacean-flat-noshadow.svg" width="16" height="16"></a> Option 1: Build script
Expand Down
10 changes: 4 additions & 6 deletions pyo3_bindgen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ pyo3_bindgen_macros = { workspace = true, optional = true }

[features]
default = []
macros = ["pyo3_bindgen_macros"]

[lib]
name = "pyo3_bindgen"
path = "src/lib.rs"
crate-type = ["rlib"]
# Enables `import_python!` macro from `pyo3_bindgen_macros` crate
macros = ["dep:pyo3_bindgen_macros"]
# Enables type mapping between Python `numpy::ndarray` and Rust `numpy::PyArray`
numpy = ["pyo3_bindgen_engine/numpy"]

[package.metadata.docs.rs]
all-features = true
6 changes: 1 addition & 5 deletions pyo3_bindgen_engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@ prettyplease = { workspace = true }
[build-dependencies]
pyo3-build-config = { workspace = true }

[lib]
name = "pyo3_bindgen_engine"
path = "src/lib.rs"
crate-type = ["rlib"]

[features]
default = []
numpy = []

[[bench]]
name = "bindgen"
Expand Down
39 changes: 35 additions & 4 deletions pyo3_bindgen_engine/src/typing/into_rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ impl Type {
}

// Try to map the local types
if let Some(relative_path) = local_types.get(&Path::from_py(&type_name)) {
let type_name_without_delimiters =
type_name.split_once('[').map(|s| s.0).unwrap_or(&type_name);
if let Some(relative_path) = local_types.get(&Path::from_py(type_name_without_delimiters)) {
let relative_path: syn::Path = relative_path.try_into().unwrap();
return OutputType::new(
quote!(::pyo3::Bound<'py, #relative_path>),
Expand All @@ -331,9 +333,38 @@ impl Type {
)
}

fn try_map_external_type(_type_name: &str) -> Option<OutputType> {
// TODO: Handle types from other packages with Rust bindings here (e.g. NumPy)
None
fn try_map_external_type(type_name: &str) -> Option<OutputType> {
// TODO: Handle types from other packages with Rust bindings here
match type_name {
#[cfg(feature = "numpy")]
numpy_ndarray

Check warning on line 340 in pyo3_bindgen_engine/src/typing/into_rs.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/typing/into_rs.rs#L339-L340

Added lines #L339 - L340 were not covered by tests
if numpy_ndarray
.split_once('[')
.map(|s| s.0)
.unwrap_or(numpy_ndarray)
.split('.')
.last()
.unwrap_or(numpy_ndarray)
.to_lowercase()
== "ndarray" =>
{
Some(OutputType::new(
quote!(
::pyo3::Bound<
'py,
::numpy::PyArray<::pyo3::Py<::pyo3::types::PyAny>, ::numpy::IxDyn>,
>
),
quote!(
&::pyo3::Bound<
'py,
::numpy::PyArray<::pyo3::Py<::pyo3::types::PyAny>, ::numpy::IxDyn>,
>
),
))
}

Check warning on line 365 in pyo3_bindgen_engine/src/typing/into_rs.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/typing/into_rs.rs#L350-L365

Added lines #L350 - L365 were not covered by tests
_ => None,
}

Check warning on line 367 in pyo3_bindgen_engine/src/typing/into_rs.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/typing/into_rs.rs#L367

Added line #L367 was not covered by tests
}
}

Expand Down

0 comments on commit 2494035

Please sign in to comment.