Skip to content

Commit

Permalink
Fix bindings for __init__() and __call__()
Browse files Browse the repository at this point in the history
Signed-off-by: Andrej Orsula <[email protected]>
  • Loading branch information
AndrejOrsula committed Jan 24, 2024
1 parent 32b99f6 commit 2acac48
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 95 deletions.
4 changes: 2 additions & 2 deletions pyo3_bindgen_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ fn main() {
let args = Args::parse();

// Generate the bindings for the module specified by the `--module-name` argument
let bindings = pyo3_bindgen::generate_bindings(&args.module_name).unwrap_or_else(|_| {
let bindings = pyo3_bindgen::generate_bindings(&args.module_name).unwrap_or_else(|err| {
panic!(
"Failed to generate bindings for module: {}",
"Failed to generate bindings for module: {}\n{err}",

Check warning on line 12 in pyo3_bindgen_cli/src/main.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_cli/src/main.rs#L12

Added line #L12 was not covered by tests
args.module_name
)
});
Expand Down
115 changes: 68 additions & 47 deletions pyo3_bindgen_engine/src/bindgen/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,7 @@ pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
let root_module_name = root_module.name()?;
let class_full_name = class.name()?;
let class_name = class_full_name.split('.').last().unwrap();
let class_module_name = format!(
"{}{}{}",
class.getattr("__module__")?,
if class_full_name.contains('.') {
"."
} else {
""
},
class_full_name.trim_end_matches(&format!(".{class_name}"))
);
let class_module_name = class.getattr("__module__")?.to_string();

// Create the Rust class identifier (raw string if it is a keyword)
let class_ident = if syn::parse_str::<syn::Ident>(class_name).is_ok() {
Expand All @@ -32,18 +23,65 @@ pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
quote::format_ident!("r#{class_name}")
};

let mut fn_names = Vec::new();
// let mut fn_names = Vec::new();

// Iterate over all attributes of the module while updating the token stream
let mut impl_token_stream = proc_macro2::TokenStream::new();

// Implement new()
if class.hasattr("__init__")? {
for i in 0.. {
let new_fn_name = if i == 0 {
"new".to_string()
} else {
format!("new{i}")

Check warning on line 36 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L36

Added line #L36 was not covered by tests
};
if !class.hasattr(new_fn_name.as_str())? {
impl_token_stream.extend(bind_function(
py,
&class_module_name,
&new_fn_name,
class.getattr("__init__")?,
all_types,
Some(class),
));
break;
}

Check warning on line 48 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L48

Added line #L48 was not covered by tests
}
}

Check warning on line 50 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L50

Added line #L50 was not covered by tests
// Implement call() method
if class.hasattr("__call__")? {
for i in 0.. {
let call_fn_name = if i == 0 {
"call".to_string()
} else {
format!("call{i}")

Check warning on line 57 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L57

Added line #L57 was not covered by tests
};
if !class.hasattr(call_fn_name.as_str())? {
impl_token_stream.extend(bind_function(
py,
&class_module_name,
&call_fn_name,
class.getattr("__call__")?,
all_types,
Some(class),
));
break;
}

Check warning on line 69 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L69

Added line #L69 was not covered by tests
}
}

Check warning on line 71 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L71

Added line #L71 was not covered by tests

// Iterate over all attributes of the module while updating the token stream
class
.dir()
.iter()
.map(|name| {
.filter_map(|name| {
let name = name.str().unwrap().to_str().unwrap();
let attr = class.getattr(name).unwrap();
let attr_type = attr.get_type();
(name, attr, attr_type)
if let Ok(attr) = class.getattr(name) {
let attr_type = attr.get_type();
Some((name, attr, attr_type))
} else {
None

Check warning on line 83 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L83

Added line #L83 was not covered by tests
}
})
.filter(|&(_, _, attr_type)| {
// Skip builtin functions
Expand All @@ -52,8 +90,8 @@ pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
.unwrap_or(false)
})
.filter(|&(name, _, _)| {
// Skip private attributes (except for __init__ and __call__)
!name.starts_with('_') || name == "__init__" || name == "__call__"
// Skip private attributes
!name.starts_with('_')
})
.filter(|(_, attr, attr_type)| {
// Skip typing attributes
Expand Down Expand Up @@ -136,20 +174,22 @@ pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
debug_assert!(![is_class, is_function].iter().all(|&v| v));

if is_class && !is_reexport {
impl_token_stream.extend(bind_class(
py,
root_module,
attr.downcast().unwrap(),
all_types,
));
// TODO: Properly handle nested classes
// impl_token_stream.extend(bind_class(
// py,
// root_module,
// attr.downcast().unwrap(),
// all_types,
// ));

Check warning on line 183 in pyo3_bindgen_engine/src/bindgen/class.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/class.rs#L177-L183

Added lines #L177 - L183 were not covered by tests
} else if is_function {
fn_names.push(name.to_string());
// fn_names.push(name.to_string());
impl_token_stream.extend(bind_function(
py,
&class_module_name,
name,
attr,
all_types,
Some(class),
));
} else if !name.starts_with('_') {
impl_token_stream.extend(bind_attribute(
Expand All @@ -164,39 +204,20 @@ pub fn bind_class<S: ::std::hash::BuildHasher + Default>(
}
});

// Add new and call aliases (currently a reimplemented versions of the function)
// TODO: Call the Rust `self.__init__()` and `self.__call__()` functions directly instead of reimplementing it
if fn_names.contains(&"__init__".to_string()) && !fn_names.contains(&"new".to_string()) {
impl_token_stream.extend(bind_function(
py,
&class_module_name,
"new",
class.getattr("__init__")?,
all_types,
));
}
if fn_names.contains(&"__call__".to_string()) && !fn_names.contains(&"call".to_string()) {
impl_token_stream.extend(bind_function(
py,
&class_module_name,
"call",
class.getattr("__call__")?,
all_types,
));
}

let mut doc = class.getattr("__doc__")?.to_string();
if doc == "None" {
doc = String::new();
};

let object_name = format!("{class_module_name}.{class_name}");

Ok(quote::quote! {
#[doc = #doc]
#[repr(transparent)]
pub struct #class_ident(::pyo3::PyAny);
// Note: Using these macros is probably not the best idea, but it makes possible wrapping around ::pyo3::PyAny instead of ::pyo3::PyObject, which improves usability
::pyo3::pyobject_native_type_named!(#class_ident);
::pyo3::pyobject_native_type_info!(#class_ident, ::pyo3::pyobject_native_static_type_object!(::pyo3::ffi::PyBaseObject_Type), ::std::option::Option::Some(#class_module_name));
::pyo3::pyobject_native_type_info!(#class_ident, ::pyo3::pyobject_native_static_type_object!(::pyo3::ffi::PyBaseObject_Type), ::std::option::Option::Some(#object_name));
::pyo3::pyobject_native_type_extract!(#class_ident);
#[automatically_derived]
impl #class_ident {
Expand Down
51 changes: 43 additions & 8 deletions pyo3_bindgen_engine/src/bindgen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub fn bind_function<S: ::std::hash::BuildHasher + Default>(
name: &str,
function: &pyo3::PyAny,
all_types: &std::collections::HashSet<String, S>,
method_of_class: Option<&pyo3::types::PyType>,
) -> Result<proc_macro2::TokenStream, pyo3::PyErr> {
let inspect = py.import("inspect")?;

Expand Down Expand Up @@ -136,6 +137,8 @@ pub fn bind_function<S: ::std::hash::BuildHasher + Default>(
let has_self_param = parameters
.iter()
.any(|(param_name, _, _, _)| param_name == "self");
let is_class_method =
method_of_class.is_some() && (!has_self_param || function_name == "__init__");

let param_idents = parameters
.iter()
Expand Down Expand Up @@ -167,13 +170,25 @@ pub fn bind_function<S: ::std::hash::BuildHasher + Default>(
doc = String::new();
};

let (maybe_ref_self, callable_object) = if has_self_param {
(quote::quote! { &'py self, }, quote::quote! { self })
let (has_self_param, is_class_method) = if function_name == "__call__" {
(true, false)
} else {
(
(has_self_param, is_class_method)
};

let (maybe_ref_self, callable_object) = match (has_self_param, is_class_method) {
(true, false) => (quote::quote! { &'py self, }, quote::quote! { self }),
(_, true) => {
let class_name = method_of_class.unwrap().name().unwrap();
(
quote::quote! {},
quote::quote! { py.import(::pyo3::intern!(py, #module_name))?.getattr(::pyo3::intern!(py, #class_name))?},
)
}
_ => (
quote::quote! {},
quote::quote! { py.import(::pyo3::intern!(py, #module_name))? },
)
),
};

let has_positional_args = !positional_args_idents.is_empty();
Expand Down Expand Up @@ -211,21 +226,41 @@ pub fn bind_function<S: ::std::hash::BuildHasher + Default>(
#(__internal_kwargs.set_item(::pyo3::intern!(py, #keyword_args_names), #keyword_args_idents)?;)*
};

let call_method = match (has_positional_args, has_kwargs) {
(_, true) => {
let is_init_fn = function_name == "__init__";

let call_method = match (is_init_fn, has_positional_args, has_kwargs) {
(true, _, true) => {
quote::quote! {
#set_args
#set_kwargs
#callable_object.call(__internal_args, Some(__internal_kwargs))?
}
}
(true, true, false) => {
quote::quote! {
#set_args
#callable_object.call1(__internal_args)?
}
}
(true, false, false) => {
quote::quote! {
#callable_object.call0()?
}

Check warning on line 248 in pyo3_bindgen_engine/src/bindgen/function.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/function.rs#L246-L248

Added lines #L246 - L248 were not covered by tests
}
(false, _, true) => {
quote::quote! {
#set_args
#set_kwargs
#callable_object.call_method(::pyo3::intern!(py, #function_name), __internal_args, Some(__internal_kwargs))?
}
}
(true, false) => {
(false, true, false) => {
quote::quote! {
#set_args
#callable_object.call_method1(::pyo3::intern!(py, #function_name), __internal_args)?
}
}
(false, false) => {
(false, false, false) => {
quote::quote! {
#callable_object.call_method0(::pyo3::intern!(py, #function_name))?
}
Expand Down
22 changes: 16 additions & 6 deletions pyo3_bindgen_engine/src/bindgen/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub fn bind_module<S: ::std::hash::BuildHasher + Default>(
})
.filter(|&(name, _, _)| {
// Skip private attributes
!name.starts_with('_') || name == "__init__" || name == "__call__"
!name.starts_with('_')
})
.filter(|(_, attr, attr_type)| {
// Skip typing attributes
Expand Down Expand Up @@ -153,7 +153,7 @@ pub fn bind_module<S: ::std::hash::BuildHasher + Default>(
let content = if is_class {
bind_class(py, root_module, attr.downcast().unwrap(), all_types).unwrap()
} else if is_function {
bind_function(py, full_module_name, name, attr, all_types).unwrap()
bind_function(py, full_module_name, name, attr, all_types, None).unwrap()

Check warning on line 156 in pyo3_bindgen_engine/src/bindgen/module.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/module.rs#L156

Added line #L156 was not covered by tests
} else {
unreachable!()
};
Expand Down Expand Up @@ -201,7 +201,10 @@ pub fn bind_module<S: ::std::hash::BuildHasher + Default>(
attr,
));
}
} else if is_reexport {
} else if is_reexport
&& (is_function
|| (is_class && all_types.contains(&format!("{full_module_name}.{name}"))))
{

Check warning on line 207 in pyo3_bindgen_engine/src/bindgen/module.rs

View check run for this annotation

Codecov / codecov/patch

pyo3_bindgen_engine/src/bindgen/module.rs#L205-L207

Added lines #L205 - L207 were not covered by tests
mod_token_stream.extend(bind_reexport(
root_module_name,
full_module_name,
Expand All @@ -216,7 +219,14 @@ pub fn bind_module<S: ::std::hash::BuildHasher + Default>(
all_types,
));
} else if is_function {
mod_token_stream.extend(bind_function(py, full_module_name, name, attr, all_types));
mod_token_stream.extend(bind_function(
py,
full_module_name,
name,
attr,
all_types,
None,
));
} else {
mod_token_stream.extend(bind_attribute(
py,
Expand Down Expand Up @@ -383,7 +393,7 @@ pub fn collect_types_of_module<S: ::std::hash::BuildHasher + Clone>(
})
.filter(|&(name, _, _)| {
// Skip private attributes
!name.starts_with('_') || name == "__init__" || name == "__call__"
!name.starts_with('_')
})
.filter(|(_, attr, attr_type)| {
// Skip typing attributes
Expand Down Expand Up @@ -502,7 +512,7 @@ pub fn collect_types_of_module<S: ::std::hash::BuildHasher + Clone>(
all_types,
);
}
} else if is_class {
} else if is_class && !attr.to_string().contains("<locals>") {
let full_class_name =
format!("{}.{}", full_module_name, attr.getattr("__name__").unwrap());
all_types.insert(full_class_name.clone());
Expand Down
Loading

0 comments on commit 2acac48

Please sign in to comment.