Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added CIFAR tutorials #79

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

khaledsaab
Copy link

No description provided.

Copy link
Contributor

@ajratner ajratner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @khaledsaab sorry for the delay reviewing this! I think it looks great overall, see comments for requested changes. Also, to pass tests, need to run make dev, then make fix and make check (see Developer guidelines in README)

# The following identity module is to essentially replace the last FC layer
# in the resnet model by the FC in MeTal

class IdentityModule(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class already exists in metal.modules.identity_module


# Here we create a dataloader that transforms CIFAR labels from 0-9, to 1-10,
# We do this because MeTal treats a 0 label as abstain
class MetalDataset(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have a MetalDataset, I would rename this OneIndexedDataset? Also clean up commented out stuff

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this can go in metal.utils

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but in_channels etc. is more standard; in_planes might be confusing?

class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is defined but not used? Would be ideal to make this a kwarg also?

@@ -0,0 +1,117 @@
'''ResNet in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put this in metal.contrib.modules

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants