Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Reactant.batch function for better batching (vmap too!) #180

Open
mofeing opened this issue Oct 16, 2024 · 0 comments
Open

Implement Reactant.batch function for better batching (vmap too!) #180

mofeing opened this issue Oct 16, 2024 · 0 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@mofeing
Copy link
Collaborator

mofeing commented Oct 16, 2024

Currently, the way batching is implemented is by replacing broadcasting with enzyme.batch op and tracing over the broadcasted code. The following example should just work:

X = [rand(4,4) for _ in 1:10]
f = @compile transpose.(X)

One inconvenient of Julia's broadcasting is that there is no way to specify the dimension over which to broadcast; it will just iterate over everything. Thus, users need to use eachslice for slicing over the desired dimension.

X = rand(4,4,10)
f = @compile broadcast(transpose, eachslice(X, dims=3))

I'm not sure if we would correctly then batch on the desired dimension in this case or that it would create some extra instructions... need to check it.

But it could be beneficial to have some similar functionality in one batch function which would be easier to correctly trace and users coming from Jax would be more familiarized. An example:

f = @compile Reactant.batch(transpose, X; dims=3)
@mofeing mofeing added enhancement New feature or request good first issue Good for newcomers labels Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant