Skip to content

Commit

Permalink
Improve ergonomy of calling methods with pyo3::Bound
Browse files Browse the repository at this point in the history
Signed-off-by: Andrej Orsula <[email protected]>
  • Loading branch information
AndrejOrsula committed May 9, 2024
1 parent 77dc889 commit dd45bb0
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 267 deletions.
3 changes: 2 additions & 1 deletion examples/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
pyo3_bindgen::import_python!("random");

fn main() -> pyo3::PyResult<()> {
use ::pyo3::types::PyAnyMethods;

pyo3::Python::with_gil(|py| {
use ::pyo3::types::PyAnyMethods;
let rand_f64: f64 = random::random(py)?.extract()?;
assert!((0.0..=1.0).contains(&rand_f64));
println!("Random f64: {}", rand_f64);
Expand Down
72 changes: 55 additions & 17 deletions pyo3_bindgen_engine/src/syntax/class.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
AttributeVariant, Function, FunctionType, Ident, MethodType, Path, Property, PropertyOwner,
AttributeVariant, Function, FunctionImplementation, FunctionType, Ident, MethodType, Path,
Property, PropertyOwner, TraitMethod,
};
use crate::{Config, Result};
use itertools::Itertools;
Expand Down Expand Up @@ -199,8 +200,6 @@ impl Class {
::pyo3::pyobject_native_static_type_object!(::pyo3::ffi::PyBaseObject_Type),
::std::option::Option::Some(#object_name)
);
// TODO: PRobably not necessary
::pyo3::pyobject_native_type_extract!(#struct_ident);
});

// Get the names of all methods to avoid name clashes
Expand All @@ -210,15 +209,26 @@ impl Class {
.map(|method| method.name.name())
.collect::<Vec<_>>();

// Generate the struct implementation block
// Generate the struct implementation blocks
let mut struct_impl = proc_macro2::TokenStream::new();
let mut method_defs = proc_macro2::TokenStream::new();
let mut method_impls = proc_macro2::TokenStream::new();
// Methods
struct_impl.extend(
self.methods
.iter()
.map(|method| method.generate(cfg, &scoped_function_idents, local_types))
.collect::<Result<proc_macro2::TokenStream>>()?,
);
self.methods
.iter()
.map(|method| method.generate(cfg, &scoped_function_idents, local_types))
.try_for_each(|def| {
match def? {
FunctionImplementation::Function(impl_fn) => {
struct_impl.extend(impl_fn);
}
FunctionImplementation::Method(TraitMethod { trait_fn, impl_fn }) => {
method_defs.extend(trait_fn);
method_impls.extend(impl_fn);
}
}
Result::Ok(())
})?;
// Properties
{
let mut scoped_function_idents_extra = Vec::with_capacity(2);
Expand All @@ -245,22 +255,50 @@ impl Class {
scoped_function_idents_extra.push(Ident::from_py("call"));
}
scoped_function_idents.extend(scoped_function_idents_extra.iter());
struct_impl.extend(
self.properties
.iter()
.map(|property| property.generate(cfg, &scoped_function_idents, local_types))
.collect::<Result<proc_macro2::TokenStream>>()?,
);
self.properties
.iter()
.map(|property| property.generate(cfg, &scoped_function_idents, local_types))
.try_for_each(|def| {
match def? {
FunctionImplementation::Function(impl_fn) => {
struct_impl.extend(impl_fn);
}
FunctionImplementation::Method(TraitMethod { trait_fn, impl_fn }) => {
method_defs.extend(trait_fn);
method_impls.extend(impl_fn);
}
}
Result::Ok(())
})?;
}

// Finalize the implementation block of the struct
// Add the implementation block for the struct
output.extend(quote::quote! {
#[automatically_derived]
impl #struct_ident {
#struct_impl
}
});

// Add the trait and implementation block for bounded struct
let trait_ident: syn::Ident =
Ident::from_py(&format!("{struct_ident}Methods")).try_into()?;
let struct_ident_str = struct_ident.to_string();
output.extend(quote::quote! {
/// These methods are defined for the `Bound<'py, T>` smart pointer, so to use
/// method call syntax these methods are separated into a trait, because stable
/// Rust does not yet support `arbitrary_self_types`.
#[doc(alias = #struct_ident_str)]
#[automatically_derived]
pub trait #trait_ident {
#method_defs
}
#[automatically_derived]
impl #trait_ident for ::pyo3::Bound<'_, #struct_ident> {
#method_impls
}
});

Ok(output)
}
}
28 changes: 28 additions & 0 deletions pyo3_bindgen_engine/src/syntax/common/function_definition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
pub enum FunctionImplementation {
Function(proc_macro2::TokenStream),
Method(TraitMethod),
}

impl FunctionImplementation {
pub fn empty_function() -> Self {
Self::Function(proc_macro2::TokenStream::new())
}

pub fn empty_method() -> Self {
Self::Method(TraitMethod::empty())
}
}

pub struct TraitMethod {
pub trait_fn: proc_macro2::TokenStream,
pub impl_fn: proc_macro2::TokenStream,
}

impl TraitMethod {
pub fn empty() -> Self {
Self {
trait_fn: proc_macro2::TokenStream::new(),
impl_fn: proc_macro2::TokenStream::new(),
}
}
}
2 changes: 2 additions & 0 deletions pyo3_bindgen_engine/src/syntax/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub(crate) mod attribute_variant;
pub(crate) mod function_definition;
pub(crate) mod ident;
pub(crate) mod path;

pub use attribute_variant::AttributeVariant;
pub use function_definition::{FunctionImplementation, TraitMethod};
pub use ident::Ident;
pub use path::Path;
9 changes: 6 additions & 3 deletions pyo3_bindgen_engine/src/syntax/common/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,12 @@ impl Path {
.collect_vec();

// Generate the import code
quote::quote! {
py.import_bound(::pyo3::intern!(py, #package_path))?#(.getattr(::pyo3::intern!(py, #remaining_path))?)*
}
remaining_path.into_iter().fold(
quote::quote! { py.import_bound(::pyo3::intern!(py, #package_path))? },
|acc, ident| {
quote::quote! { ::pyo3::types::PyAnyMethods::getattr(#acc.as_any(), ::pyo3::intern!(py, #ident))? }
},
)
}
}

Expand Down
82 changes: 52 additions & 30 deletions pyo3_bindgen_engine/src/syntax/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{Ident, Path};
use super::{FunctionImplementation, Ident, Path, TraitMethod};
use crate::{typing::Type, Config, Result};
use itertools::Itertools;
use proc_macro2::TokenStream;
use pyo3::{prelude::*, types::IntoPyDict, ToPyObject};
use rustc_hash::FxHashMap as HashMap;

Expand Down Expand Up @@ -321,14 +322,14 @@ impl Function {
cfg: &Config,
scoped_function_idents: &[&Ident],
local_types: &HashMap<Path, Path>,
) -> Result<proc_macro2::TokenStream> {
let mut output = proc_macro2::TokenStream::new();
) -> Result<FunctionImplementation> {
let mut impl_fn = proc_macro2::TokenStream::new();

// Documentation
if cfg.generate_docs {
if let Some(mut docstring) = self.docstring.clone() {
crate::utils::text::format_docstring(&mut docstring);
output.extend(quote::quote! {
impl_fn.extend(quote::quote! {
#[doc = #docstring]
});
}
Expand All @@ -339,7 +340,7 @@ impl Function {
let name = self.name.name();
if let Ok(ident) = name.try_into() {
if crate::config::FORBIDDEN_FUNCTION_NAMES.contains(&name.as_py()) {
return Ok(proc_macro2::TokenStream::new());
return Ok(FunctionImplementation::empty_function());
} else {
ident
}
Expand All @@ -360,7 +361,7 @@ impl Function {
"WARN: Function '{}' is an invalid Rust ident for a function name. Renaming failed. Bindings will not be generated.",
self.name
);
return Ok(proc_macro2::TokenStream::new());
return Ok(FunctionImplementation::empty_function());
}
}
};
Expand All @@ -386,15 +387,14 @@ impl Function {
.map(|param| Result::Ok(param.annotation.clone().into_rs_borrowed(local_types)))
.collect::<Result<Vec<_>>>()?;
let return_type = self.return_annotation.clone().into_rs_owned(local_types);
output.extend(match &self.typ {
let fn_contract = match &self.typ {
FunctionType::Method {
typ: MethodType::InstanceMethod,
..
} => {
quote::quote! {
pub fn #function_ident<'py>(
slf: &::pyo3::Bound<'py, Self>,
py: ::pyo3::marker::Python<'py>,
fn #function_ident<'py>(
&'py self,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_type>
}
Expand All @@ -418,9 +418,8 @@ impl Function {
}
.try_into()?;
quote::quote! {
pub fn #call_fn_ident<'py>(
slf: &::pyo3::Bound<'py, Self>,
py: ::pyo3::marker::Python<'py>,
fn #call_fn_ident<'py>(
&'py self,
#(#param_idents: #param_types),*
) -> ::pyo3::PyResult<#return_type>
}
Expand Down Expand Up @@ -458,7 +457,19 @@ impl Function {
) -> ::pyo3::PyResult<#return_type>
}
}
});
};
impl_fn.extend(fn_contract.clone());

// If the function is a method with `self` as a parameter, extract the Python marker from `self`
let maybe_extract_py = match &self.typ {
FunctionType::Method {
typ: MethodType::InstanceMethod | MethodType::Callable,
..
} => quote::quote! {
let py = self.py();
},
_ => TokenStream::new(),
};

// Function body (function dispatcher)
let function_dispatcher = match &self.typ {
Expand All @@ -477,7 +488,7 @@ impl Function {
..
} => {
quote::quote! {
slf
self
}
}
FunctionType::Method {
Expand All @@ -488,7 +499,7 @@ impl Function {
"WARN: Method '{}' has an unknown type. Bindings will not be generated.",
self.name
);
return Ok(proc_macro2::TokenStream::new());
return Ok(FunctionImplementation::empty_method());
}
};

Expand Down Expand Up @@ -521,9 +532,9 @@ impl Function {
let n_args_fixed = positional_args_idents.len();
quote::quote! {
{
let mut __internal__args = Vec::with_capacity(#n_args_fixed + #var_positional_args_ident.len()?);
let mut __internal__args = Vec::with_capacity(#n_args_fixed + ::pyo3::types::PyTupleMethods::len(#var_positional_args_ident));
__internal__args.extend([#(::pyo3::ToPyObject::to_object(&#positional_args_idents, py),)*]);
__internal__args.extend(#var_positional_args_ident.iter().as_ref().map(|__internal__arg| ::pyo3::ToPyObject::to_object(__internal__arg, py)));
__internal__args.extend(::pyo3::types::PyTupleMethods::iter(#var_positional_args_ident).map(|__internal__arg| ::pyo3::ToPyObject::to_object(&__internal__arg, py)));
::pyo3::types::PyTuple::new_bound(
py,
__internal__args,
Expand Down Expand Up @@ -573,7 +584,7 @@ impl Function {
{
let __internal__kwargs = #var_keyword_args_ident;
#(
__internal__kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents);
::pyo3::types::PyDictMethods::set_item(&__internal__kwargs, ::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents);
)*
__internal__kwargs
}
Expand All @@ -588,7 +599,7 @@ impl Function {
{
let __internal__kwargs = ::pyo3::types::PyDict::new_bound(py);
#(
__internal__kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents);
::pyo3::types::PyDictMethods::set_item(&__internal__kwargs, ::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents);
)*
__internal__kwargs
}
Expand All @@ -602,44 +613,55 @@ impl Function {
{
if has_keyword_args {
quote::quote! {
call(#positional_args, Some(&#keyword_args))
::pyo3::types::PyAnyMethods::call(#function_dispatcher.as_any(), #positional_args, Some(&#keyword_args))
}
} else if has_positional_args {
quote::quote! {
call1(#positional_args)
::pyo3::types::PyAnyMethods::call1(#function_dispatcher.as_any(), #positional_args)
}
} else {
quote::quote! {
call0()
::pyo3::types::PyAnyMethods::call0(#function_dispatcher.as_any())
}
}
} else {
let method_name = self.name.name().as_py();
if has_keyword_args {
quote::quote! {
call_method(::pyo3::intern!(py, #method_name), #positional_args, Some(&#keyword_args))
::pyo3::types::PyAnyMethods::call_method(#function_dispatcher.as_any(), ::pyo3::intern!(py, #method_name), #positional_args, Some(&#keyword_args))
}
} else if has_positional_args {
quote::quote! {
call_method1(::pyo3::intern!(py, #method_name), #positional_args)
::pyo3::types::PyAnyMethods::call_method1(#function_dispatcher.as_any(), ::pyo3::intern!(py, #method_name), #positional_args)
}
} else {
quote::quote! {
call_method0(::pyo3::intern!(py, #method_name))
::pyo3::types::PyAnyMethods::call_method0(#function_dispatcher.as_any(), ::pyo3::intern!(py, #method_name))
}
}
};

// Function body
output.extend(quote::quote! {
impl_fn.extend(quote::quote! {
{
use ::pyo3::types::PyAnyMethods;
#maybe_extract_py
#param_preprocessing
#function_dispatcher.#call?.extract()
::pyo3::types::PyAnyMethods::extract(
&#call?
)
}
});

Ok(output)
Ok(match &self.typ {
FunctionType::Method {
typ: MethodType::InstanceMethod | MethodType::Callable,
..
} => FunctionImplementation::Method(TraitMethod {
trait_fn: quote::quote! { #fn_contract ; },
impl_fn,
}),
_ => FunctionImplementation::Function(impl_fn),
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion pyo3_bindgen_engine/src/syntax/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub(crate) mod property;
pub(crate) mod type_var;

pub use class::Class;
pub use common::{AttributeVariant, Ident, Path};
pub use common::{AttributeVariant, FunctionImplementation, Ident, Path, TraitMethod};
pub use function::{Function, FunctionType, MethodType};
pub use import::Import;
pub use module::Module;
Expand Down
Loading

0 comments on commit dd45bb0

Please sign in to comment.