Skip to content

Commit

Permalink
Kick out py_shared_ptr and instead keep the python instance alive man…
Browse files Browse the repository at this point in the history
…ually by holding a reference to it

See pybind/pybind11#1389 for why py_shared_ptr was needed in the first place, and the comment from May 27 why we may not want to use it (reference cycle)
  • Loading branch information
florianwechsung committed Jul 8, 2021
1 parent 1e412b8 commit 69c55ff
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 54 deletions.
5 changes: 3 additions & 2 deletions src/simsopt/field/magneticfieldclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,10 @@ class InterpolatedField(sopp.InterpolatedField, MagneticField):
This resulting interpolant can then be evaluated very quickly.
"""

def __init__(self, *args):
def __init__(self, underlying_field, *args):
self.__underlying_field = underlying_field
MagneticField.__init__(self)
sopp.InterpolatedField.__init__(self, *args)
sopp.InterpolatedField.__init__(self, underlying_field, *args)

def to_vtk(self, filename, h=0.1):
"""Export the field evaluated on a regular grid for visualisation with e.g. Paraview."""
Expand Down
30 changes: 0 additions & 30 deletions src/simsoptpp/py_shared_ptr.h

This file was deleted.

2 changes: 0 additions & 2 deletions src/simsoptpp/python.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
#define FORCE_IMPORT_ARRAY
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
Expand Down
9 changes: 4 additions & 5 deletions src/simsoptpp/python_curves.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;


Expand Down Expand Up @@ -89,16 +88,16 @@ template <typename T, typename S> void register_common_curve_methods(S &c) {
}

void init_curves(py::module_ &m) {
auto pycurve = py::class_<PyCurve, py_shared_ptr<PyCurve>, PyCurveTrampoline<PyCurve>>(m, "Curve")
auto pycurve = py::class_<PyCurve, PyCurveTrampoline<PyCurve>, shared_ptr<PyCurve>>(m, "Curve")
.def(py::init<vector<double>>());
register_common_curve_methods<PyCurve>(pycurve);

auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, py_shared_ptr<PyCurveXYZFourier>, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, shared_ptr<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
.def(py::init<vector<double>, int>())
.def_readonly("dofs", &PyCurveXYZFourier::dofs);
register_common_curve_methods<PyCurveXYZFourier>(pycurvexyzfourier);

auto pycurverzfourier = py::class_<PyCurveRZFourier, py_shared_ptr<PyCurveRZFourier>, PyCurveRZFourierTrampoline<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
auto pycurverzfourier = py::class_<PyCurveRZFourier, PyCurveRZFourierTrampoline<PyCurveRZFourier>, shared_ptr<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
//.def(py::init<int, int>())
.def(py::init<vector<double>, int, int, bool>())
.def_readwrite("rc", &PyCurveRZFourier::rc)
Expand Down
21 changes: 10 additions & 11 deletions src/simsoptpp/python_magneticfield.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
namespace py = pybind11;
#include "xtensor-python/pyarray.hpp" // Numpy bindings
#include "xtensor-python/pytensor.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;

Expand Down Expand Up @@ -53,17 +52,17 @@ template <typename T, typename S> void register_common_field_methods(S &c) {

void init_magneticfields(py::module_ &m){

py::class_<InterpolationRule, py_shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
py::class_<InterpolationRule, shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
.def_readonly("degree", &InterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");

py::class_<UniformInterpolationRule, py_shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
py::class_<UniformInterpolationRule, shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
.def(py::init<int>())
.def_readonly("degree", &UniformInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");
py::class_<ChebyshevInterpolationRule, py_shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
py::class_<ChebyshevInterpolationRule, shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
.def(py::init<int>())
.def_readonly("degree", &ChebyshevInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");

py::class_<RegularGridInterpolant3D<PyTensor>, py_shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
py::class_<RegularGridInterpolant3D<PyTensor>, shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
R"pbdoc(
Interpolates a (vector valued) function on a uniform grid.
This interpolant is optimized for fast function evaluation (at the cost of memory usage). The main purpose of this class is to be used to interpolate magnetic fields and then use the interpolant for tasks such as fieldline or particle tracing for which the field needs to be evaluated many many times.
Expand All @@ -74,32 +73,32 @@ void init_magneticfields(py::module_ &m){
.def("evaluate_batch", &RegularGridInterpolant3D<PyTensor>::evaluate_batch, "Evaluate the interpolant at multiple points (faster than `evaluate` as it uses prefetching).");


py::class_<Current<PyArray>, py_shared_ptr<Current<PyArray>>>(m, "Current", "Simple class that wraps around a single double representing a coil current.")
py::class_<Current<PyArray>, shared_ptr<Current<PyArray>>>(m, "Current", "Simple class that wraps around a single double representing a coil current.")
.def(py::init<double>())
.def("set_dofs", &Current<PyArray>::set_dofs, "Set the current.")
.def("get_dofs", &Current<PyArray>::get_dofs, "Get the current.")
.def("set_value", &Current<PyArray>::set_value, "Set the current.")
.def("get_value", &Current<PyArray>::get_value, "Get the current.");


py::class_<Coil<PyArray>, py_shared_ptr<Coil<PyArray>>>(m, "Coil", "Optimizable that represents a coil, consisting of a curve and a current.")
py::class_<Coil<PyArray>, shared_ptr<Coil<PyArray>>>(m, "Coil", "Optimizable that represents a coil, consisting of a curve and a current.")
.def(py::init<shared_ptr<Curve<PyArray>>, shared_ptr<Current<PyArray>>>())
.def_readonly("curve", &Coil<PyArray>::curve, "Get the underlying curve.")
.def_readonly("current", &Coil<PyArray>::current, "Get the underlying current.");

auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, py_shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
.def(py::init<>());
register_common_field_methods<PyMagneticField>(mf);
//.def("B", py::overload_cast<>(&PyMagneticField::B));

auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, py_shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
.def(py::init<vector<shared_ptr<Coil<PyArray>>>>())
.def("compute", &PyBiotSavart::compute)
.def("fieldcache_get_or_create", &PyBiotSavart::fieldcache_get_or_create)
.def("fieldcache_get_status", &PyBiotSavart::fieldcache_get_status);
register_common_field_methods<PyBiotSavart>(bs);

auto ifield = py::class_<PyInterpolatedField, py_shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
auto ifield = py::class_<PyInterpolatedField, shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
.def(py::init<shared_ptr<PyMagneticField>, InterpolationRule, RangeTriplet, RangeTriplet, RangeTriplet, bool>())
.def(py::init<shared_ptr<PyMagneticField>, int, RangeTriplet, RangeTriplet, RangeTriplet, bool>())
.def("estimate_error_B", &PyInterpolatedField::estimate_error_B)
Expand Down
3 changes: 1 addition & 2 deletions src/simsoptpp/python_surfaces.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;

Expand Down
3 changes: 1 addition & 2 deletions src/simsoptpp/python_tracing.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
namespace py = pybind11;
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "xtensor-python/pytensor.hpp" // Numpy bindings
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;
#include "tracing.h"
Expand Down

0 comments on commit 69c55ff

Please sign in to comment.