torchfilter
is a library for discrete-time Bayesian filtering in PyTorch.
By writing filters as standard PyTorch modules, we get:
- The ability to optimize for system models/parameters that directly minimize end-to-end state estimation error
- Automatic Jacobians with autograd
- GPU acceleration (particularly useful for particle filters)
The package is broken down into six submodules:
torchfilter.base
|
Base classes that define standard interfaces for implementing filter, dynamics, measurement, and virtual sensor models as PyTorch modules. |
torchfilter.filters
|
Differentiable filters implemented as PyTorch modules, which
can either be used directly or subclassed. Currently implemented:
For our PF, we include both standard resampling and the soft/differentiable approach from [1]. UKFs and SR-UKFs are implemented using equations from [2]; approach for handling heteroscedastic noises in the former is based on [3]. For our EKF, EIF, UKF, and SR-UKF, we also provide “virtual sensor” implementations, which use a (raw observation => state prediction/uncertainty) mapping and an identity as the measurement model. This is similar to the discriminative strategy described in [4]. |
torchfilter.train
|
Training loop helpers. These are currently coupled tightly with a custom experiment management framework, but may be useful as a reference. |
torchfilter.data
|
Dataset interfaces used for our training loop helpers. |
torchfilter.utils
|
General utilities; currently only contains helpers for performing unscented transforms. |
torchfilter.types
|
Data structures and semantic type aliases. |
For more details, see the API reference.
For a linear system example, see tests/_linear_system_models.py
. A more
complex application can be found in code for our IROS 2020 work [5]:
GitHub repository.
From source:
$ git clone https://github.com/stanford-iprl-lab/torchfilter.git
$ cd torchfilter
$ pip install -e .
Tests can be run with pytest
, and documentation can be built by running
make dirhtml
in the docs/
directory.
Tooling: black and isort for formatting, flake8 for linting, and mypy for static type checking.
Until numpy 1.20.0
is released,
static analysis works best with NumPy stubs installed manually:
pip install https://github.com/numpy/numpy-stubs/tarball/master
This library is based on code written for our IROS 2020 work [5].
[1] P. Karkus, D. Hsu, and W. S. Lee, "Particle filter networks with application to visual localization", in Conference on Robot Learning, 2018, pp. 169–178.
[2] R. Van der Merwe and E. A. Wan, "The square-root unscented Kalman filter for state and parameter-estimation", in IEEE International Conference on Acoustics, Speech, and Signal Processing, 2001, pp. 3461-3464 vol.6.
[3] A. Kloss, G. Martius, and J. Bohg, "How to Train Your Differentiable Filter", in Robotics: Science and Systems (RSS) Workshop on Structured Approaches to Robot Learning for Improved Generalization, 2020.
[4] T. Haarnoja, A. Ajay, S. Levine, and P. Abbeel, "Backprop KF: Learning discriminative deterministic state estimators", in Advances in Neural Information Processing Systems, 2016, pp. 4376–4384.
[5] M. Lee*, B. Yi*, R. Martín-Martín, S. Savarese, J. Bohg, "Multimodal Sensor Fusion with Differentiable Filters", in IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), 2020.
brentyi (at) stanford (dot) edu