Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the trace optimizer for easier use. #98

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 23.3.0
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [flake8-isort]
2 changes: 2 additions & 0 deletions znnl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
distance_metrics,
loss_functions,
models,
optimizers,
point_selection,
similarity_measures,
training_recording,
Expand All @@ -51,6 +52,7 @@
distance_metrics.__name__,
loss_functions.__name__,
accuracy_functions.__name__,
optimizers.__name__,
models.__name__,
point_selection.__name__,
similarity_measures.__name__,
Expand Down
9 changes: 5 additions & 4 deletions znnl/data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
import jax.nn as nn
import jax.numpy as np
import numpy as onp
import plotly.graph_objects as go
import tensorflow_datasets as tfds
from plotly.subplots import make_subplots
Expand Down Expand Up @@ -66,11 +67,11 @@ def __init__(self, ds_size: int = 500, one_hot_encoding: bool = True):
self.test_ds.pop("label")
self.data_pool = self.train_ds["inputs"].astype(float)
if one_hot_encoding:
self.train_ds["targets"] = nn.one_hot(
self.train_ds["targets"], num_classes=10
self.train_ds["targets"] = onp.array(
nn.one_hot(self.train_ds["targets"], num_classes=10)
)
self.test_ds["targets"] = nn.one_hot(
self.test_ds["targets"], num_classes=10
self.test_ds["targets"] = onp.array(
nn.one_hot(self.test_ds["targets"], num_classes=10)
)

def plot_image(self, indices: list = None, data_list: list = None):
Expand Down
13 changes: 7 additions & 6 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import optax
from flax.training.train_state import TrainState

from znnl.optimizers.partitioned_trace_optimizer import PartitionedTraceOptimizer
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.utils.prng import PRNGKey

Expand Down Expand Up @@ -80,11 +81,11 @@ def __init__(
self.init_model(seed)

# Prepare NTK calculation
self.empirical_ntk = nt.batch(
nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes),
batch_size=ntk_batch_size,
)
self.empirical_ntk_jit = jax.jit(self.empirical_ntk)
self.empirical_ntk = nt.batch(nt.empirical_ntk_fn(
f=self._ntk_apply_fn, trace_axes=trace_axes
), batch_size=ntk_batch_size)

self.empirical_ntk_jit = self.empirical_ntk

def init_model(
self,
Expand Down Expand Up @@ -122,7 +123,7 @@ def _create_train_state(
params = self._init_params(kernel_init, bias_init)

# Set dummy optimizer for case of trace optimizer.
if isinstance(self.optimizer, TraceOptimizer):
if isinstance(self.optimizer, (TraceOptimizer, PartitionedTraceOptimizer)):
optimizer = optax.sgd(1.0)
else:
optimizer = self.optimizer
Expand Down
3 changes: 2 additions & 1 deletion znnl/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""
from znnl.optimizers.partitioned_trace_optimizer import PartitionedTraceOptimizer
from znnl.optimizers.trace_optimizer import TraceOptimizer

__all__ = [TraceOptimizer.__name__]
__all__ = [TraceOptimizer.__name__, PartitionedTraceOptimizer.__name__]
159 changes: 159 additions & 0 deletions znnl/optimizers/partitioned_trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
from dataclasses import dataclass
from typing import Callable

import jax.numpy as np
import numpy as onp
import optax
from flax.training.train_state import TrainState


@dataclass
class PartitionedTraceOptimizer:
"""
Class implementation of the trace optimizer

Attributes
----------
scale_factor : float
Scale factor to apply to the optimizer.
rescale_interval : int
Number of epochs to wait before re-scaling the learning rate.
subset : float
What percentage of data you want to use in the trace calculation.
"""

scale_factor: float
rescale_interval: float = 1
subset: float = None

_start_value = None

@optax.inject_hyperparams
def optimizer(self, learning_rate):
return optax.sgd(learning_rate)

def apply_optimizer(
self,
model_state: TrainState,
data_set: np.ndarray,
ntk_fn: Callable,
epoch: int,
):
"""
Apply the optimizer to a model state.

Parameters
----------
model_state : TrainState
Current state of the model
data_set : jnp.ndarray
Data-set to use in the computation.
ntk_fn : Callable
Function to use for the NTK computation
epoch : int
Current epoch

Returns
-------
new_state : TrainState
New state of the model
"""
eps = 1e-8

partitions = {}

number_of_classes = np.unique(data_set["targets"], axis=0)

for i in range(number_of_classes.shape[0]):
indices = np.where(data_set["targets"].argmax(-1) == i)[0]

partitions[i] = np.take(data_set["inputs"], indices, axis=0)

if self._start_value is None:
if self.subset is not None:
init_data_set = {}
for ds in partitions:
subset_size = int(self.subset * partitions[ds].shape[0])
init_data_set[ds] = np.take(
partitions[ds],
onp.random.randint(
0, partitions[ds].shape[0] - 1, size=subset_size
),
axis=0,
)
else:
init_data_set = data_set

start_trace = 0

for ds in init_data_set:
ntk = ntk_fn(init_data_set[ds])["empirical"]
start_trace += np.trace(ntk)

self._start_value = np.trace(ntk)

# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
# Select a subset of the data
if self.subset is not None:
data_set = {}

for ds in partitions:
subset_size = int(self.subset * partitions[ds].shape[0])
data_set[ds] = np.take(
partitions[ds],
onp.random.randint(
0, partitions[ds].shape[0] - 1, size=subset_size
),
axis=0,
)

# Compute the ntk trace.
trace = 0.0

for ds in data_set:
ntk = ntk_fn(data_set[ds])["empirical"]
trace += np.trace(ntk)

# Create the new optimizer.
new_optimizer = self.optimizer(
(self.scale_factor * self._start_value) / (trace + eps)
)

# Create the new state
new_state = TrainState.create(
apply_fn=model_state.apply_fn,
params=model_state.params,
tx=new_optimizer,
)
else:
# If no update is needed, return the old state.
new_state = model_state

return new_state
47 changes: 45 additions & 2 deletions znnl/optimizers/trace_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
-------
"""
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Union

import jax.numpy as np
import numpy as onp
import optax
from flax.training.train_state import TrainState

Expand All @@ -43,10 +44,16 @@ class TraceOptimizer:
Scale factor to apply to the optimizer.
rescale_interval : int
Number of epochs to wait before re-scaling the learning rate.
subset : float
What percentage of data you want to use in the trace calculation.
"""

scale_factor: float
rescale_interval: float = 1
subset: Union[float, list] = None
memory: int = 1

_start_value = []

@optax.inject_hyperparams
def optimizer(self, learning_rate):
Expand Down Expand Up @@ -78,15 +85,51 @@ def apply_optimizer(
new_state : TrainState
New state of the model
"""
data_set = data_set["inputs"]
eps = 1e-8

if self._start_value == []:
if self.subset is not None:
if isinstance(self.subset, float):
subset_size = int(self.subset * data_set.shape[0])
init_data_set = np.take(
data_set,
onp.random.randint(0, data_set.shape[0] - 1, size=subset_size),
axis=0,
)
else:
init_data_set = np.take(data_set, self.subset, axis=0)
else:
init_data_set = data_set
ntk = ntk_fn(init_data_set)["empirical"]
self._start_value.append(np.trace(ntk))

# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
# Select a subset of the data
if self.subset is not None:
if isinstance(self.subset, float):
subset_size = int(self.subset * data_set.shape[0])
data_set = np.take(
data_set,
onp.random.randint(0, data_set.shape[0] - 1, size=subset_size),
axis=0,
)
else:
data_set = np.take(data_set, self.subset, axis=0)

# Compute the ntk trace.
ntk = ntk_fn(data_set)["empirical"]
trace = np.trace(ntk)

memory_index = int(np.clip(epoch - self.memory, 0, epoch))
memory_index = 0

# Create the new optimizer.
new_optimizer = self.optimizer(self.scale_factor / (trace + eps))
new_optimizer = self.optimizer(
(self.scale_factor * self._start_value[memory_index]) / (trace + eps)
)
self._start_value.append(trace)

# Create the new state
new_state = TrainState.create(
Expand Down
17 changes: 14 additions & 3 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class JaxRecorder:
name: str = "my_recorder"
storage_path: str = "./"
chunk_size: int = 100
subset: int = None

# Model Loss
loss: bool = True
Expand Down Expand Up @@ -294,9 +295,19 @@ def update_recorder(self, epoch: int, model: JaxModel):
# Compute ntk here to avoid repeated computation.
if self._compute_ntk:
try:
ntk = self._model.compute_ntk(
self._data_set["inputs"], infinite=False
)
if self.subset is not None:
indices = onp.random.randint(
0,
self._data_set["inputs"].shape[0] - 1,
size=self.subset,
)

my_ds = onp.take(self._data_set["inputs"], indices, axis=0)
ntk = self._model.compute_ntk(my_ds, infinite=False)
else:
ntk = self._model.compute_ntk(
self._data_set["inputs"], infinite=False
)
parsed_data["ntk"] = ntk["empirical"]
except NotImplementedError:
logger.info(
Expand Down
Loading