Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the loss NTK calculation #109

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
076d0df
starting out
Jan 12, 2024
930cac8
Renaming file
ma-sauter Jan 18, 2024
580194a
analysis file should be done
ma-sauter Jan 18, 2024
768fc13
quick fix
ma-sauter Jan 18, 2024
c802010
Included loss ntk, eigval and entropy in jax recorder
ma-sauter Jan 19, 2024
17fd83a
Updated recorder instantiation test
ma-sauter Jan 19, 2024
e88292f
Bugfixing
ma-sauter Jan 19, 2024
4eb8362
Stand vor Kino
Jan 19, 2024
fb1da2e
First thing state that might be working
ma-sauter Jan 22, 2024
d22bb68
writing test
Jan 22, 2024
e141dd8
Included vmap_axes
ma-sauter Jan 30, 2024
22cfc90
Working on calculating loss derivatives to calculate loss ntk comparison
ma-sauter Jan 30, 2024
5f9b6bf
Calculation and test should work now
ma-sauter Feb 4, 2024
131d5cb
some linting updates
ma-sauter Feb 4, 2024
5b244ff
Fixed tests
ma-sauter Feb 4, 2024
3a06e75
Some modifications to simplify the loss ntk test code
ma-sauter Feb 6, 2024
001e46e
Class renaming to follow convention
ma-sauter Feb 6, 2024
fd99b77
added reshape and unshape methods in the loss_ntk_calculation
ma-sauter Feb 6, 2024
0ef059f
quicksave
Feb 20, 2024
98c6b30
Quick save
ma-sauter Feb 20, 2024
cd5644a
Added test for eigenvalues, precision is still only e-4
ma-sauter Feb 20, 2024
b4fe246
Added some docstrings
ma-sauter Feb 20, 2024
bf7a401
More docstrings
ma-sauter Feb 20, 2024
43cdcb9
Black formatting
ma-sauter Feb 20, 2024
fd8626f
More docstrings
ma-sauter Feb 20, 2024
b6ebb8d
Some PR modifications
ma-sauter Feb 21, 2024
d17e514
fixing PR comment
Feb 26, 2024
eb49043
requirements change
Feb 26, 2024
53c548d
Black formatter changes
Feb 26, 2024
1dac434
changed recorder to use the use_loss_ntk flag
Feb 26, 2024
422eceb
Change recorder test for new flag
Feb 26, 2024
b5e8170
removed unneccesary CNN model from loss_ntk calculation test
Feb 26, 2024
594f4cb
Started integration test for loss_ntk_calculation
Feb 26, 2024
b55fc30
Implemented integration test
Feb 27, 2024
080099e
isort
Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just name this test_loss_ntk. The naming of the tests should mirror the main python package just with test in front. All integration tests using the loss ntk should be in this one module.

Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove from the test please.


import numpy as np
import optax
from neural_tangents import stax
from numpy.testing import assert_array_almost_equal

from znnl.loss_functions import LPNormLoss
from znnl.models import NTModel
from znnl.training_recording import JaxRecorder
from znnl.training_strategies import SimpleTraining


class TestLossNTKRecorderDeployment:
"""
Test suite for the loss and NTK recorder.
"""

@classmethod
def setup_class(cls):
"""
Create a model and data for the tests.
"""

network = stax.serial(
stax.Dense(10), stax.Relu(), stax.Dense(10), stax.Relu(), stax.Dense(1)
)
cls.model = NTModel(
nt_module=network, input_shape=(5,), optimizer=optax.adam(1e-3)
)

cls.data_set = {
"inputs": np.random.rand(10, 5),
"targets": np.random.randint(0, 2, (10, 1)),
}

cls.ntk_recorder = JaxRecorder(
name="ntk_recorder",
ntk=True,
update_rate=1,
)
cls.loss_ntk_recorder = JaxRecorder(
name="loss_ntk_recorder",
ntk=True,
use_loss_ntk=True,
update_rate=1,
)

cls.ntk_recorder.instantiate_recorder(data_set=cls.data_set)
cls.loss_ntk_recorder.instantiate_recorder(data_set=cls.data_set)

cls.trainer = SimpleTraining(
model=cls.model,
loss_fn=LPNormLoss(order=2),
recorders=[cls.ntk_recorder, cls.loss_ntk_recorder],
)

def test_loss_ntk_deployment(self):
"""
Test the deployment of the loss_NTK recorder.
"""

# train the model
training_metrics = self.trainer.train_model(
train_ds=self.data_set,
test_ds=self.data_set,
epochs=10,
batch_size=2,
)

# gather the recording
ntk_recording = self.ntk_recorder.gather_recording()
loss_ntk_recording = self.loss_ntk_recorder.gather_recording()

# For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK
# should be the same up to a factor of +1 or -1.
assert_array_almost_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is an integration test, you will also want to check that the deployment has worked. You can check things like the shape of the stored values.

np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5
)
240 changes: 240 additions & 0 deletions CI/unit_tests/analysis/test_loss_ntk_calculation.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to be inline with the package.

Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this from the tests


import jax.numpy as np
import numpy as onp
import optax
from neural_tangents import stax
from numpy.testing import assert_array_almost_equal

from znnl.analysis import EigenSpaceAnalysis, LossNTKCalculation
from znnl.data import MNISTGenerator
from znnl.distance_metrics import LPNorm
from znnl.loss_functions import LPNormLoss
from znnl.models import FlaxModel, NTModel


class TestLossNTKCalculation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with the class names. Just TestLossNTK

"""
Test Suite for the LossNTKCalculation module.
"""

def test_reshaping_methods(self):
"""
Test the _reshape_dataset and _unshape_dataset methods.
These are functions used in the loss NTK calculation to
"""
# Define a dummy model and dataset to be able to define a
# LossNTKCalculation class
feed_forward_model = stax.serial(
stax.Dense(5),
stax.Relu(),
stax.Dense(2),
stax.Relu(),
)

# Initialize the model
test_model = NTModel(
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 5),
trace_axes=(),
nt_module=feed_forward_model,
)

data_generator = MNISTGenerator(ds_size=20)
data_set = {
"inputs": data_generator.train_ds["inputs"],
"targets": data_generator.train_ds["targets"],
}

# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=LPNorm(order=2),
model=test_model,
dataset=data_set,
)

# Setup a test dataset for reshaping
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
test_data_set = {
"inputs": np.array([[1, 2, 3], [4, 5, 6]]),
"targets": np.array([[7, 9], [10, 12]]),
}

# Test the reshaping
reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set)

assert_array_almost_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should probably be split into differen tests rather than having one large one. Something like:

def test_reshaping
...

def test_unshaping
...

reshaped_test_data_set, np.array([[1, 2, 3, 7, 9], [4, 5, 6, 10, 12]])
)

# Test the unshaping
input_0, target_0 = loss_ntk_calculator._unshape_data(
reshaped_test_data_set,
input_dimension=3,
input_shape=(2, 3),
target_shape=(2, 2),
batch_length=reshaped_test_data_set.shape[0],
)
assert_array_almost_equal(input_0, test_data_set["inputs"])
assert_array_almost_equal(target_0, test_data_set["targets"])

def test_function_for_loss_ntk(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_loss_computation?

"""
This method tests the function that calculates the loss for single
datapoints.
"""
# Define a simple feed forward test model
feed_forward_model = stax.serial(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have to repeadt this code, it would be better to use a setup method at the start.

stax.Dense(5),
stax.Relu(),
stax.Dense(2),
stax.Relu(),
)

# Initialize the model
model = NTModel(
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 5),
trace_axes=(),
nt_module=feed_forward_model,
)

# Define a test dataset with only two datapoints
test_data_set = {
"inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]),
"targets": np.array([[1, 3], [2, 5]]),
}

# Initialize loss
loss = LPNormLoss(order=2)
# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=loss.metric,
model=model,
dataset=test_data_set,
)

# Calculate the subloss from the NTK first
datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1]
subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk(
{
"params": model.model_state.params,
"batch_stats": model.model_state.batch_stats,
},
datapoint=datapoint,
)

# Now calculate subloss manually
applied_model = model.apply(
{
"params": model.model_state.params,
"batch_stats": model.model_state.batch_stats,
},
test_data_set["inputs"][0],
)
subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2)

# Check that the two losses are the same
assert subloss - subloss_from_NTK < 1e-5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use pytest.approx for this kind of thing.


def test_loss_NTK_calculation(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have a real hard-coded or analytical result here. You can seed the network initialisation and force the same output. It would be even better to have an analytically validated result over an ensemble or something. Here you rely on your implementation via einsum to check what is probably your implementation via einsum just wrapped up in a different function name. This won't validate your implementation and it won't identify if something changes in that library that yields a false result.

"""
Test the Loss NTK calculation with a manually hardcoded dataset
and a simple feed forward network.
"""
# Define a simple feed forward test model
feed_forward_model = stax.serial(
stax.Dense(5),
stax.Relu(),
stax.Dense(1),
stax.Relu(),
)

# Initialize the model
model = NTModel(
optimizer=optax.adam(learning_rate=0.01),
input_shape=(2,),
trace_axes=(),
nt_module=feed_forward_model,
)

# Create a test dataset
inputs = np.array(onp.random.rand(10, 2))
targets = np.array(100 * onp.random.rand(10, 1))
test_data_set = {"inputs": inputs, "targets": targets}

# Initialize loss
loss = LPNormLoss(order=2)
# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=loss.metric,
model=model,
dataset=test_data_set,
)

# Compute the loss NTK
loss_ntk = loss_ntk_calculator.compute_loss_ntk(x_i=test_data_set, model=model)[
"empirical"
]

# Now for comparison calculate regular ntk
ntk = model.compute_ntk(test_data_set["inputs"], infinite=False)["empirical"]

# predictions calculation analogous to the one in jax recording
predictions = model(test_data_set["inputs"])
if type(predictions) is tuple:
predictions = predictions[0]

# calculation of loss derivatives
# note: here we need the derivatives of the subloss, not the regular loss fn
loss_derivatives = onp.empty(shape=(len(predictions)))
for i in range(len(loss_derivatives)):
loss_derivatives[i] = (
predictions[i, 0] / np.abs(predictions[i, 0])
if predictions[i, 0] != 0
else 0
)

# Calculate the loss NTK from the loss derivatives and the ntk
loss_ntk_2 = np.einsum(
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
"i, j, ijkl-> ij", loss_derivatives, loss_derivatives, ntk
)

# Assert that the loss NTKs are the same
assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=6)

calculator1 = EigenSpaceAnalysis(matrix=loss_ntk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the spectrum analysis if you have element by element compared the matrix?

calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2)

eigenvalues1 = calculator1.compute_eigenvalues(normalize=False)
eigenvalue2 = calculator2.compute_eigenvalues(normalize=False)

assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=6)
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_instantiation(self):
"name",
"storage_path",
"chunk_size",
"use_loss_ntk",
]
for key, val in vars(recorder).items():
if key[0] != "_" and key not in _exclude_list:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ plotly
flax
tqdm
pandas
neural-tangents==0.6.4
neural-tangents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a >= rather than removing this.

tensorflow-datasets
isort
tensorflow
Expand Down
2 changes: 2 additions & 0 deletions znnl/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.analysis.loss_ntk_calculation import LossNTKCalculation

__all__ = [
EntropyAnalysis.__name__,
EigenSpaceAnalysis.__name__,
LossDerivative.__name__,
LossNTKCalculation.__name__,
]
Loading
Loading