Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added "Porting PyTorch model to JAX" tutorial #71

Merged
merged 1 commit into from
Oct 21, 2024

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Oct 17, 2024

No description provided.

@vfdev-5 vfdev-5 force-pushed the jax-for-pytorch-users-p2 branch 2 times, most recently from a6dc454 to f48fe33 Compare October 17, 2024 14:57
Copy link
Contributor

@melissawm melissawm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vfdev-5 - found a few typos, left a few suggestions for rewording and for adding links. Some are very optional, feel free to take or leave 😄 Cheers!

docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great – thanks! Just a couple small comments below, then I'd love to get this merged!

docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
docs/JAX_porting_PyTorch_model.md Outdated Show resolved Hide resolved
- Activations like `nn.ReLU` -> `lambda x: nnx.relu(x)`
- Pooling layers like `nn.AvgPool2d` -> `lambda x: nnx.avg_pool(x, ...)`
- `nn.AdaptiveAvgPool2d(1)` -> `lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2]))`, x is in NHWC format
- `nn.Flatten()` -> `lambda x: x.reshape(x.shape[0], -1)`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use jax.vmap(jnp.ravel) here? Or is that too magic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is too magic.

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Oct 21, 2024

@melissawm @jakevdp thanks for the review, I applied majority of your suggestions to this tutorial

Copy link
Contributor

@melissawm melissawm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@melissawm
Copy link
Contributor

The errors on ReadtheDocs are about the headings:

/home/docs/checkouts/readthedocs.org/user_builds/jax-ai-stack/checkouts/71/docs/JAX_porting_PyTorch_model.ipynb.rst:180002: WARNING: Non-consecutive header level increase; H2 to H4 [myst.header]

You might need to change them to be H3 instead.

@jakevdp jakevdp merged commit 2eae49e into jax-ml:main Oct 21, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants