-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented the loss NTK calculation #109
base: main
Are you sure you want to change the base?
Changes from 25 commits
076d0df
930cac8
580194a
768fc13
c802010
17fd83a
e88292f
4eb8362
fb1da2e
d22bb68
e141dd8
22cfc90
5f9b6bf
131d5cb
5b244ff
3a06e75
001e46e
fd99b77
0ef059f
98c6b30
cd5644a
b4fe246
bf7a401
43cdcb9
fd8626f
b6ebb8d
d17e514
eb49043
53c548d
1dac434
422eceb
b5e8170
594f4cb
b55fc30
080099e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
""" | ||
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this from the tests |
||
|
||
import jax.numpy as np | ||
import numpy as onp | ||
import optax | ||
from flax import linen as nn | ||
from neural_tangents import stax | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from znnl.analysis import EigenSpaceAnalysis, LossDerivative, LossNTKCalculation | ||
from znnl.data import MNISTGenerator | ||
from znnl.distance_metrics import LPNorm | ||
from znnl.loss_functions import LPNormLoss | ||
from znnl.models import FlaxModel, NTModel | ||
|
||
|
||
# Defines a simple CNN module | ||
class ProductionModule(nn.Module): | ||
""" | ||
Simple CNN module. | ||
""" | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Conv(features=16, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | ||
x = nn.Conv(features=16, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) | ||
x = x.reshape((x.shape[0], -1)) # flatten | ||
x = nn.Dense(features=10)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(10)(x) | ||
|
||
return x | ||
|
||
|
||
class TestLossNTKCalculation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the class names. Just |
||
""" | ||
Test Suite for the LossNTKCalculation module. | ||
""" | ||
|
||
def test_reshaping_methods(self): | ||
""" | ||
Test the _reshape_dataset and _unshape_dataset methods. | ||
These are functions used in the loss NTK calculation to | ||
""" | ||
# Define a dummy model and dataset to be able to define a | ||
# LossNTKCalculation class | ||
production_model = FlaxModel( | ||
flax_module=ProductionModule(), | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(1, 28, 28, 1), | ||
trace_axes=(), | ||
) | ||
|
||
data_generator = MNISTGenerator(ds_size=20) | ||
data_set = { | ||
"inputs": data_generator.train_ds["inputs"], | ||
"targets": data_generator.train_ds["targets"], | ||
} | ||
|
||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=LPNorm(order=2), | ||
model=production_model, | ||
dataset=data_set, | ||
) | ||
|
||
# Setup a test dataset for reshaping | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_data_set = { | ||
"inputs": np.array([[1, 2, 3], [4, 5, 6]]), | ||
"targets": np.array([[7], [10]]), | ||
} | ||
|
||
# Test the reshaping | ||
reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set) | ||
|
||
assert_array_almost_equal( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these should probably be split into differen tests rather than having one large one. Something like: def test_reshaping
...
def test_unshaping
... |
||
reshaped_test_data_set, np.array([[1, 2, 3, 7], [4, 5, 6, 10]]) | ||
) | ||
|
||
# Test the unshaping | ||
input_0, target_0 = loss_ntk_calculator._unshape_data( | ||
reshaped_test_data_set, | ||
input_dimension=3, | ||
input_shape=(2, 3), | ||
target_shape=(2, 1), | ||
batch_length=reshaped_test_data_set.shape[0], | ||
) | ||
assert_array_almost_equal(input_0, test_data_set["inputs"]) | ||
assert_array_almost_equal(target_0, test_data_set["targets"]) | ||
|
||
def test_function_for_loss_ntk(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
This method tests the function that is used for the correlation matrix | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
in the loss NTK calculation. It is supposed to yield the loss per single | ||
datapoint. | ||
""" | ||
# Define a simple feed forward test model | ||
feed_forward_model = stax.serial( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have to repeadt this code, it would be better to use a setup method at the start. |
||
stax.Dense(5), | ||
stax.Relu(), | ||
stax.Dense(2), | ||
stax.Relu(), | ||
) | ||
|
||
# Initialize the model | ||
model = NTModel( | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(1, 5), | ||
trace_axes=(), | ||
nt_module=feed_forward_model, | ||
) | ||
|
||
# Define a test dataset with only two datapoints | ||
test_data_set = { | ||
"inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]), | ||
"targets": np.array([[1, 3], [2, 5]]), | ||
} | ||
|
||
# Initialize loss | ||
loss = LPNormLoss(order=2) | ||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=loss.metric, | ||
model=model, | ||
dataset=test_data_set, | ||
) | ||
|
||
# Calculate the subloss from the NTK first | ||
datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1] | ||
subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk( | ||
{ | ||
"params": model.model_state.params, | ||
"batch_stats": model.model_state.batch_stats, | ||
}, | ||
datapoint=datapoint, | ||
) | ||
|
||
# Now calculate subloss manually | ||
applied_model = model.apply( | ||
{ | ||
"params": model.model_state.params, | ||
"batch_stats": model.model_state.batch_stats, | ||
}, | ||
test_data_set["inputs"][0], | ||
) | ||
subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2) | ||
|
||
# Check that the two losses are the same | ||
assert subloss - subloss_from_NTK < 1e-5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use |
||
|
||
def test_loss_NTK_calculation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to have a real hard-coded or analytical result here. You can seed the network initialisation and force the same output. It would be even better to have an analytically validated result over an ensemble or something. Here you rely on your implementation via einsum to check what is probably your implementation via einsum just wrapped up in a different function name. This won't validate your implementation and it won't identify if something changes in that library that yields a false result. |
||
""" | ||
Test the Loss NTK calculation. | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Here we test if the Loss NTK calculated through the neural tangents module is | ||
the same as the Loss NTK calculated with the already implemented NTK and loss | ||
derivatives. | ||
We do this for a small CNN model and the MNIST dataset. | ||
We also check if the eigenvalues of the two Loss NTKs are the same. | ||
|
||
The current implementation yields a precision of e-4. If these are numerical | ||
errors or due to a mistake in the implementation is to be decided. | ||
""" | ||
|
||
# Define a test model | ||
production_model = FlaxModel( | ||
flax_module=ProductionModule(), | ||
optimizer=optax.adam(learning_rate=0.01), | ||
input_shape=(1, 28, 28, 1), | ||
trace_axes=(), | ||
) | ||
# Initialize model parameters | ||
|
||
data_generator = MNISTGenerator(ds_size=20) | ||
data_set = { | ||
"inputs": data_generator.train_ds["inputs"], | ||
"targets": data_generator.train_ds["targets"], | ||
} | ||
|
||
# Initialize the loss NTK calculation | ||
loss_ntk_calculator = LossNTKCalculation( | ||
metric_fn=LPNorm(order=2), | ||
model=production_model, | ||
dataset=data_set, | ||
) | ||
|
||
# Compute the loss NTK | ||
loss_ntk = loss_ntk_calculator.compute_loss_ntk( | ||
x_i=data_set, model=production_model | ||
)["empirical"] | ||
|
||
# Now for comparison calculate regular ntk | ||
ntk = production_model.compute_ntk(data_set["inputs"], infinite=False)[ | ||
"empirical" | ||
] | ||
# Calculate Loss derivative fn | ||
loss_derivative_calculator = LossDerivative(LPNormLoss(order=2)) | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# predictions calculation analogous to the one in jax recording | ||
predictions = production_model(data_set["inputs"]) | ||
if type(predictions) is tuple: | ||
predictions = predictions[0] | ||
|
||
# calculation of loss derivatives | ||
# note: here we need the derivatives of the subloss, not the regular loss fn | ||
loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0]))) | ||
for i in range(len(loss_derivatives)): | ||
# The weird indexing here is because of axis constraints in LPNormLoss | ||
loss_derivatives[i] = loss_derivative_calculator.calculate( | ||
predictions[i : i + 1], data_set["targets"][i : i + 1] | ||
)[0] | ||
|
||
# Calculate the loss NTK from the loss derivatives and the ntk | ||
loss_ntk_2 = np.einsum( | ||
KonstiNik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"ik, jl, ijkl-> ij", loss_derivatives, loss_derivatives, ntk | ||
) | ||
|
||
# Assert that the loss NTKs are the same | ||
assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) | ||
|
||
calculator1 = EigenSpaceAnalysis(matrix=loss_ntk) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do the spectrum analysis if you have element by element compared the matrix? |
||
calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2) | ||
|
||
eigenvalues1 = calculator1.compute_eigenvalues(normalize=False) | ||
eigenvalue2 = calculator2.compute_eigenvalues(normalize=False) | ||
|
||
assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this to be inline with the package.