Skip to content

Commit

Permalink
updated readmes. updated pyo3 dev dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
kachark committed Mar 30, 2024
1 parent 8b002aa commit 8dc810b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 87 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ ndarray-rand = "0.14"
cxx-build = "1.0"

[dev-dependencies]
pyo3 = { version = "0.15.1", features = ["auto-initialize"] }
numpy = "0.15"
pyo3 = { version = "0.20", features = ["auto-initialize"] }
numpy = "0.20"
rand = "0.8"
criterion = "0.3"

Expand Down
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,24 @@ let max_cost = cost.max().unwrap();
cost = &cost / *max_cost;

// Compute the optimal transport matrix
let ot_matrix = match EarthMovers::new(
let ot_matrix = EarthMovers::new(
&mut source_weights,
&mut target_weights,
&mut cost
).solve()?;

```

## Testing
```
cargo test
```

If using M1 mac and linking against Homebrew's OpenBLAS, add the following to `build.rs`:
```
println!("cargo:rustc-link-search=/opt/homebrew/opt/openblas/lib");
```

## Acknowledgements

This library is inspired by Python Optimal Transport. The original authors and contributors of that project are listed at [POT](https://github.com/PythonOT/POT#acknowledgements).
Expand Down
15 changes: 15 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,19 @@ pip install matplotlib
NOTE: sinkhorn_1D example must be run from sinkhorn_1D directory in order to use the included
python script for visualization. See below:

## M1 Mac + Homebrew OpenBLAS
If OpenBLAS is installed Homebrew on an M1 mac, you may need to add the following to `build.rs`:
```
println!("cargo:rustc-link-search=/opt/homebrew/opt/openblas/lib");
```

## Anaconda
To link against Anaconda python, you may need to add the following to `build.rs`:
```
println!(
"cargo:rustc-link-arg=-Wl,-rpath,/path/to/anaconda3/lib/"
);
```

![](https://github.com/kachark/rust-optimal-transport/blob/main/assets/sinkhorn_1D_gaussian.png)
2 changes: 1 addition & 1 deletion examples/emd_2D/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use ndarray::prelude::*;
use ndarray_stats::QuantileExt;

use ot::prelude::*;
use rust_optimal_transport as ot;
use ot::prelude::*;

mod plot;

Expand Down
90 changes: 45 additions & 45 deletions examples/emd_2D/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,63 @@ pub fn plot_py(
let target_y = target_samples.slice(s![.., 1]);

// Start the python interpreter
let gil = Python::acquire_gil();
let py = gil.python();
Python::with_gil(|py| {

// Import matplotlib
let plt = py.import("matplotlib.pyplot")?;
// Import matplotlib
let plt = py.import("matplotlib.pyplot")?;

// Translate to numpy array
let source_x_py = source_x.to_pyarray(py);
let source_y_py = source_y.to_pyarray(py);
// Translate to numpy array
let source_x_py = source_x.to_pyarray(py);
let source_y_py = source_y.to_pyarray(py);

let target_x_py = target_x.to_pyarray(py);
let target_y_py = target_y.to_pyarray(py);
let target_x_py = target_x.to_pyarray(py);
let target_y_py = target_y.to_pyarray(py);

let ot_matrix_py = ot_matrix.to_pyarray(py);
let ot_matrix_py = ot_matrix.to_pyarray(py);

// Plot by calling into matplotlib
// Plot by calling into matplotlib

// plot ot matrix
plt.getattr("figure")?.call1((1,))?;
plt.call_method(
"imshow",
(ot_matrix_py,),
Some(vec![("interpolation", "nearest")].into_py_dict(py)),
)?;
plt.call_method1("title", ("OT matrix",))?;
// plot ot matrix
plt.getattr("figure")?.call1((1,))?;
plt.call_method(
"imshow",
(ot_matrix_py,),
Some(vec![("interpolation", "nearest")].into_py_dict(py)),
)?;
plt.call_method1("title", ("OT matrix",))?;

// plot data with coupling between source and target distributions
plt.getattr("figure")?.call1((2,))?;
// plot data with coupling between source and target distributions
plt.getattr("figure")?.call1((2,))?;

let threshold = 1E-8;
for i in 0..ot_matrix.shape()[0] {
for j in 0..ot_matrix.shape()[1] {
if ot_matrix[[i, j]] > threshold {
let args = (
array![source_x[i], target_x[j]].to_pyarray(py),
array![source_y[i], target_y[j]].to_pyarray(py),
);
let kwargs = Some(vec![("color", "0.8")].into_py_dict(py));
let threshold = 1E-8;
for i in 0..ot_matrix.shape()[0] {
for j in 0..ot_matrix.shape()[1] {
if ot_matrix[[i, j]] > threshold {
let args = (
array![source_x[i], target_x[j]].to_pyarray(py),
array![source_y[i], target_y[j]].to_pyarray(py),
);
let kwargs = Some(vec![("color", "0.8")].into_py_dict(py));

plt.call_method("plot", args, kwargs)?;
plt.call_method("plot", args, kwargs)?;
}
}
}
}

plt.call_method(
"plot",
(source_x_py, source_y_py, "+b"),
Some(vec![("label", "Source samples")].into_py_dict(py)),
)?;
plt.call_method(
"plot",
(target_x_py, target_y_py, "xr"),
Some(vec![("label", "Target samples")].into_py_dict(py)),
)?;
plt.getattr("legend")?.call0()?;
plt.call_method(
"plot",
(source_x_py, source_y_py, "+b"),
Some(vec![("label", "Source samples")].into_py_dict(py)),
)?;
plt.call_method(
"plot",
(target_x_py, target_y_py, "xr"),
Some(vec![("label", "Target samples")].into_py_dict(py)),
)?;
plt.getattr("legend")?.call0()?;

plt.getattr("show")?.call0()?;
plt.getattr("show")?.call0()?;

Ok(())
Ok(())
})
}
2 changes: 1 addition & 1 deletion examples/sinkhorn_1D/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use ndarray::{prelude::*, stack};
use ndarray_stats::QuantileExt;

use ot::prelude::*;
use rust_optimal_transport as ot;
use ot::prelude::*;

mod plot;

Expand Down
74 changes: 37 additions & 37 deletions examples/sinkhorn_1D/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,41 @@ pub fn plot_py(
let target_y = target_samples.slice(s![.., 1]);

// Start the python interpreter
let gil = Python::acquire_gil();
let py = gil.python();

// Import matplotlib
let plt = py.import("matplotlib.pyplot")?;

// Import plotting function by adding python script to path
let pwd = env::current_dir()?;
let syspath: &PyList = py
.import("sys")
.unwrap()
.getattr("path")
.unwrap()
.try_into()
.unwrap();

syspath.insert(0, pwd.display().to_string()).unwrap();
let plot_mod = py.import("plot_1d_mat")?;

// Translate to numpy array
let source_y_py = source_y.to_pyarray(py);
let target_y_py = target_y.to_pyarray(py);
let ot_matrix_py = ot_matrix.to_pyarray(py);

// Plot by calling into matplotlib via python script
plt.call_method(
"figure",
(4,),
Some(vec![("figsize", (5, 5))].into_py_dict(py)),
)?;
plot_mod
.getattr("plot1D_mat")?
.call1((source_y_py, target_y_py, ot_matrix_py, title))?;

plt.getattr("show")?.call0()?;

Ok(())
Python::with_gil(|py| {

// Import matplotlib
let plt = py.import("matplotlib.pyplot")?;

// Import plotting function by adding python script to path
let pwd = env::current_dir()?;
let syspath: &PyList = py
.import("sys")
.unwrap()
.getattr("path")
.unwrap()
.try_into()
.unwrap();

syspath.insert(0, pwd.display().to_string()).unwrap();
let plot_mod = py.import("plot_1d_mat")?;

// Translate to numpy array
let source_y_py = source_y.to_pyarray(py);
let target_y_py = target_y.to_pyarray(py);
let ot_matrix_py = ot_matrix.to_pyarray(py);

// Plot by calling into matplotlib via python script
plt.call_method(
"figure",
(4,),
Some(vec![("figsize", (5, 5))].into_py_dict(py)),
)?;
plot_mod
.getattr("plot1D_mat")?
.call1((source_y_py, target_y_py, ot_matrix_py, title))?;

plt.getattr("show")?.call0()?;

Ok(())
})
}

0 comments on commit 8dc810b

Please sign in to comment.