Skip to content

Commit

Permalink
Implement TryExtend for PyDict
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Oct 29, 2024
1 parent 6eea38c commit e4faf4e
Showing 1 changed file with 131 additions and 1 deletion.
132 changes: 131 additions & 1 deletion src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::ffi_ptr_ext::FfiPtrExt;
use crate::instance::{Borrowed, Bound};
use crate::py_result_ext::PyResultExt;
use crate::types::{PyAny, PyAnyMethods, PyList, PyMapping};
use crate::{ffi, BoundObject, IntoPyObject, Python};
use crate::{ffi, BoundObject, IntoPyObject, Python, TryExtend};

/// Represents a Python `dict`.
///
Expand Down Expand Up @@ -408,6 +408,98 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
}
}

impl<'py, I> TryExtend<I, (Bound<'py, PyAny>, Bound<'py, PyAny>)> for Bound<'_, PyDict>
where
I: IntoIterator<Item = (Bound<'py, PyAny>, Bound<'py, PyAny>)>,
{
#[cfg(not(feature = "nightly"))]
fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter()
.try_for_each(|(key, value)| self.set_item(key, value))
}

#[cfg(feature = "nightly")]
default fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter()
.try_for_each(|(key, value)| self.set_item(key, value))
}
}

impl<'py, I> TryExtend<I, PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>> for Bound<'_, PyDict>
where
I: IntoIterator<Item = PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>>,
{
#[cfg(not(feature = "nightly"))]
fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter().try_for_each(|item| {
let (key, value) = item?;
self.set_item(key, value)
})
}

#[cfg(feature = "nightly")]
default fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter().try_for_each(|item| {
let (key, value) = item?;
self.set_item(key, value)
})
}
}

impl<'py, I> TryExtend<I, Bound<'py, PyAny>> for Bound<'_, PyDict>
where
I: IntoIterator<Item = Bound<'py, PyAny>>,
{
#[cfg(not(feature = "nightly"))]
fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter().try_for_each(|item| {
let (key, value): (Bound<'py, PyAny>, Bound<'py, PyAny>) = item.extract()?;
self.set_item(key, value)
})
}

#[cfg(feature = "nightly")]
default fn try_extend(&mut self, iter: I) -> PyResult<()> {
iter.into_iter().try_for_each(|item| {
let (key, value): (Bound<'py, PyAny>, Bound<'py, PyAny>) = item.extract()?;
self.set_item(key, value)
})
}
}

#[cfg(feature = "nightly")]
impl<'py> TryExtend<Bound<'py, PyDict>, (Bound<'py, PyAny>, Bound<'py, PyAny>)>
for Bound<'_, PyDict>
{
#[cfg(feature = "nightly")]
fn try_extend(&mut self, iter: Bound<'py, PyDict>) -> PyResult<()> {
err::error_on_minusone(iter.py(), unsafe {
ffi::PyDict_Merge(self.as_ptr(), iter.as_ptr(), 1)
})
}
}

macro_rules! impl_try_extend_specialization(
($i:ty, $t:ty) => {
#[cfg(feature = "nightly")]
impl<'py> TryExtend<$i, $t> for Bound<'_, PyDict> {
fn try_extend(&mut self, iter: $i) -> PyResult<()> {
err::error_on_minusone(iter.py(), unsafe {
ffi::PyDict_MergeFromSeq2(self.as_ptr(), iter.as_ptr(), 1)
})
}
}
}
);

impl_try_extend_specialization!(
Bound<'py, crate::types::PyIterator>,
PyResult<Bound<'py, PyAny>>
);
impl_try_extend_specialization!(Bound<'py, crate::types::PyList>, Bound<'py, PyAny>);
impl_try_extend_specialization!(Bound<'py, crate::types::PySet>, Bound<'py, PyAny>);
impl_try_extend_specialization!(Bound<'py, crate::types::PyTuple>, Bound<'py, PyAny>);

impl<'a, 'py> Borrowed<'a, 'py, PyDict> {
/// Iterates over the contents of this dictionary without incrementing reference counts.
///
Expand Down Expand Up @@ -1652,4 +1744,42 @@ mod tests {
.is_err());
});
}

#[test]
fn test_dict_extend() {
Python::with_gil::<_, PyResult<()>>(|py| {
let mut dict = PyDict::new(py);

let vec = vec![(
Bound::into_any(1.into_pyobject(py)?),
Bound::into_any(1.into_pyobject(py)?),
)];
dict.try_extend(vec)?;

let slice = [(
Bound::into_any(2.into_pyobject(py)?),
Bound::into_any(2.into_pyobject(py)?),
)];
dict.try_extend(slice)?;

let other_dict = [(3, 3)].into_py_dict(py)?;
dict.try_extend(other_dict)?;

let list = PyList::new(py, [(4, 4)])?;
dict.try_extend(list)?;

let tuple = PyTuple::new(py, [(5, 5)])?;
dict.try_extend(tuple)?;

assert_eq!(dict.len(), 5);
assert!(dict.iter().all(|(k, v)| {
let k = k.extract::<i32>().unwrap();
let v = v.extract::<i32>().unwrap();
k == v
}));

Ok(())
})
.unwrap();
}
}

0 comments on commit e4faf4e

Please sign in to comment.