diff --git a/pyproject.toml b/pyproject.toml index 6c3f509..d08a59f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,16 +40,24 @@ dependencies = [ ] [project.optional-dependencies] +fit_base = [ + "scipy", + "matplotlib", +] fit = [ + "crispio[fit_base]", + "jax", + "optax", +] +fit_cuda12 = [ + "crispio[fit_base]", "jax[cuda12]", "optax", - "scipy", - "matplotlib", ] -fit_cuda_local = [ +fit_cuda12_local = [ + "crispio[fit_base]", "jax[cuda12_local]", "optax", - "scipy", ] [project.urls]