-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented the loss NTK calculation #109
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job in this PR. The first draft looks pretty nice.
I have two major comments:
- In general, a test should cover as many aspects of the code as possible. And a test should test the desired aspect as simply as possible. This includes trying to rely as little as possible on existing methods. E.g. including an existing data generator is not as good practice as creating some dummy test data.
- With your implementation, we would need to duplicate all observables for the loss ntk that we already have for the ntk. One can avoid this by having the option to either use a recorder for the loss ntk or the regular ntk. This could be done with one keyword at initialization e.g..
# Check if we need a loss NTK computation and update the class accordingly | ||
if any( | ||
[ | ||
"loss_ntk" in self._selected_properties, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I see, right now we would have to implement the trace and all other properties again to use them with the loss ntk.
I think it might be more reasonable you had one kwarg like use_loss_ntk
with which all ntk calculations are now using the loss ntk, making the entire recorder a loss ntk recorder. With this, we could re-use all the properties we have already implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that there's room for improvement, but I think if we introduce a flag like this here we should maybe also discuss more changes to the recorder. We should talk about this in person or in a meeting, but I'd like to make sure that the tests are working first because it's more urgent for the DPG if that's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flag got introduced in commit 1dac434
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few comments. If you go through and address them all I can go back over it but in general, I like it and am happy to have it merged soon.
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove from the test please.
|
||
# For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK | ||
# should be the same up to a factor of +1 or -1. | ||
assert_array_almost_equal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is an integration test, you will also want to check that the deployment has worked. You can check things like the shape of the stored values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just name this test_loss_ntk. The naming of the tests should mirror the main python package just with test in front. All integration tests using the loss ntk should be in this one module.
|
||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this from the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this to be inline with the package.
) | ||
|
||
@staticmethod | ||
def _unshape_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sam here.
""" | ||
|
||
# Set the attributes | ||
self.ntk_batch_size = model.ntk_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to make all of these arguments in this calculation. Especially when we later move into the new measurement system, this will all need to be self contained. Things like store_on_device
are only pertinent to this calculator.
|
||
Returns | ||
------- | ||
input: np.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some shape information here? It can be (batch * input size, ) or anything, but just some information about what I will get back. What you mean by unshape is also very unclear. Is it flattening is it reshaping, unshape doesn't have a real meaning.
batch_length, *input_shape[1:] | ||
), datapoint[:, input_dimension:].reshape(batch_length, *target_shape[1:]) | ||
|
||
def _function_for_loss_ntk(self, params, datapoint) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer different naming here. Is it an apply function on flattened data, a loss function. What do you mean by subloss? Loss between two data points is just loss. Function for loss ntk could be anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the notebook is not clear, can you clear it of outputs.
No description provided.