-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added "Porting PyTorch model to JAX" tutorial #71
Conversation
a6dc454
to
f48fe33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vfdev-5 - found a few typos, left a few suggestions for rewording and for adding links. Some are very optional, feel free to take or leave 😄 Cheers!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really great – thanks! Just a couple small comments below, then I'd love to get this merged!
- Activations like `nn.ReLU` -> `lambda x: nnx.relu(x)` | ||
- Pooling layers like `nn.AvgPool2d` -> `lambda x: nnx.avg_pool(x, ...)` | ||
- `nn.AdaptiveAvgPool2d(1)` -> `lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2]))`, x is in NHWC format | ||
- `nn.Flatten()` -> `lambda x: x.reshape(x.shape[0], -1)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use jax.vmap(jnp.ravel)
here? Or is that too magic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is too magic.
f48fe33
to
dc9f67f
Compare
@melissawm @jakevdp thanks for the review, I applied majority of your suggestions to this tutorial |
dc9f67f
to
0c0c43e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The errors on ReadtheDocs are about the headings:
You might need to change them to be H3 instead. |
0c0c43e
to
c1da57c
Compare
No description provided.