Skip to content

Commit

Permalink
Improve PyArrayCapsule design by disaplying copy constructor & suppor…
Browse files Browse the repository at this point in the history
…ting different types via PyCapsule<T> template class
  • Loading branch information
Hamdi Sahloul committed Jul 30, 2021
1 parent b6c60e7 commit 0a0aeec
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions python/bindings/include/openravepy/openravepy_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class OPENRAVEPY_API PythonGILSaver
}
};

// TODO: Remove AutoPyArrayObjectDereferencer in favor of PyArrayCapsule
// TODO: Remove AutoPyArrayObjectDereferencer in favor of PyCapsule
class OPENRAVEPY_API AutoPyArrayObjectDereferencer
{
public:
Expand All @@ -231,51 +231,61 @@ class OPENRAVEPY_API AutoPyArrayObjectDereferencer
PyArrayObject* _pyarrobj;
};

class OPENRAVEPY_API PyArrayCapsule
template<typename T>
class OPENRAVEPY_API PyCapsule
{
public:
// \brief the only entrypoint that allows a null pointer. Expects reset to follow
PyArrayCapsule()
// \brief disables copy constructor to avoid calling Py_DECREF more than once
PyCapsule(const PyCapsule<T>&) = delete;

// \brief the only entrypoint that allows a nullptr. Expects a `reset` call to follow
PyCapsule()
: _ptr(nullptr)
{
}

// \param ptr a not-null pointer
PyArrayCapsule(PyArrayObject* ptr)
PyCapsule(T* ptr)
: _ptr(_PreventNullPtr(ptr))
{
reset(ptr); // Avoid initializer, and use a single entry-point
}

// \param ptr a not-null pointer
void reset(PyArrayObject* ptr)
void reset(T* ptr)
{
if (!ptr) {
throw OPENRAVE_EXCEPTION_FORMAT0(_("Invalid Numpy-array pointer. Failed to get contiguous array?"), ORE_InvalidArguments);
}
_ptr = ptr;
_ptr = _PreventNullPtr(ptr);
}

PyArrayObject* get() const
T* get() const
{
return _ptr;
}

operator PyArrayObject*() const
operator T*() const
{
return _ptr;
}

~PyArrayCapsule() {
if (!_ptr) {
return;
~PyCapsule() {
if (_ptr != nullptr) {
Py_DECREF(_ptr);
}
Py_DECREF(_ptr);
}

private:
PyArrayObject* _ptr;
T* _ptr;

static T* _PreventNullPtr(T* ptr)
{
if (ptr == nullptr) {
throw OPENRAVE_EXCEPTION_FORMAT0(_("Invalid Numpy-array pointer. Failed to get contiguous array?"), ORE_InvalidArguments);
}
return ptr;
}
};

typedef PyCapsule<PyArrayObject> PyArrayCapsule;

inline RaveVector<float> ExtractFloat3(const py::object& o)
{
return RaveVector<float>(py::extract<float>(o[0]), py::extract<float>(o[1]), py::extract<float>(o[2]));
Expand Down

0 comments on commit 0a0aeec

Please sign in to comment.