From ec1f15153e9391ebf5016f02d10364bcaeb2ea3c Mon Sep 17 00:00:00 2001 From: Jan-Matthis Lueckmann Date: Tue, 8 Oct 2024 08:20:01 -0700 Subject: [PATCH] Add jax/flax utilities. PiperOrigin-RevId: 683630717 --- connectomics/jax/checkpoint.py | 139 ++++++ connectomics/jax/config_util.py | 138 ++++++ connectomics/jax/config_util_test.py | 81 +++ connectomics/jax/grain_util.py | 321 ++++++++++++ connectomics/jax/grain_util_test.py | 97 ++++ connectomics/jax/metrics.py | 468 ++++++++++++++++++ connectomics/jax/metrics_test.py | 184 +++++++ connectomics/jax/models/activation.py | 29 ++ connectomics/jax/models/initializer.py | 73 +++ connectomics/jax/models/normalization.py | 215 ++++++++ connectomics/jax/models/normalization_test.py | 72 +++ 11 files changed, 1817 insertions(+) create mode 100644 connectomics/jax/checkpoint.py create mode 100644 connectomics/jax/config_util.py create mode 100644 connectomics/jax/config_util_test.py create mode 100644 connectomics/jax/grain_util.py create mode 100644 connectomics/jax/grain_util_test.py create mode 100644 connectomics/jax/metrics.py create mode 100644 connectomics/jax/metrics_test.py create mode 100644 connectomics/jax/models/activation.py create mode 100644 connectomics/jax/models/initializer.py create mode 100644 connectomics/jax/models/normalization.py create mode 100644 connectomics/jax/models/normalization_test.py diff --git a/connectomics/jax/checkpoint.py b/connectomics/jax/checkpoint.py new file mode 100644 index 0000000..d3f5383 --- /dev/null +++ b/connectomics/jax/checkpoint.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for model checkpointing.""" + +import dataclasses +import re +from typing import Any, Optional, Sequence, TypeVar + +from clu import checkpoint as checkpoint_lib +from etils import epath +import flax +import grain.python as grain +import grain.tensorflow as tfgrain +from orbax import checkpoint as ocp +import tensorflow as tf + + +T = TypeVar('T') + + +class MixedMultihostCheckpoint(checkpoint_lib.MultihostCheckpoint): + """Like MultihostCheckpoint, but with a single source of FLAX weights. + + TF settings are restored per-host as in the base class. + + This prevents the model from loading potentially inconsistent weights + saved by other hosts. Weights might be inconsistent when they are saved + based on wall-clock time instead of step count. + """ + + def load_state( + self, state: Optional[T], checkpoint: Optional[str] = None + ) -> T: + flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint)) + flax_path = re.sub('checkpoints-[0-9]*', 'checkpoints-0', flax_path) + if not tf.io.gfile.exists(flax_path): + raise FileNotFoundError(f'Checkpoint {checkpoint} does not exist') + with tf.io.gfile.GFile(flax_path, 'rb') as f: + return flax.serialization.from_bytes(state, f.read()) + + +def get_checkpoint_manager( + workdir: epath.PathLike, + item_names: Sequence[str], +) -> ocp.CheckpointManager: + """Returns a checkpoint manager.""" + checkpoint_dir = epath.Path(workdir) / 'checkpoints' + return ocp.CheckpointManager( + checkpoint_dir, + item_names=item_names, + options=ocp.CheckpointManagerOptions( + create=True, cleanup_tmp_directories=True), + ) + + +def save_checkpoint( + manager: ocp.CheckpointManager, + state: Any, + step: int, + pygrain_checkpointers: Sequence[str] = ('train_iter',), + wait_until_finished: bool = True, +): + """Saves a checkpoint. + + Args: + manager: Checkpoint manager to use. + state: Data to be saved. + step: Step at which to save the data. + pygrain_checkpointers: Names of items for which to use pygrain checkpointer. + wait_until_finished: If True, blocks until checkpoint is written. + """ + save_args_dict = {} + for k, v in state.items(): + if k in pygrain_checkpointers: + save_args_dict[k] = grain.PyGrainCheckpointSave(v) + else: + save_args_dict[k] = ocp.args.StandardSave(v) + manager.save(step, args=ocp.args.Composite(**save_args_dict)) + if wait_until_finished: + manager.wait_until_finished() + + +def restore_checkpoint( + manager: ocp.CheckpointManager, + state: Any, + step: int | None = None, + pygrain_checkpointers: Sequence[str] = ('train_iter',), +) -> Any: + """Restores a checkpoint. + + Args: + manager: Checkpoint manager to use. + state: Data to be restored. + step: Step at which to save the data. If None, uses latest step. + pygrain_checkpointers: Names of items for which to use pygrain checkpointer. + + Returns: + Restored data. + """ + restore_args_dict = {} + for k, v in state.items(): + if k in pygrain_checkpointers: + restore_args_dict[k] = grain.PyGrainCheckpointRestore(v) + else: + restore_args_dict[k] = ocp.args.StandardRestore(v) + return manager.restore( + manager.latest_step() if step is None else step, + args=ocp.args.Composite(**restore_args_dict)) + + +class TfGrainCheckpointHandler(tfgrain.OrbaxCheckpointHandler): + + def save(self, directory: epath.Path, args: 'TfGrainCheckpointArgs') -> None: + return super().save(directory, args.item) + + def restore( + self, directory: epath.Path, args: 'TfGrainCheckpointArgs' + ) -> tfgrain.TfGrainDatasetIterator: + return super().restore(directory, args.item) + + +@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types + TfGrainCheckpointHandler, for_save=True, for_restore=True +) +@dataclasses.dataclass +class TfGrainCheckpointArgs(ocp.args.CheckpointArgs): + item: Any diff --git a/connectomics/jax/config_util.py b/connectomics/jax/config_util.py new file mode 100644 index 0000000..c46a296 --- /dev/null +++ b/connectomics/jax/config_util.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper tools for config files. + +While configs should remain simple and self-explanatory, it can also be very +useful to augment the configs with a bit of logic that helps organizing +complicated sweeps. + +This module contains shared code that allows for powerful uncluttered configs. +""" + +from typing import Any, Sequence + +import ml_collections as mlc + + +def parse_arg(arg, lazy=False, **spec): + """Makes ConfigDict's get_config single-string argument more usable. + + Example use in the config file: + + import big_vision.configs.common as bvcc + def get_config(arg): + arg = bvcc.parse_arg(arg, + res=(224, int), + runlocal=False, + schedule='short', + ) + + # ... + + config.shuffle_buffer = 250_000 if not arg.runlocal else 50 + + Ways that values can be passed when launching: + + --config amazing.py:runlocal,schedule=long,res=128 + --config amazing.py:res=128 + --config amazing.py:runlocal # A boolean needs no value for "true". + --config amazing.py:runlocal=False # Explicit false boolean. + --config amazing.py:128 # The first spec entry may be passed unnamed alone. + + Uses strict bool conversion (converting 'True', 'true' to True, and 'False', + 'false', '' to False). + + Args: + arg: the string argument that's passed to get_config. + lazy: allow lazy parsing of arguments, which are not in spec. For these, + the type is auto-extracted in dependence of most complex possible type. + **spec: the name and default values of the expected options. + If the value is a tuple, the value's first element is the default value, + and the second element is a function called to convert the string. + Otherwise the type is automatically extracted from the default value. + + Returns: + ConfigDict object with extracted type-converted values. + """ + # Normalize arg and spec layout. + arg = arg or '' # Normalize None to empty string + spec = {k: (v if isinstance(v, tuple) else (v, _get_type(v))) + for k, v in spec.items()} + + result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only. + + # Expand convenience-cases for a single parameter without = sign. + if arg and ',' not in arg and '=' not in arg: + # (think :runlocal) If it's the name of sth in the spec (or there is no + # spec), it's that in bool. + if arg in spec or not spec: + arg = f'{arg}=True' + # Otherwise, it is the value for the first entry in the spec. + else: + arg = f'{list(spec.keys())[0]}={arg}' + # Yes, we rely on Py3.7 insertion order! + + # Now, expand the `arg` string into a dict of keys and values: + raw_kv = {raw_arg.split('=')[0]: + raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' + for raw_arg in arg.split(',') if raw_arg} + + # And go through the spec, using provided or default value for each: + for name, (default, type_fn) in spec.items(): + val = raw_kv.pop(name, None) + result[name] = type_fn(val) if val is not None else default + + if raw_kv: + if lazy: # Process args which are not in spec. + for k, v in raw_kv.items(): + result[k] = _autotype(v) + else: + raise ValueError(f'Unhandled config args remain: {raw_kv}') + + return result + + +def _get_type(v): + """Returns type of v and for boolean returns a strict bool function.""" + if isinstance(v, bool): + def strict_bool(x): + assert x.lower() in {'true', 'false', ''} + return x.lower() == 'true' + return strict_bool + return type(v) + + +def _autotype(x): + """Auto-converts string to bool/int/float if possible.""" + assert isinstance(x, str) + if x.lower() in {'true', 'false'}: + return x.lower() == 'true' # Returns as bool. + try: + return int(x) # Returns as int. + except ValueError: + try: + return float(x) # Returns as float. + except ValueError: + return x # Returns as str. + + +def sequence_to_string(x: Sequence[Any], separator: str = ',') -> str: + """Converts sequence of str/bool/int/float to a concatenated string.""" + return separator.join([str(i) for i in x]) + + +def string_to_sequence(x: str, separator: str = ',') -> Sequence[Any]: + """Converts string to sequence of str/bool/int/float with auto-conversion.""" + return [_autotype(i) for i in x.split(separator)] diff --git a/connectomics/jax/config_util_test.py b/connectomics/jax/config_util_test.py new file mode 100644 index 0000000..092a964 --- /dev/null +++ b/connectomics/jax/config_util_test.py @@ -0,0 +1,81 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for config_util.""" + +from absl.testing import absltest +from absl.testing import parameterized +from connectomics.jax import config_util as cutil + + +class ConfigUtilTest(parameterized.TestCase): + + @parameterized.parameters(False, True) + def test_parse_arg_works(self, lazy): + spec = dict( + res=224, + lr=0.1, + runlocal=False, + schedule='short', + ) + + def check(result, runlocal, schedule, res, lr): + self.assertEqual(result.runlocal, runlocal) + self.assertEqual(result.schedule, schedule) + self.assertEqual(result.res, res) + self.assertEqual(result.lr, lr) + self.assertIsInstance(result.runlocal, bool) + self.assertIsInstance(result.schedule, str) + self.assertIsInstance(result.res, int) + self.assertIsInstance(result.lr, float) + + check(cutil.parse_arg(None, lazy=lazy, **spec), False, 'short', 224, 0.1) + check(cutil.parse_arg('', lazy=lazy, **spec), False, 'short', 224, 0.1) + check(cutil.parse_arg('runlocal=True', lazy=lazy, **spec), True, 'short', + 224, 0.1) + check(cutil.parse_arg('runlocal=False', lazy=lazy, **spec), False, 'short', + 224, 0.1) + check(cutil.parse_arg('runlocal=', lazy=lazy, **spec), False, 'short', 224, + 0.1) + check(cutil.parse_arg('runlocal', lazy=lazy, **spec), True, 'short', 224, + 0.1) + check(cutil.parse_arg('res=128', lazy=lazy, **spec), False, 'short', 128, + 0.1) + check(cutil.parse_arg('128', lazy=lazy, **spec), False, 'short', 128, 0.1) + check(cutil.parse_arg('schedule=long', lazy=lazy, **spec), False, 'long', + 224, 0.1) + check(cutil.parse_arg('runlocal,schedule=long,res=128', lazy=lazy, **spec), + True, 'long', 128, 0.1) + + @parameterized.parameters( + (None, {}, {}), + (None, {'res': 224}, {'res': 224}), + ('640', {'res': 224}, {'res': 640}), + ('runlocal', {}, {'runlocal': True}), + ('res=640,lr=0.1,runlocal=false,schedule=long', {}, + {'res': 640, 'lr': 0.1, 'runlocal': False, 'schedule': 'long'}), + ) + def test_lazy_parse_arg_works(self, arg, spec, expected): + self.assertEqual(dict(cutil.parse_arg(arg, lazy=True, **spec)), expected) + + def test_sequence_to_string(self): + seq = ['a', True, 1, 1.0] + self.assertEqual(cutil.sequence_to_string(seq), 'a,True,1,1.0') + + def test_string_to_sequence(self): + self.assertEqual( + cutil.string_to_sequence('a,True,1,1.0'), ['a', True, 1, 1.0]) + +if __name__ == '__main__': + absltest.main() diff --git a/connectomics/jax/grain_util.py b/connectomics/jax/grain_util.py new file mode 100644 index 0000000..694f97b --- /dev/null +++ b/connectomics/jax/grain_util.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilties for grain. + +Code for `all_ops` and `parse` are forked from scenic -- +this implementation uses `grain.python` rather than `grain.tensorflow`. +""" + +import ast +import dataclasses +import inspect +import re +import sys +from typing import Any, Optional, Sequence, Type + +import grain.python as grain +import numpy as np + +FlatFeatures = dict[str, Any] + +# Regex that finds upper case characters. +_CAMEL_CASE_RGX = re.compile(r'(? FlatFeatures: + for k in self.keys: + if k not in features: continue + arr = features[k].astype(np.float32) + features[k] = np.clip(arr, self.min_value, self.max_value) + return features + + +@dataclasses.dataclass(frozen=False) +class ExpandDims(grain.MapTransform): + """Expands the shape of an array. + + Attr: + keys: Keys to apply the transformation to. + axis: Position for placement. + """ + + keys: str | Sequence[str] = ('x',) + axis: int | Sequence[int] = 0 + + def __post_init__(self): + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + features[k] = np.expand_dims(features[k], axis=self.axis) + return features + + +@dataclasses.dataclass(frozen=False) +class PadValues(grain.MapTransform): + """Pads values. + + Attr: + keys: Keys to apply the transformation to. + pad_width: Padding width. + mode: Padding mode. + """ + + keys: str | Sequence[str] = ('x',) + pad_width: int | Sequence[int] = 0 + mode: str = 'constant' + + def __post_init__(self): + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + features[k] = np.pad( + features[k], pad_width=self.pad_width, mode=self.mode) + return features + + +@dataclasses.dataclass(frozen=False) +class RescaleValues(grain.MapTransform): + """Rescales values from `min/max_input` to `min/max_output`. + + Attr: + keys: Keys to apply the transformation to. + min_input: The minimum value of the input. + max_input: The maximum value of the input. + min_output: The minimum value of the output. + max_output: The maximum value of the output. + """ + + keys: str | Sequence[str] = ('x',) + min_output: float = 0.0 + max_output: float = 1.0 + min_input: float = 0.0 + max_input: float = 255.0 + + def __post_init__(self): + assert self.min_output < self.max_output + assert self.min_input < self.max_input + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + arr = features[k].astype(np.float32) + arr = (arr - self.min_input) / (self.max_input - self.min_input) + arr = self.min_output + arr * (self.max_output - self.min_output) + features[k] = arr + return features + + +@dataclasses.dataclass(frozen=False) +class ReshapeValues(grain.MapTransform): + """Reshapes values. + + Attr: + keys: Keys to apply the transformation to. + newshape: New shape. + """ + + keys: str | Sequence[str] = ('x',) + newshape: int | Sequence[int] = -1 + + def __post_init__(self): + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + features[k] = features[k].reshape(self.newshape) + return features + + +@dataclasses.dataclass(frozen=False) +class ShiftAndDivideValues(grain.MapTransform): + """Subtracts shift from values and divides by scaling factor. + + Attr: + keys: Keys to apply the transformation to. + shift: Shift to subtract. + divisor: Scale factor to divide by. + """ + + keys: str | Sequence[str] = ('x',) + shift: float = 0.0 + divisor: float = 1.0 + + def __post_init__(self): + assert self.divisor != 0.0 + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + arr = features[k].astype(np.float32) + arr = (arr - self.shift) / self.divisor + features[k] = arr + return features + + +@dataclasses.dataclass(frozen=False) +class TransposeValues(grain.MapTransform): + """Transposes values. + + Attr: + keys: Keys to apply the transformation to. + axis: If specified, it must be a tuple or list which contains a permutation. + If not specified, which reverses order of axes. + """ + + keys: str | Sequence[str] = ('x',) + axis: Optional[Sequence[int]] = None + + def __post_init__(self): + self.keys = (self.keys,) if isinstance(self.keys, str) else self.keys + + def map(self, features: FlatFeatures) -> FlatFeatures: + for k in self.keys: + if k not in features: continue + features[k] = features[k].transpose(self.axis) + return features + + +def get_all_ops(module_name: str) -> list[ + tuple[str, Type[grain.Transformation]]]: + """Helper to return all preprocess ops in a module. + + Modules that define processing ops can simply define: + all_ops = lambda: process_spec.get_all_ops(__name__) + all_ops() will then return a list with all dataclasses being + grain.Transformation. + + Args: + module_name: Name of the module. The module must already be imported. + + Returns: + List of tuples of process ops. The first tuple element is the class name + converted to snake case (MyAwesomeTransform => my_awesome_transform) and + the second element is the class. + """ + transforms = [grain.MapTransform, grain.RandomMapTransform, + grain.FilterTransform] + + def is_op(x) -> bool: + return (inspect.isclass(x) + and dataclasses.is_dataclass(x) + and any(issubclass(x, t) for t in transforms)) + + op_name = lambda n: _CAMEL_CASE_RGX.sub('_', n).lower() + members = inspect.getmembers(sys.modules[module_name]) + return [(op_name(name), op) for name, op in members if is_op(op)] + + +def _get_op_class( + expr: list[ast.stmt], + available_ops: dict[str, type[grain.Transformation]] + ) -> Type[grain.Transformation]: + """Gets the process op fn from the given expression.""" + if isinstance(expr, ast.Call): + fn_name = expr.func.id + elif isinstance(expr, ast.Name): + fn_name = expr.id + else: + raise ValueError( + f'Could not parse function name from expression: {expr!r}.') + if fn_name in available_ops: + return available_ops[fn_name] + raise ValueError( + f'"{fn_name}" is not available (available ops: {list(available_ops)}).') + + +def _parse_single_preprocess_op( + spec: str, + available_ops: dict[str, Type[grain.Transformation]] + ) -> grain.Transformation: + """Parsing the spec for a single preprocess op. + + The op can just be the method name or the method name followed by any + arguments (both positional and keyword) to the method. + See the test cases for some valid examples. + + Args: + spec: String specifying a single processing operations. + available_ops: Available preprocessing ops. + + Returns: + The Transformation corresponding to the spec. + """ + try: + expr = ast.parse(spec, mode='eval').body # pytype: disable=attribute-error + except SyntaxError as e: + raise ValueError(f'{spec!r} is not a valid preprocess op spec.') from e + op_class = _get_op_class(expr, available_ops) # pytype: disable=wrong-arg-types + + # Simple case without arguments. + if isinstance(expr, ast.Name): + return op_class() + + assert isinstance(expr, ast.Call) + args = [ast.literal_eval(arg) for arg in expr.args] + kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords} + if not args: + return op_class(**kwargs) + + # Translate positional arguments into keyword arguments. + available_arg_names = [f.name for f in dataclasses.fields(op_class)] + for i, arg in enumerate(args): + name = available_arg_names[i] + if name in kwargs: + raise ValueError( + f'Argument {name} to {op_class} given both as positional argument ' + f'(value: {arg}) and keyword argument (value: {kwargs[name]}).') + kwargs[name] = arg + + return op_class(**kwargs) + + +def parse(spec: str, available_ops: list[tuple[str, Any]] + ) -> grain.Transformations: + """Parses a preprocess spec; a '|' separated list of preprocess ops.""" + available_ops = dict(available_ops) + if not spec.strip(): + transformations = [] + else: + transformations = [ + _parse_single_preprocess_op(s, available_ops) + for s in spec.split('|') + ] + + return transformations diff --git a/connectomics/jax/grain_util_test.py b/connectomics/jax/grain_util_test.py new file mode 100644 index 0000000..de0e7f2 --- /dev/null +++ b/connectomics/jax/grain_util_test.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from connectomics.jax import grain_util +import numpy as np + + +class GrainUtilTest(absltest.TestCase): + + def test_clip_values(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = np.array([[0, 0, 0], [1, 1, 1]], dtype=np.float32) + np.testing.assert_allclose( + grain_util.ClipValues().map({'x': array})['x'], expected + ) + + def test_expand_dims(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = array[np.newaxis, ...] + np.testing.assert_allclose( + grain_util.ExpandDims(axis=0).map({'x': array})['x'], expected + ) + + def test_pad_values(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = np.array([[0, 0, 0, 0, 0], [0, 255, 255, 255, 0]]) + np.testing.assert_allclose( + grain_util.PadValues( + pad_width=((0, 0), (1, 1))).map({'x': array})['x'], expected + ) + + def test_rescale_values(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = np.array([[0, 0, 0], [1, 1, 1]], dtype=np.float32) + np.testing.assert_allclose( + grain_util.RescaleValues().map({'x': array})['x'], expected + ) + + def test_reshape_values(self): + array = np.zeros((2, 4, 8)) + shape = (-1, 8) + np.testing.assert_allclose( + grain_util.ReshapeValues(newshape=shape).map({'x': array})['x'], + array.reshape(shape) + ) + + def test_shift_and_divide_values(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = np.array([[0, 0, 0], [1, 1, 1]], dtype=np.float32) + np.testing.assert_allclose( + grain_util.ShiftAndDivideValues( + divisor=255.0).map({'x': array})['x'], expected + ) + + def test_transpose_values(self): + array = np.zeros((2, 4, 8)) + axis = (2, 1, 0) + np.testing.assert_allclose( + grain_util.TransposeValues(axis=axis).map({'x': array})['x'], + array.transpose(axis) + ) + + def test_all_ops(self): + all_ops = sum(map(grain_util.get_all_ops, + ['connectomics.jax.grain_util']), []) + assert len(all_ops) >= 3 + + def test_parse(self): + array = np.array([[0, 0, 0], [255, 255, 255]]) + expected = np.array([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=np.float32) + + all_ops = sum(map(grain_util.get_all_ops, + ['connectomics.jax.grain_util']), []) + transformations = grain_util.parse( + 'clip_values()|shift_and_divide_values(divisor=2.)', all_ops) + assert len(transformations) == 2 + + res = {'x': np.copy(array)} + for op in transformations: + res = op.map(res) + np.testing.assert_allclose(res['x'], expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/connectomics/jax/metrics.py b/connectomics/jax/metrics.py new file mode 100644 index 0000000..1413f55 --- /dev/null +++ b/connectomics/jax/metrics.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Metrics. + +All metrics assume a leading batch dimension that is preserved and assume +inputs `predictions`, `targets`. All functions are compatible with clu.metrics +constructed from functions. Constructed relative metrics additionally require +the keyword argument `baseline`. +""" + +from collections.abc import Callable +from typing import Any, Sequence + +from clu import metric_writers +from clu import metrics +import flax +import jax +import jax.numpy as jnp +import numpy as np +import scipy.special +import sklearn.metrics + +Array = metric_writers.interface.Array +Scalar = metric_writers.interface.Scalar + + +def get_metrics_collection_from_dict( + metrics_dict: dict[str, Any], prefix: str = '' +) -> type[metrics.Collection]: + """Gets metrics collection from dict with optional prefix.""" + return metrics.Collection.create( + **{f'{prefix}{k}': v for k, v in metrics_dict.items()} + ) + + +def make_dict_of_scalars( + metrics_dict: dict[str, Scalar | Array], + prefix_keys: str = '', + prefix_vector: str = '/', +) -> dict[str, Scalar]: + """Converts vectors to scalars in metrics dict.""" + metrics_dict_compat = dict() + for k, v in metrics_dict.items(): + if isinstance(v, int) or isinstance(v, float): + metrics_dict_compat[f'{prefix_keys}{k}'] = v + elif isinstance(v, np.ndarray) or isinstance(v, jnp.ndarray): + if v.ndim == 0: + metrics_dict_compat[f'{prefix_keys}{k}'] = v + elif v.ndim == 1: + for i, v_el in enumerate(v): + metrics_dict_compat[f'{prefix_keys}{k}{prefix_vector}{i+1}'] = v_el + else: + raise ValueError('Only scalars or vectors are allowed.') + else: + raise ValueError('Unsupported type.') + return metrics_dict_compat + + +def make_relative_metric( + metric: Callable[..., jnp.ndarray], +) -> Callable[..., jnp.ndarray]: + """Construct relative metric to a baseline given base metric.""" + + def _relative_metric( + predictions: jnp.ndarray, + targets: jnp.ndarray, + baseline: jnp.ndarray, + **kwargs, + ) -> jnp.ndarray: + metric_model = metric(predictions=predictions, targets=targets, **kwargs) + metric_baseline = metric(predictions=baseline, targets=targets, **kwargs) + return metric_model / metric_baseline + + return _relative_metric + + +def make_per_step_metric( + metric: Callable[..., jnp.ndarray], +) -> Callable[..., jnp.ndarray]: + """Construct per-step metric.""" + + def _per_step_metric( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs + ) -> jnp.ndarray: + assert predictions.shape == targets.shape + assert len(targets.shape) >= 2 + kwargs['video'] = False # Only needed for video_forecasting.metrics.ssim + batch, timesteps = targets.shape[:2] + predictions = predictions.reshape(batch * timesteps, *targets.shape[2:]) + targets = targets.reshape(batch * timesteps, *targets.shape[2:]) + score = metric(predictions=predictions, targets=targets, **kwargs) + return score.reshape(batch, timesteps) + + return _per_step_metric + + +@flax.struct.dataclass +class PerStepAverage(metrics.Metric): + """Average metric with additional kept leading dimension (e.g. steps). + + Assumes inputs of shape of shape (batch, steps) and averages to (steps,). + """ + + total: jnp.ndarray + count: jnp.ndarray + + @classmethod + def empty(cls) -> Any: + return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) + + @classmethod + def from_model_output(cls, values: jnp.ndarray, mask: Any = None, **_) -> Any: + assert values.ndim >= 2, 'Vector Average requires per sample steps' + assert mask is None, 'Mask not supported' + batch, timesteps = values.shape[:2] + total = values.reshape(batch, timesteps, -1).sum(axis=(0, 2)) + return cls(total=total, count=batch) # pytype: disable=wrong-arg-types # jnp-array + + def merge(self, other: Any) -> Any: + return type(self)( + total=self.total + other.total, + count=self.count + other.count, + ) + + def compute(self) -> Any: + return self.total / self.count + + +def make_metric_with_threshold( + metric: Callable[..., jnp.ndarray], threshold: float +) -> Callable[..., jnp.ndarray]: + """Construct metric that is applied after thresholding to boolean array.""" + + def _metric_with_threshold( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs + ) -> jnp.ndarray: + return metric( + predictions=predictions > threshold, + targets=targets > threshold, + **kwargs, + ) + + return _metric_with_threshold + + +def mse(predictions: jnp.ndarray, targets: jnp.ndarray, **_) -> jnp.ndarray: + """Compute mean squared error per example.""" + assert predictions.shape == targets.shape + axes = tuple(range(1, targets.ndim)) + return jnp.mean(jnp.square(targets - predictions), axis=axes) + + +def mae(predictions: jnp.ndarray, targets: jnp.ndarray, **_) -> jnp.ndarray: + """Compute mean absolute error per example.""" + assert predictions.shape == targets.shape + axes = tuple(range(1, targets.ndim)) + return jnp.mean(jnp.abs(targets - predictions), axis=axes) + + +def mape(predictions: jnp.ndarray, targets: jnp.ndarray, **_) -> jnp.ndarray: + """Compute mean absolute percentage error per example.""" + assert predictions.shape == targets.shape + eps = jnp.finfo(targets.dtype).eps + axes = tuple(range(1, targets.ndim)) + return jnp.mean( + jnp.abs(predictions - targets) / jnp.maximum(jnp.abs(targets), eps), + axis=axes, + ) + + +@jax.jit +def _confusion_matrix_bool_1d( + y_true: jnp.ndarray, y_pred: jnp.ndarray, **_ +) -> jnp.ndarray: + """Calculates confusion matrix for boolean 1d-arrays.""" + return jnp.bincount(2 * y_true + y_pred, minlength=4, length=4).reshape(2, 2) + + +def confusion_matrix_bool( + predictions: jnp.ndarray, targets: jnp.ndarray, **_ +) -> jnp.ndarray: + """Calculates confusion matrix for boolean arrays. + + Args: + predictions: Array of boolean predictions with leading batch dimension. + targets: Array of boolean targets with leading batch dimension. + + Returns: + Confusion matrix, with True values as positive class, laid out as follows: + tp fp + fn tn, + where tp = true pos., fp = false pos., fn = false neg., and tn = true neg. + """ + assert predictions.dtype == targets.dtype == bool + assert predictions.shape == targets.shape + shape = (targets.shape[0], -1) + + predictions = predictions.reshape(*shape) + targets = targets.reshape(*shape) + + cm_batched = jax.vmap(_confusion_matrix_bool_1d, 0, 0) + return cm_batched(~targets, ~predictions).transpose((0, 2, 1)) + + +def confusion_matrix_sklearn( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs +) -> jnp.ndarray: + """Calculates confusion matrix with sklearn. + + In the case of boolean arrays, the implementation in `confusion_matrix_bool` + can be significantly faster, see also [1]. + + To match the return format of `confusion_matrix_bool` for boolean arrays, use: + confusion_matrix_sklearn( + ~predictions, ~targets, labels=[False, True]).transpose((0, 2, 1)) + + Args: + predictions: Array of boolean predictions with leading batch dimension. + targets: Array of boolean targets with leading batch dimension. + **kwargs: Passed to sklearn.metrics.confusion_matrix. + + Returns: + Confusion matrix with leading batch dimension whose i-th row and j-th + column entry indicates the number of samples with true label being i-th + class and predicted label being j-th class. + + References: + [1]: https://github.com/scikit-learn/scikit-learn/issues/15388 + """ + assert predictions.shape == targets.shape + shape = (targets.shape[0], -1) + + predictions = predictions.reshape(*shape) + targets = targets.reshape(*shape) + + res = [] + for batch in range(shape[0]): + res.append( + sklearn.metrics.confusion_matrix( + y_true=targets[batch, :], y_pred=predictions[batch, :], **kwargs + ) + ) + return jnp.array(res) + + +def precision_bool( + predictions: jnp.ndarray, + targets: jnp.ndarray, + zero_division: float = jnp.nan, + **_, +) -> jnp.ndarray: + """Compute precision for boolean arrays.""" + assert predictions.dtype == targets.dtype == bool + assert predictions.shape == targets.shape + cm = confusion_matrix_bool(predictions=predictions, targets=targets) + # precision: tp / (tp + fp) + numerator = cm[:, 0, 0] + denominator = cm[:, 0, 0] + cm[:, 0, 1] + return jnp.where(denominator > 0, numerator / denominator, zero_division) + + +def precision_sklearn( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs +) -> jnp.ndarray: + """Compute precision with sklearn.""" + assert predictions.shape == targets.shape + shape = (targets.shape[0], -1) + return jnp.array([ + sklearn.metrics.precision_score( + y_true=targets.reshape(*shape)[b, :], + y_pred=predictions.reshape(*shape)[b, :], + **kwargs, + ) + for b in range(shape[0]) + ]) + + +def recall_bool( + predictions: jnp.ndarray, + targets: jnp.ndarray, + zero_division: float = jnp.nan, + **_, +) -> jnp.ndarray: + """Compute recall for boolean arrays.""" + assert predictions.dtype == targets.dtype == bool + assert predictions.shape == targets.shape + cm = confusion_matrix_bool(predictions=predictions, targets=targets) + # recall: tp / (tp + fn) + numerator = cm[:, 0, 0] + denominator = cm[:, 0, 0] + cm[:, 1, 0] + return jnp.where(denominator > 0, numerator / denominator, zero_division) + + +def recall_sklearn( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs +) -> jnp.ndarray: + """Compute recall with sklearn.""" + assert predictions.shape == targets.shape + shape = (targets.shape[0], -1) + return jnp.array([ + sklearn.metrics.recall_score( + y_true=targets.reshape(*shape)[b, :], + y_pred=predictions.reshape(*shape)[b, :], + **kwargs, + ) + for b in range(shape[0]) + ]) + + +def precision_recall_f1_bool( + predictions: jnp.ndarray, + targets: jnp.ndarray, + zero_division: float = jnp.nan, + **_, +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Compute precision, recall, f1 score for boolean arrays.""" + assert predictions.dtype == targets.dtype == bool + assert predictions.shape == targets.shape + cm = confusion_matrix_bool(predictions=predictions, targets=targets) + + # precision: tp / (tp + fp) + numerator = cm[:, 0, 0] + denominator = cm[:, 0, 0] + cm[:, 0, 1] + p = jnp.where(denominator > 0, numerator / denominator, zero_division) + + # recall: tp / (tp + fn) + denominator = cm[:, 0, 0] + cm[:, 1, 0] + r = jnp.where(denominator > 0, numerator / denominator, zero_division) + + # f1: 2 * (precision * recall) / (precision + recall) + numerator = 2 * p * r + denominator = p + r + f1 = jnp.where(denominator > 0, numerator / denominator, zero_division) + + return p, r, f1 + + +def f1_bool( + predictions: jnp.ndarray, + targets: jnp.ndarray, + zero_division: float = jnp.nan, + **_, +) -> jnp.ndarray: + """Compute f1 score for boolean arrays.""" + _, _, f1 = precision_recall_f1_bool(predictions, targets, zero_division) + return f1 + + +def f1_sklearn( + predictions: jnp.ndarray, targets: jnp.ndarray, **kwargs +) -> jnp.ndarray: + """Compute f1 score with sklearn.""" + assert predictions.shape == targets.shape + shape = (targets.shape[0], -1) + return jnp.array([ + sklearn.metrics.f1_score( + y_true=targets.reshape(*shape)[b, :], + y_pred=predictions.reshape(*shape)[b, :], + **kwargs, + ) + for b in range(shape[0]) + ]) + + +def create_vpt_metric(metric_fn: Any, threshold: float) -> type[metrics.Metric]: + """Creates metric to compute valid prediction time (VPT). + + Assumes inputs of shape (batch, steps) and returns VPT as an integer, + computing argmin over steps for `metric_fn(predictions, targets) > threshold`. + + Args: + metric_fn: Metric function. + threshold: Threshold. + + Returns: + VPT metric. + """ + + @flax.struct.dataclass + class _ValidPredictionTime(PerStepAverage): + """Valid Prediction Time metric.""" + + def compute(self) -> Any: + return jnp.min(jnp.argwhere(self.total / self.count > threshold)) + + return _ValidPredictionTime.from_fun(make_per_step_metric(metric_fn)) + + +def create_classification_metrics( + class_names: Sequence[str], +) -> type[metrics.CollectingMetric]: + """Creates classification metrics for N classes.""" + + @flax.struct.dataclass + class ClassificationMetrics( + metrics.CollectingMetric.from_outputs(('labels', 'logits')) + ): + """Computes precision, recall, F1, auc_pr and roc_auc per class. + + Data (labels, logits) is collected on the host, and summarized into + metrics when compute() is called. + """ + + classes: Sequence[str] = class_names + + def compute(self) -> dict[str, float]: + """Computes the metrics.""" + values = super().compute() + labels = np.array(values['labels']) + logits = np.array(values['logits']) + + labels = labels.ravel() + logits = logits.reshape([-1, logits.shape[-1]]) + + if logits.shape[-1] != len(self.classes): + raise ValueError( + f'Number of classes {len(self.classes)} does not match logits' + f' dimension {logits.shape[-1]}.' + ) + + with jax.default_device(jax.local_devices(backend='cpu')[0]): + labels_1hot = np.asarray( + jax.nn.one_hot(labels, num_classes=len(self.classes)) + ) + + prob = scipy.special.softmax(logits, axis=-1) + pred = prob.argmax(axis=-1) + + if prob.shape[-1] == 2: + roc_prob = prob[:, 0] + else: + roc_prob = prob + + precision, recall, f1, _ = ( + sklearn.metrics.precision_recall_fscore_support(labels, pred) + ) + roc_auc = sklearn.metrics.roc_auc_score( + labels, roc_prob, multi_class='ovr' + ) + auc_pr = sklearn.metrics.average_precision_score( + labels_1hot, prob, average=None + ) + + ret = { + 'roc_auc': roc_auc, + } + + for i, name in enumerate(self.classes): + ret[f'precision__{name}'] = precision[i] + ret[f'recall__{name}'] = recall[i] + ret[f'f1__{name}'] = f1[i] + ret[f'auc_pr__{name}'] = auc_pr[i] + + return ret + + return ClassificationMetrics diff --git a/connectomics/jax/metrics_test.py b/connectomics/jax/metrics_test.py new file mode 100644 index 0000000..420babc --- /dev/null +++ b/connectomics/jax/metrics_test.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for metrics. + +Depending on the metric, tested against dm_pix, sklearn, or manual results. +dm_pix is limited to [batch, x, y, channel]. +""" + +from absl.testing import absltest +from connectomics.jax import metrics +import dm_pix +import jax.numpy as jnp +import numpy as np +import scipy.special +import sklearn.metrics + + +class MetricsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + rng = np.random.default_rng(42) + + # (batch, z, y, x, channel) + self.vol1 = rng.uniform(0, 1, (4, 28, 32, 36, 2)) + self.vol2 = rng.uniform(0, 1, (4, 28, 32, 36, 2)) + + # (batch, x) + self.boolean_true = np.array([[0, 1, 1, 0, 0, 0, 1, 1, 1, 1]]).astype(bool) + self.boolean_pred = np.array([[0, 0, 1, 0, 0, 1, 1, 1, 1, 0]]).astype(bool) + tp, fp, fn, tn = 4, 1, 2, 3 + self.boolean_confusion_matrix = [[tp, fp], [fn, tn]] + self.boolean_precision = tp / (tp + fp) + self.boolean_recall = tp / (tp + fn) + self.boolean_f1 = ( + 2 + * (self.boolean_precision * self.boolean_recall) + / (self.boolean_precision + self.boolean_recall) + ) + + def test_mae_integration_against_pix_2d(self): + mae_pix = dm_pix.mae(self.vol1[:, 0], self.vol2[:, 0]) + mae_neuro = metrics.mae(self.vol1[:, 0], self.vol2[:, 0]) + np.testing.assert_allclose(mae_pix, mae_neuro, atol=1e-6, rtol=1e-6) + + def test_mse_integration_against_pix_2d(self): + mse_pix = dm_pix.mse(self.vol1[:, 0], self.vol2[:, 0]) + mse_neuro = metrics.mse(self.vol1[:, 0], self.vol2[:, 0]) + np.testing.assert_allclose(mse_pix, mse_neuro, atol=1e-6, rtol=1e-6) + + def test_mape_integration_against_sklearn(self): + mape_sklearn = sklearn.metrics.mean_absolute_percentage_error( + y_pred=self.vol1.reshape(-1, 1), y_true=self.vol2.reshape(-1, 1) + ) + mape_neuro = metrics.mape( + self.vol1.reshape(1, -1), self.vol2.reshape(1, -1) + ) + np.testing.assert_allclose(mape_sklearn, mape_neuro, atol=1e-6, rtol=1e-6) + + def test_confusion_matrix_bool_against_manual_result(self): + cm = metrics.confusion_matrix_bool(self.boolean_pred, self.boolean_true) + np.testing.assert_array_equal( + jnp.array([self.boolean_confusion_matrix]), cm + ) + + def test_confusion_matrix_sklearn_against_manual_result(self): + cm = metrics.confusion_matrix_sklearn( + ~self.boolean_pred, ~self.boolean_true, labels=[False, True] + ).transpose((0, 2, 1)) + np.testing.assert_array_equal( + jnp.array([self.boolean_confusion_matrix]), cm + ) + + def test_precision_bool_against_manual_result(self): + precision_neuro = metrics.precision_bool( + self.boolean_pred, self.boolean_true + ) + np.testing.assert_array_equal( + jnp.array([self.boolean_precision]), precision_neuro + ) + + def test_precision_bool_against_sklearn(self): + precision_ref = metrics.make_metric_with_threshold( + metrics.precision_sklearn, 0.5 + )(self.vol1, self.vol2, zero_division=0.0) + precision_neuro = metrics.make_metric_with_threshold( + metrics.precision_bool, 0.5 + )(self.vol1, self.vol2, zero_division=0.0) + np.testing.assert_allclose( + precision_ref, precision_neuro, atol=1e-6, rtol=1e-6 + ) + + def test_recall_bool_against_manual_result(self): + recall_neuro = metrics.recall_bool(self.boolean_pred, self.boolean_true) + np.testing.assert_array_equal( + jnp.array([self.boolean_recall]), recall_neuro + ) + + def test_recall_bool_against_sklearn(self): + recall_ref = metrics.make_metric_with_threshold( + metrics.recall_sklearn, 0.5 + )(self.vol1, self.vol2, zero_division=0.0) + recall_neuro = metrics.make_metric_with_threshold(metrics.recall_bool, 0.5)( + self.vol1, self.vol2, zero_division=0.0 + ) + np.testing.assert_allclose(recall_ref, recall_neuro, atol=1e-6, rtol=1e-6) + + def test_f1_bool_against_manual_result(self): + f1_neuro = metrics.f1_bool(self.boolean_pred, self.boolean_true) + np.testing.assert_array_equal(jnp.array([self.boolean_f1]), f1_neuro) + + def test_f1_bool_against_sklearn(self): + f1_ref = metrics.make_metric_with_threshold(metrics.f1_sklearn, 0.5)( + self.vol1, self.vol2, zero_division=0.0 + ) + f1_neuro = metrics.make_metric_with_threshold(metrics.f1_bool, 0.5)( + self.vol1, self.vol2, zero_division=0.0 + ) + np.testing.assert_allclose(f1_ref, f1_neuro, atol=1e-6, rtol=1e-6) + + def test_valid_prediction_time(self): + targets = np.array([[1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0]]) + predictions = np.array( + [[1.0, 1.0, 1.0, 10.0, 10.0], [1.0, 1.0, 1.0, 10.0, 10.0]] + ) + vpt = metrics.create_vpt_metric(metric_fn=metrics.mse, threshold=0.5) + metric = vpt.from_model_output(predictions=predictions, targets=targets) + np.testing.assert_equal(metric.compute(), np.array(3, dtype='int')) + + def test_classification_metrics_binary(self): + cls = metrics.create_classification_metrics(('neuron', 'glia')) + m = cls.from_model_output( + logits=np.array([[-1, 1], [-1, 1], [1, -1]]), labels=np.array([0, 1, 1]) + ) + actual = m.compute() + + self.assertEqual(actual['precision__glia'], 0.5) + self.assertEqual(actual['precision__neuron'], 0) + self.assertEqual(actual['recall__neuron'], 0) + self.assertEqual(actual['recall__glia'], 0.5) + self.assertEqual(actual['f1__neuron'], 0) + self.assertEqual(actual['f1__glia'], 0.5) + + def test_classification_metrics_multiclass(self): + cls = metrics.create_classification_metrics(('axon', 'dend', 'glia')) + l1, l2, l3 = ( + scipy.special.logit(0.1), + scipy.special.logit(0.2), + scipy.special.logit(0.7), + ) + + m = cls.from_model_output( + logits=np.array([[l1, l2, l3], [l2, l3, l1], [l2, l1, l3]]), # gdg + labels=np.array([0, 1, 2]), # adg + ) + actual = m.compute() + + self.assertEqual(actual['precision__axon'], 0) + self.assertEqual(actual['recall__axon'], 0) + self.assertEqual(actual['f1__axon'], 0) + + self.assertEqual(actual['precision__glia'], 0.5) + self.assertEqual(actual['recall__glia'], 1.) + self.assertEqual(actual['f1__glia'], 2/3.) + + self.assertEqual(actual['precision__dend'], 1) + self.assertEqual(actual['recall__dend'], 1) + self.assertEqual(actual['f1__dend'], 1) + + +if __name__ == '__main__': + absltest.main() diff --git a/connectomics/jax/models/activation.py b/connectomics/jax/models/activation.py new file mode 100644 index 0000000..de3ec78 --- /dev/null +++ b/connectomics/jax/models/activation.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Activation functions.""" + +from typing import Any + +import flax.linen as nn + + +def activation_fn_from_str(input_string: str) -> Any: + """Gets activation function from string.""" + if input_string == 'linear': + return lambda x: x + elif hasattr(nn, input_string): + return getattr(nn, input_string) + else: + raise ValueError('activation function not found as part of flax.linen.') diff --git a/connectomics/jax/models/initializer.py b/connectomics/jax/models/initializer.py new file mode 100644 index 0000000..125666f --- /dev/null +++ b/connectomics/jax/models/initializer.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Initializer functions.""" + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np + + +def constant_init(dim, dtype=jnp.float_): + """Initializes weights to `1 / shape[dim]`.""" + + def init(unused_key, shape, dtype=dtype): + dtype = jax.dtypes.canonicalize_dtype(dtype) + return 1. / shape[dim] * jnp.full(shape, 1., dtype=dtype) + + return init + + +def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): + """1D Sinusoidal Position Embedding Initializer. + + Args: + max_len: maximum possible length for the input. + min_scale: float: minimum frequency-scale in sine grating. + max_scale: float: maximum frequency-scale in sine grating. + + Returns: + output: init function returning `(1, max_len, d_feature)` + """ + + def init(key, shape, dtype=np.float32): + """Sinusoidal init.""" + del key, dtype + d_feature = shape[-1] + pe = np.zeros((max_len, d_feature), dtype=np.float32) + position = np.arange(0, max_len)[:, np.newaxis] + scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) + div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) + pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] + return jnp.array(pe) + + return init + + +def init_fn_from_str(input_string: str): + """Gets init function from string.""" + if input_string == 'constant': + return constant_init + elif input_string.startswith('normal('): + std = input_string.replace('normal(', '').replace(')', '') + return nn.initializers.normal(float(std)) + elif input_string == 'sinusoidal': + return sinusoidal_init + elif hasattr(nn, input_string): + return getattr(nn, input_string) + else: + raise ValueError('init function not found as part of flax.linen.') diff --git a/connectomics/jax/models/normalization.py b/connectomics/jax/models/normalization.py new file mode 100644 index 0000000..62872d1 --- /dev/null +++ b/connectomics/jax/models/normalization.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalization modules for Flax.""" + +import functools +from typing import Any, Callable, Optional + +import flax.linen as nn +from flax.linen.dtypes import canonicalize_dtype # pylint: disable=g-importing-member +from flax.linen.module import Module, compact # pylint: disable=g-importing-member,g-multiple-import +from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize # pylint: disable=g-importing-member,g-multiple-import +from flax.linen.normalization import Array, Axes, Dtype, PRNGKey, Shape # pylint: disable=g-importing-member,g-multiple-import +from jax import lax +from jax.nn import initializers +import jax.numpy as jnp + + +class NoOp(Module): + """NoOp.""" + + @compact + def __call__(self, x): + return x + + +def norm_layer_from_str(input_string: str, train: Optional[bool] = None) -> Any: + """Gets normalization layer from string.""" + kwargs = {} + if not input_string or input_string == 'NoOp': + layer = NoOp + elif input_string.startswith('BatchNorm'): + layer = nn.BatchNorm + if train is not None: + kwargs['use_running_average'] = not train + if '(' in input_string: + kwargs['momentum'] = float(input_string.replace( + 'BatchNorm(', '').replace(')', '')) + elif input_string == 'InstanceNorm': + layer = nn.GroupNorm + kwargs['group_size'] = 1 + elif input_string == 'ReversibleInstanceNorm': + layer = ReversibleInstanceNorm + elif hasattr(nn, input_string): + layer = getattr(nn, input_string) + else: + raise ValueError('normalization layer not found as part of flax.linen.') + return functools.partial(layer, **kwargs) + + +def _denormalize( + mdl: Module, + x: Array, + mean: Array, + var: Array, + reduction_axes: Axes, + feature_axes: Axes, + dtype: Dtype, + param_dtype: Dtype, + epsilon: float, + use_bias: bool, + use_scale: bool, + bias_init: Callable[[PRNGKey, Shape, Dtype], Array], + scale_init: Callable[[PRNGKey, Shape, Dtype], Array], +): + """Denormalizes the input of a normalization layer with optional learned scale and bias. + + Arguments: + mdl: Module to apply the denormalization in (normalization params will + reside in this module). + x: The input. + mean: Mean to use for denormalization. + var: Variance to use for denormalization. + reduction_axes: The axes in ``x`` to reduce. + feature_axes: Axes containing features. A separate bias and scale is learned + for each specified feature. + dtype: The dtype of the result (default: infer from input and params). + param_dtype: The dtype of the parameters. + epsilon: Denormalization epsilon. + use_bias: If true, add a bias term to the output. + use_scale: If true, scale the output. + bias_init: Initialization function for the bias term. + scale_init: Initialization function for the scaling function. + + Returns: + The denormalized input. + """ + reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) + feature_axes = _canonicalize_axes(x.ndim, feature_axes) + feature_shape = [1] * x.ndim + reduced_feature_shape = [] + for ax in feature_axes: + feature_shape[ax] = x.shape[ax] + reduced_feature_shape.append(x.shape[ax]) + + y = x + args = [x] + if use_bias: + bias = mdl.param( + 'bias', bias_init, reduced_feature_shape, param_dtype + ).reshape(feature_shape) + y -= bias + args.append(bias) + var = jnp.expand_dims(var, reduction_axes) + mul = lax.sqrt(var + epsilon) + if use_scale: + scale = mdl.param( + 'scale', scale_init, reduced_feature_shape, param_dtype + ).reshape(feature_shape) + mul /= scale + args.append(scale) + y *= mul + y += jnp.expand_dims(mean, reduction_axes) + dtype = canonicalize_dtype(*args, dtype=dtype) + return jnp.asarray(y, dtype) + + +class ReversibleInstanceNorm(Module): + """Reversible instance normalization (https://openreview.net/forum?id=cGDAkQo1C0p). + + Usage example: + rev_in = ReversibleInstanceNorm() + x, stats = rev_in(x) # x is normalized + # ... + y, _ = rev_in(x, stats) # x is denormalized using stats + return y + + Attributes: + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + This is only needed if the model is subdivided across devices, i.e. the + array being normalized is sharded across devices within a pmap. + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + """ + + epsilon: float = 1e-6 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + use_fast_variance: bool = True + + @compact + def __call__(self, x, stats=None): + """Applies (reversible) instance normalization on the input. + + Args: + x: the inputs + stats: statistics, if passed, inputs are denormalized. + + Returns: + (De)normalized inputs (the same shape as inputs) and stats. + """ + reduction_axes = list(range(1, x.ndim - 1)) + [-1] + feature_axes = (-1,) + + if stats is None: + transform_fn = _normalize + mean, var = _compute_stats( + x.reshape(x.shape + (1,)), + reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + use_fast_variance=self.use_fast_variance, + ) + stats = {'mean': mean, 'var': var} + else: + transform_fn = _denormalize + + return transform_fn( + self, + x, + stats['mean'], + stats['var'], + reduction_axes[:-1], + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ), stats diff --git a/connectomics/jax/models/normalization_test.py b/connectomics/jax/models/normalization_test.py new file mode 100644 index 0000000..179bc7c --- /dev/null +++ b/connectomics/jax/models/normalization_test.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for normalization.""" + +from absl.testing import absltest, parameterized # pylint: disable=g-multiple-import + +from connectomics.jax.models import normalization +from flax import linen as nn + +import jax +from jax import random + +import numpy as np + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +class NormalizationTest(parameterized.TestCase): + + def test_reversible_instance_norm(self): + e = 1e-5 + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + norm = normalization.ReversibleInstanceNorm( + name='norm', + use_bias=False, + use_scale=False, + epsilon=e, + ) + x_norm, stats = norm(x) + y, _ = norm(x_norm, stats) + return y, x_norm, stats + + rng = random.PRNGKey(0) + key1, key2 = random.split(rng) + x = random.normal(key1, (2, 5, 4, 4, 32)) + (y, x_norm, stats), _ = Foo().init_with_output(key2, x) + + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + np.testing.assert_allclose(y, x, atol=1e-6) + + self.assertEqual(x.dtype, x_norm.dtype) + self.assertEqual(x.shape, x_norm.shape) + x_gr = x.reshape([2, 5, 4, 4, 32, 1]) + x_norm_test = ( + x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True) + ) * jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e) + x_norm_test = x_norm_test.reshape([2, 5, 4, 4, 32]) + np.testing.assert_allclose(x_norm_test, x_norm, atol=1e-4) + + self.assertEqual(stats['mean'].shape, (2, 32)) + self.assertEqual(stats['var'].shape, (2, 32)) + + +if __name__ == '__main__': + absltest.main()