-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added CIFAR tutorials #79
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @khaledsaab sorry for the delay reviewing this! I think it looks great overall, see comments for requested changes. Also, to pass tests, need to run make dev
, then make fix
and make check
(see Developer guidelines in README)
# The following identity module is to essentially replace the last FC layer | ||
# in the resnet model by the FC in MeTal | ||
|
||
class IdentityModule(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class already exists in metal.modules.identity_module
|
||
# Here we create a dataloader that transforms CIFAR labels from 0-9, to 1-10, | ||
# We do this because MeTal treats a 0 label as abstain | ||
class MetalDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already have a MetalDataset
, I would rename this OneIndexedDataset
? Also clean up commented out stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this can go in metal.utils
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, in_planes, planes, stride=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, but in_channels
etc. is more standard; in_planes
might be confusing?
class ResNet(nn.Module): | ||
def __init__(self, block, num_blocks, num_classes=10): | ||
super(ResNet, self).__init__() | ||
self.in_planes = 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is defined but not used? Would be ideal to make this a kwarg also?
@@ -0,0 +1,117 @@ | |||
'''ResNet in PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should put this in metal.contrib.modules
No description provided.