Skip to content

Commit

Permalink
upgrade code for jax 0.8.5
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba committed Jun 29, 2024
1 parent 572128b commit 066eb68
Show file tree
Hide file tree
Showing 81 changed files with 398 additions and 438 deletions.
4 changes: 2 additions & 2 deletions examples/bring_in_your_own.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __call__(self, x):
from fortuna.utils.random import generate_rng_like_tree
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp


Expand All @@ -86,7 +86,7 @@ def log_joint_prob(self, params: Params) -> float:
v = jnp.mean((ravel_pytree(params)[0] <= 1) & (ravel_pytree(params)[0] >= 0))
return jnp.where(v == 1.0, jnp.array(0), -jnp.inf)

def sample(self, params_like: Params, rng: Optional[PRNGKeyArray] = None) -> Params:
def sample(self, params_like: Params, rng: Optional[jax.Array] = None) -> Params:
if rng is None:
rng = self.rng.get()
keys = generate_rng_like_tree(rng, params_like)
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

from flax.core import FrozenDict
from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp
from optax._src.base import PyTree

Expand All @@ -35,7 +35,7 @@ def training_loss_step(
params: Params,
batch: Batch,
mutable: Mutable,
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
unravel: Optional[Callable[[any], PyTree]] = None,
calib_params: Optional[CalibParams] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def validation_step(
state: CalibState,
batch: Batch,
loss_fun: Callable[[Any], Union[float, Tuple[float, dict]]],
rng: PRNGKeyArray,
rng: jax.Array,
n_data: int,
metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None,
unravel: Optional[Callable[[any], PyTree]] = None,
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_model_manager(
)
else:
model_manager = model_manager_cls(model, model_editor)
except ModuleNotFoundError as e:
except ModuleNotFoundError:
logging.warning(
"No module named 'transformer' is installed. "
"If you are not working with models from the `transformers` library ignore this warning, otherwise "
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/config/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
f"All metrics in `metrics` must be callable objects, but {metric} is not."
)
if uncertainty_fn is not None and not callable(uncertainty_fn):
raise ValueError(f"`uncertainty_fn` must be a a callable function.")
raise ValueError("`uncertainty_fn` must be a a callable function.")

self.metrics = metrics
self.uncertainty_fn = uncertainty_fn
Expand Down
4 changes: 2 additions & 2 deletions fortuna/calib_model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Tuple,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.likelihood.base import Likelihood
Expand Down Expand Up @@ -40,7 +40,7 @@ def __call__(
return_aux: Optional[List[str]] = None,
train: bool = False,
outputs: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Tuple[jnp.ndarray, Any]:
if return_aux is None:
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calib_model/predictive/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.data.loader import (
Expand Down Expand Up @@ -60,7 +60,7 @@ def sample(
self,
inputs_loader: InputsLoader,
n_samples: int = 1,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -80,7 +80,7 @@ def sample(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
14 changes: 7 additions & 7 deletions fortuna/calib_model/predictive/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Union,
)

from jax._src.prng import PRNGKeyArray
import jax
import jax.numpy as jnp

from fortuna.calib_model.predictive.base import Predictive
Expand All @@ -21,7 +21,7 @@ def entropy(
self,
inputs_loader: InputsLoader,
n_samples: int = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -42,7 +42,7 @@ def entropy(
A loader of input data points.
n_samples : int
Number of samples to draw for each input.
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand All @@ -67,7 +67,7 @@ def quantile(
q: Union[float, Array, List],
inputs_loader: InputsLoader,
n_samples: Optional[int] = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> Union[float, jnp.ndarray]:
r"""
Expand All @@ -81,7 +81,7 @@ def quantile(
A loader of input data points.
n_samples : int
Number of target samples to sample for each input data point.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down Expand Up @@ -109,7 +109,7 @@ def credible_interval(
n_samples: int = 30,
error: float = 0.05,
interval_type: str = "two-tailed",
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
) -> jnp.ndarray:
r"""
Expand All @@ -126,7 +126,7 @@ def credible_interval(
`error=0.05` corresponds to a 95% level of credibility.
interval_type: str
The interval type. We support "two-tailed" (default), "right-tailed" and "left-tailed".
rng : Optional[PRNGKeyArray]
rng : Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

from jax import vmap
import jax.numpy as jnp

from fortuna.conformal.multivalid.mixins.multicalibrator import MulticalibratorMixin
Expand Down
2 changes: 0 additions & 2 deletions fortuna/conformal/multivalid/one_shot/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import abc
import logging
from typing import (
Dict,
Optional,
Tuple,
Union,
)

Expand Down
14 changes: 7 additions & 7 deletions fortuna/data/dataset/huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
Dataset,
DatasetDict,
)
import jax
from jax import numpy as jnp
import jax.random
from jax.random import PRNGKeyArray
from tqdm import tqdm
from transformers import (
BatchEncoding,
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_data_loader(
self,
dataset: Dataset,
per_device_batch_size: int,
rng: PRNGKeyArray,
rng: jax.Array,
shuffle: bool = False,
drop_last: bool = False,
verbose: bool = False,
Expand All @@ -105,7 +105,7 @@ def get_data_loader(
A tokenizeed dataset (see :meth:`.HuggingFaceClassificationDatasetABC.get_tokenized_datasets`).
per_device_batch_size: bool
Batch size for each device.
rng: PRNGKeyArray
rng: jax.Array
Random number generator.
shuffle: bool
if True, shuffle the data so that each batch is a ranom sample from the dataset.
Expand Down Expand Up @@ -141,7 +141,7 @@ def _collate(self, batch: Dict[str, Array], batch_size: int) -> Dict[str, Array]

@staticmethod
def _get_batches_idxs(
rng: PRNGKeyArray,
rng: jax.Array,
dataset_size: int,
batch_size: int,
shuffle: bool = False,
Expand All @@ -167,7 +167,7 @@ def _get_data_loader(
batch_size: int,
shuffle: bool = False,
drop_last: bool = False,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
verbose: bool = False,
) -> Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array], Array]]]:
batch_idxs_gen = self._get_batches_idxs(
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(
super(HuggingFaceMaskedLMDataset, self).__init__(*args, **kwargs)
if not self.tokenizer.is_fast:
logger.warning(
f"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
"You are not using a Fast Tokenizer, so whole words cannot be masked, only tokens."
)
self.mlm = mlm
self.mlm_probability = mlm_probability
Expand Down Expand Up @@ -407,7 +407,7 @@ def get_tokenized_datasets(
), "Only one text column should be passed when the task is MaskedLM."

def _tokenize_fn(
batch: Dict[str, List[Union[str, int]]]
batch: Dict[str, List[Union[str, int]]],
) -> Dict[str, List[int]]:
tokenized_inputs = self.tokenizer(
*[batch[col] for col in text_columns],
Expand Down
10 changes: 5 additions & 5 deletions fortuna/distribution/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Union

import jax
from jax import (
random,
vmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
from jax.scipy.stats import (
multivariate_normal,
Expand All @@ -30,11 +30,11 @@ def __init__(self, mean: Union[float, Array], std: Union[float, Array]):
self.std = std
self.dim = 1 if type(mean) in [int, float] else len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the diagonal Gaussian.
:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down Expand Up @@ -72,11 +72,11 @@ def __init__(self, mean: Array, cov: Array):
self.cov = cov
self.dim = len(mean)

def sample(self, rng: PRNGKeyArray, n_samples: int = 1) -> jnp.ndarray:
def sample(self, rng: jax.Array, n_samples: int = 1) -> jnp.ndarray:
"""
Sample from the multivariate Gaussian.
:param rng: PRNGKeyArray
:param rng: jax.Array
Random number generator.
:param n_samples: int
Number of samples.
Expand Down
1 change: 0 additions & 1 deletion fortuna/hallucination/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down
12 changes: 6 additions & 6 deletions fortuna/likelihood/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
Union,
)

import jax
from jax import (
jit,
pmap,
)
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

from fortuna.data.loader import (
Expand Down Expand Up @@ -133,7 +133,7 @@ def _batched_log_joint_prob(
return_aux: Optional[List[str]] = None,
train: bool = False,
outputs: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Any]]:
"""
Expand Down Expand Up @@ -161,7 +161,7 @@ def _batched_log_joint_prob(
Whether the method is called during training.
outputs : Optional[jnp.ndarray]
Pre-computed batch of outputs.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
Returns
Expand Down Expand Up @@ -272,7 +272,7 @@ def sample(
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
return_aux: Optional[List[str]] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
Expand All @@ -296,7 +296,7 @@ def sample(
return_aux : Optional[List[str]]
The auxiliary objects to return. We support 'outputs'. If this argument is not given, no auxiliary object
is returned.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down Expand Up @@ -345,7 +345,7 @@ def _batched_sample(
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
return_aux: Optional[List[str]] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, dict]]:
if return_aux is None:
Expand Down
8 changes: 4 additions & 4 deletions fortuna/likelihood/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Union,
)

import jax
from jax import vmap
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -106,7 +106,7 @@ def entropy(
calib_params: Optional[CalibParams] = None,
calib_mutable: Optional[CalibMutable] = None,
n_target_samples: Optional[int] = 30,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
**kwargs,
) -> jnp.ndarray:
Expand Down Expand Up @@ -139,7 +139,7 @@ def quantile(
calib_mutable: Optional[CalibMutable] = None,
n_target_samples: Optional[int] = 30,
target_samples: Optional[jnp.ndarray] = None,
rng: Optional[PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
distribute: bool = True,
**kwargs,
) -> Union[float, jnp.ndarray]:
Expand All @@ -164,7 +164,7 @@ def quantile(
Number of target samples to sample for each input data point.
target_samples: Optional[jnp.ndarray] = None
Samples of the target variable for each input, used to estimate the quantiles.
rng: Optional[PRNGKeyArray]
rng: Optional[jax.Array]
A random number generator. If not passed, this will be taken from the attributes of this class.
distribute: bool
Whether to distribute computation over multiple devices, if available.
Expand Down
Loading

0 comments on commit 066eb68

Please sign in to comment.