Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the loss NTK calculation #109

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
076d0df
starting out
Jan 12, 2024
930cac8
Renaming file
ma-sauter Jan 18, 2024
580194a
analysis file should be done
ma-sauter Jan 18, 2024
768fc13
quick fix
ma-sauter Jan 18, 2024
c802010
Included loss ntk, eigval and entropy in jax recorder
ma-sauter Jan 19, 2024
17fd83a
Updated recorder instantiation test
ma-sauter Jan 19, 2024
e88292f
Bugfixing
ma-sauter Jan 19, 2024
4eb8362
Stand vor Kino
Jan 19, 2024
fb1da2e
First thing state that might be working
ma-sauter Jan 22, 2024
d22bb68
writing test
Jan 22, 2024
e141dd8
Included vmap_axes
ma-sauter Jan 30, 2024
22cfc90
Working on calculating loss derivatives to calculate loss ntk comparison
ma-sauter Jan 30, 2024
5f9b6bf
Calculation and test should work now
ma-sauter Feb 4, 2024
131d5cb
some linting updates
ma-sauter Feb 4, 2024
5b244ff
Fixed tests
ma-sauter Feb 4, 2024
3a06e75
Some modifications to simplify the loss ntk test code
ma-sauter Feb 6, 2024
001e46e
Class renaming to follow convention
ma-sauter Feb 6, 2024
fd99b77
added reshape and unshape methods in the loss_ntk_calculation
ma-sauter Feb 6, 2024
0ef059f
quicksave
Feb 20, 2024
98c6b30
Quick save
ma-sauter Feb 20, 2024
cd5644a
Added test for eigenvalues, precision is still only e-4
ma-sauter Feb 20, 2024
b4fe246
Added some docstrings
ma-sauter Feb 20, 2024
bf7a401
More docstrings
ma-sauter Feb 20, 2024
43cdcb9
Black formatting
ma-sauter Feb 20, 2024
fd8626f
More docstrings
ma-sauter Feb 20, 2024
b6ebb8d
Some PR modifications
ma-sauter Feb 21, 2024
d17e514
fixing PR comment
Feb 26, 2024
eb49043
requirements change
Feb 26, 2024
53c548d
Black formatter changes
Feb 26, 2024
1dac434
changed recorder to use the use_loss_ntk flag
Feb 26, 2024
422eceb
Change recorder test for new flag
Feb 26, 2024
b5e8170
removed unneccesary CNN model from loss_ntk calculation test
Feb 26, 2024
594f4cb
Started integration test for loss_ntk_calculation
Feb 26, 2024
b55fc30
Implemented integration test
Feb 27, 2024
080099e
isort
Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions CI/unit_tests/analysis/test_loss_ntk_calculation.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to be inline with the package.

Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this from the tests


import jax.numpy as np
import numpy as onp
import optax
from flax import linen as nn
from neural_tangents import stax
from numpy.testing import assert_array_almost_equal

from znnl.analysis import EigenSpaceAnalysis, LossDerivative, LossNTKCalculation
from znnl.data import MNISTGenerator
from znnl.distance_metrics import LPNorm
from znnl.loss_functions import LPNormLoss
from znnl.models import FlaxModel, NTModel


# Defines a simple CNN module
class ProductionModule(nn.Module):
"""
Simple CNN module.
"""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=16, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
x = nn.Conv(features=16, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=10)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)

return x


class TestLossNTKCalculation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with the class names. Just TestLossNTK

"""
Test Suite for the LossNTKCalculation module.
"""

def test_reshaping_methods(self):
"""
Test the _reshape_dataset and _unshape_dataset methods.
These are functions used in the loss NTK calculation to
"""
# Define a dummy model and dataset to be able to define a
# LossNTKCalculation class
production_model = FlaxModel(
flax_module=ProductionModule(),
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 28, 28, 1),
trace_axes=(),
)

data_generator = MNISTGenerator(ds_size=20)
data_set = {
"inputs": data_generator.train_ds["inputs"],
"targets": data_generator.train_ds["targets"],
}

# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=LPNorm(order=2),
model=production_model,
dataset=data_set,
)

# Setup a test dataset for reshaping
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
test_data_set = {
"inputs": np.array([[1, 2, 3], [4, 5, 6]]),
"targets": np.array([[7], [10]]),
}

# Test the reshaping
reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set)

assert_array_almost_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should probably be split into differen tests rather than having one large one. Something like:

def test_reshaping
...

def test_unshaping
...

reshaped_test_data_set, np.array([[1, 2, 3, 7], [4, 5, 6, 10]])
)

# Test the unshaping
input_0, target_0 = loss_ntk_calculator._unshape_data(
reshaped_test_data_set,
input_dimension=3,
input_shape=(2, 3),
target_shape=(2, 1),
batch_length=reshaped_test_data_set.shape[0],
)
assert_array_almost_equal(input_0, test_data_set["inputs"])
assert_array_almost_equal(target_0, test_data_set["targets"])

def test_function_for_loss_ntk(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_loss_computation?

"""
This method tests the function that is used for the correlation matrix
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
in the loss NTK calculation. It is supposed to yield the loss per single
datapoint.
"""
# Define a simple feed forward test model
feed_forward_model = stax.serial(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have to repeadt this code, it would be better to use a setup method at the start.

stax.Dense(5),
stax.Relu(),
stax.Dense(2),
stax.Relu(),
)

# Initialize the model
model = NTModel(
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 5),
trace_axes=(),
nt_module=feed_forward_model,
)

# Define a test dataset with only two datapoints
test_data_set = {
"inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]),
"targets": np.array([[1, 3], [2, 5]]),
}

# Initialize loss
loss = LPNormLoss(order=2)
# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=loss.metric,
model=model,
dataset=test_data_set,
)

# Calculate the subloss from the NTK first
datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1]
subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk(
{
"params": model.model_state.params,
"batch_stats": model.model_state.batch_stats,
},
datapoint=datapoint,
)

# Now calculate subloss manually
applied_model = model.apply(
{
"params": model.model_state.params,
"batch_stats": model.model_state.batch_stats,
},
test_data_set["inputs"][0],
)
subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2)

# Check that the two losses are the same
assert subloss - subloss_from_NTK < 1e-5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use pytest.approx for this kind of thing.


def test_loss_NTK_calculation(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have a real hard-coded or analytical result here. You can seed the network initialisation and force the same output. It would be even better to have an analytically validated result over an ensemble or something. Here you rely on your implementation via einsum to check what is probably your implementation via einsum just wrapped up in a different function name. This won't validate your implementation and it won't identify if something changes in that library that yields a false result.

"""
Test the Loss NTK calculation.
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
Here we test if the Loss NTK calculated through the neural tangents module is
the same as the Loss NTK calculated with the already implemented NTK and loss
derivatives.
We do this for a small CNN model and the MNIST dataset.
We also check if the eigenvalues of the two Loss NTKs are the same.

The current implementation yields a precision of e-4. If these are numerical
errors or due to a mistake in the implementation is to be decided.
"""

# Define a test model
production_model = FlaxModel(
flax_module=ProductionModule(),
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 28, 28, 1),
trace_axes=(),
)
# Initialize model parameters

data_generator = MNISTGenerator(ds_size=20)
data_set = {
"inputs": data_generator.train_ds["inputs"],
"targets": data_generator.train_ds["targets"],
}

# Initialize the loss NTK calculation
loss_ntk_calculator = LossNTKCalculation(
metric_fn=LPNorm(order=2),
model=production_model,
dataset=data_set,
)

# Compute the loss NTK
loss_ntk = loss_ntk_calculator.compute_loss_ntk(
x_i=data_set, model=production_model
)["empirical"]

# Now for comparison calculate regular ntk
ntk = production_model.compute_ntk(data_set["inputs"], infinite=False)[
"empirical"
]
# Calculate Loss derivative fn
loss_derivative_calculator = LossDerivative(LPNormLoss(order=2))
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved

# predictions calculation analogous to the one in jax recording
predictions = production_model(data_set["inputs"])
if type(predictions) is tuple:
predictions = predictions[0]

# calculation of loss derivatives
# note: here we need the derivatives of the subloss, not the regular loss fn
loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0])))
for i in range(len(loss_derivatives)):
# The weird indexing here is because of axis constraints in LPNormLoss
loss_derivatives[i] = loss_derivative_calculator.calculate(
predictions[i : i + 1], data_set["targets"][i : i + 1]
)[0]

# Calculate the loss NTK from the loss derivatives and the ntk
loss_ntk_2 = np.einsum(
KonstiNik marked this conversation as resolved.
Show resolved Hide resolved
"ik, jl, ijkl-> ij", loss_derivatives, loss_derivatives, ntk
)

# Assert that the loss NTKs are the same
assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4)

calculator1 = EigenSpaceAnalysis(matrix=loss_ntk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the spectrum analysis if you have element by element compared the matrix?

calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2)

eigenvalues1 = calculator1.compute_eigenvalues(normalize=False)
eigenvalue2 = calculator2.compute_eigenvalues(normalize=False)

assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=4)
3 changes: 3 additions & 0 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def test_instantiation(self):
eigenvalues=True,
trace=True,
loss_derivative=True,
loss_ntk=True,
loss_ntk_eigenvalues=True,
loss_ntk_entropy=True,
)
recorder.instantiate_recorder(data_set=self.dummy_data_set)
_exclude_list = [
Expand Down
2 changes: 2 additions & 0 deletions znnl/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.analysis.loss_ntk_calculation import LossNTKCalculation

__all__ = [
EntropyAnalysis.__name__,
EigenSpaceAnalysis.__name__,
LossDerivative.__name__,
LossNTKCalculation.__name__,
]
Loading
Loading