-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented the loss NTK calculation #109
base: main
Are you sure you want to change the base?
Changes from all commits
076d0df
930cac8
580194a
768fc13
c802010
17fd83a
e88292f
4eb8362
fb1da2e
d22bb68
e141dd8
22cfc90
5f9b6bf
131d5cb
5b244ff
3a06e75
001e46e
fd99b77
0ef059f
98c6b30
cd5644a
b4fe246
bf7a401
43cdcb9
fd8626f
b6ebb8d
d17e514
eb49043
53c548d
1dac434
422eceb
b5e8170
594f4cb
b55fc30
080099e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
""" | ||
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove from the test please. |
||
|
||
import numpy as np | ||
import optax | ||
from neural_tangents import stax | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from znnl.loss_functions import LPNormLoss | ||
from znnl.models import NTModel | ||
from znnl.training_recording import JaxRecorder | ||
from znnl.training_strategies import SimpleTraining | ||
|
||
|
||
class TestLossNTKRecorderDeployment: | ||
""" | ||
Test suite for the loss and NTK recorder. | ||
""" | ||
|
||
@classmethod | ||
def setup_class(cls): | ||
""" | ||
Create a model and data for the tests. | ||
""" | ||
|
||
network = stax.serial( | ||
stax.Dense(10), stax.Relu(), stax.Dense(10), stax.Relu(), stax.Dense(1) | ||
) | ||
cls.model = NTModel( | ||
nt_module=network, input_shape=(5,), optimizer=optax.adam(1e-3) | ||
) | ||
|
||
cls.data_set = { | ||
"inputs": np.random.rand(10, 5), | ||
"targets": np.random.randint(0, 2, (10, 1)), | ||
} | ||
|
||
cls.ntk_recorder = JaxRecorder( | ||
name="ntk_recorder", | ||
ntk=True, | ||
update_rate=1, | ||
) | ||
cls.loss_ntk_recorder = JaxRecorder( | ||
name="loss_ntk_recorder", | ||
ntk=True, | ||
use_loss_ntk=True, | ||
update_rate=1, | ||
) | ||
|
||
cls.ntk_recorder.instantiate_recorder(data_set=cls.data_set) | ||
cls.loss_ntk_recorder.instantiate_recorder(data_set=cls.data_set) | ||
|
||
cls.trainer = SimpleTraining( | ||
model=cls.model, | ||
loss_fn=LPNormLoss(order=2), | ||
recorders=[cls.ntk_recorder, cls.loss_ntk_recorder], | ||
) | ||
|
||
def test_loss_ntk_deployment(self): | ||
""" | ||
Test the deployment of the loss_NTK recorder. | ||
""" | ||
|
||
# train the model | ||
training_metrics = self.trainer.train_model( | ||
train_ds=self.data_set, | ||
test_ds=self.data_set, | ||
epochs=10, | ||
batch_size=2, | ||
) | ||
|
||
# gather the recording | ||
ntk_recording = self.ntk_recorder.gather_recording() | ||
loss_ntk_recording = self.loss_ntk_recorder.gather_recording() | ||
|
||
# For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK | ||
# should be the same up to a factor of +1 or -1. | ||
assert_array_almost_equal( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this is an integration test, you will also want to check that the deployment has worked. You can check things like the shape of the stored values. |
||
np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5 | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rename this to be inline with the package. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
""" | ||
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this from the tests |
||
|
||
import jax.numpy as np | ||
import numpy as onp | ||
import optax | ||
from neural_tangents import stax | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from znnl.analysis import EigenSpaceAnalysis, LossNTKCalculation | ||
from znnl.data import MNISTGenerator | ||
from znnl.distance_metrics import LPNorm | ||
from znnl.loss_functions import LPNormLoss | ||
from znnl.models import FlaxModel, NTModel | ||
|
||
|
||
class TestLossNTKCalculation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the class names. Just |
||
""" | ||
Test Suite for the LossNTKCalculation module. | ||
""" | ||
|
||
def test_reshaping_methods(self): | ||
""" | ||
Test the _reshape_dataset and _unshape_dataset methods. | ||
These are functions used in the loss NTK calculation to | ||
""" | ||
# Define a dummy model and dataset to be able to define a | ||
# LossNTKCalculation class | ||
feed_forward_model = stax.serial( | ||
stax.Dense(5), | ||
stax.Relu(), | ||
stax.Dense(2), | ||
stax.Relu(), | ||
) | ||
|
||
# Initialize the model | ||
test_model = NTModel( | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(1, 5), | ||
trace_axes=(), | ||
nt_module=feed_forward_model, | ||
) | ||
|
||
data_generator = MNISTGenerator(ds_size=20) | ||
data_set = { | ||
"inputs": data_generator.train_ds["inputs"], | ||
"targets": data_generator.train_ds["targets"], | ||
} | ||
|
||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=LPNorm(order=2), | ||
model=test_model, | ||
dataset=data_set, | ||
) | ||
|
||
# Setup a test dataset for reshaping | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_data_set = { | ||
"inputs": np.array([[1, 2, 3], [4, 5, 6]]), | ||
"targets": np.array([[7, 9], [10, 12]]), | ||
} | ||
|
||
# Test the reshaping | ||
reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set) | ||
|
||
assert_array_almost_equal( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these should probably be split into differen tests rather than having one large one. Something like: def test_reshaping
...
def test_unshaping
... |
||
reshaped_test_data_set, np.array([[1, 2, 3, 7, 9], [4, 5, 6, 10, 12]]) | ||
) | ||
|
||
# Test the unshaping | ||
input_0, target_0 = loss_ntk_calculator._unshape_data( | ||
reshaped_test_data_set, | ||
input_dimension=3, | ||
input_shape=(2, 3), | ||
target_shape=(2, 2), | ||
batch_length=reshaped_test_data_set.shape[0], | ||
) | ||
assert_array_almost_equal(input_0, test_data_set["inputs"]) | ||
assert_array_almost_equal(target_0, test_data_set["targets"]) | ||
|
||
def test_function_for_loss_ntk(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
This method tests the function that calculates the loss for single | ||
datapoints. | ||
""" | ||
# Define a simple feed forward test model | ||
feed_forward_model = stax.serial( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have to repeadt this code, it would be better to use a setup method at the start. |
||
stax.Dense(5), | ||
stax.Relu(), | ||
stax.Dense(2), | ||
stax.Relu(), | ||
) | ||
|
||
# Initialize the model | ||
model = NTModel( | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(1, 5), | ||
trace_axes=(), | ||
nt_module=feed_forward_model, | ||
) | ||
|
||
# Define a test dataset with only two datapoints | ||
test_data_set = { | ||
"inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]), | ||
"targets": np.array([[1, 3], [2, 5]]), | ||
} | ||
|
||
# Initialize loss | ||
loss = LPNormLoss(order=2) | ||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=loss.metric, | ||
model=model, | ||
dataset=test_data_set, | ||
) | ||
|
||
# Calculate the subloss from the NTK first | ||
datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1] | ||
subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk( | ||
{ | ||
"params": model.model_state.params, | ||
"batch_stats": model.model_state.batch_stats, | ||
}, | ||
datapoint=datapoint, | ||
) | ||
|
||
# Now calculate subloss manually | ||
applied_model = model.apply( | ||
{ | ||
"params": model.model_state.params, | ||
"batch_stats": model.model_state.batch_stats, | ||
}, | ||
test_data_set["inputs"][0], | ||
) | ||
subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2) | ||
|
||
# Check that the two losses are the same | ||
assert subloss - subloss_from_NTK < 1e-5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use |
||
|
||
def test_loss_NTK_calculation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to have a real hard-coded or analytical result here. You can seed the network initialisation and force the same output. It would be even better to have an analytically validated result over an ensemble or something. Here you rely on your implementation via einsum to check what is probably your implementation via einsum just wrapped up in a different function name. This won't validate your implementation and it won't identify if something changes in that library that yields a false result. |
||
""" | ||
Test the Loss NTK calculation with a manually hardcoded dataset | ||
and a simple feed forward network. | ||
""" | ||
# Define a simple feed forward test model | ||
feed_forward_model = stax.serial( | ||
stax.Dense(5), | ||
stax.Relu(), | ||
stax.Dense(1), | ||
stax.Relu(), | ||
) | ||
|
||
# Initialize the model | ||
model = NTModel( | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(2,), | ||
trace_axes=(), | ||
nt_module=feed_forward_model, | ||
) | ||
|
||
# Create a test dataset | ||
inputs = np.array(onp.random.rand(10, 2)) | ||
targets = np.array(100 * onp.random.rand(10, 1)) | ||
test_data_set = {"inputs": inputs, "targets": targets} | ||
|
||
# Initialize loss | ||
loss = LPNormLoss(order=2) | ||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=loss.metric, | ||
model=model, | ||
dataset=test_data_set, | ||
) | ||
|
||
# Compute the loss NTK | ||
loss_ntk = loss_ntk_calculator.compute_loss_ntk(x_i=test_data_set, model=model)[ | ||
"empirical" | ||
] | ||
|
||
# Now for comparison calculate regular ntk | ||
ntk = model.compute_ntk(test_data_set["inputs"], infinite=False)["empirical"] | ||
|
||
# predictions calculation analogous to the one in jax recording | ||
predictions = model(test_data_set["inputs"]) | ||
if type(predictions) is tuple: | ||
predictions = predictions[0] | ||
|
||
# calculation of loss derivatives | ||
# note: here we need the derivatives of the subloss, not the regular loss fn | ||
loss_derivatives = onp.empty(shape=(len(predictions))) | ||
for i in range(len(loss_derivatives)): | ||
loss_derivatives[i] = ( | ||
predictions[i, 0] / np.abs(predictions[i, 0]) | ||
if predictions[i, 0] != 0 | ||
else 0 | ||
) | ||
|
||
# Calculate the loss NTK from the loss derivatives and the ntk | ||
loss_ntk_2 = np.einsum( | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"i, j, ijkl-> ij", loss_derivatives, loss_derivatives, ntk | ||
) | ||
|
||
# Assert that the loss NTKs are the same | ||
assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=6) | ||
|
||
calculator1 = EigenSpaceAnalysis(matrix=loss_ntk) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do the spectrum analysis if you have element by element compared the matrix? |
||
calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2) | ||
|
||
eigenvalues1 = calculator1.compute_eigenvalues(normalize=False) | ||
eigenvalue2 = calculator2.compute_eigenvalues(normalize=False) | ||
|
||
assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=6) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ plotly | |
flax | ||
tqdm | ||
pandas | ||
neural-tangents==0.6.4 | ||
neural-tangents | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use a |
||
tensorflow-datasets | ||
isort | ||
tensorflow | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just name this test_loss_ntk. The naming of the tests should mirror the main python package just with test in front. All integration tests using the loss ntk should be in this one module.