-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pmap support #645
Comments
I agree that list of arguments as input and list of outputs makes sense. Especially for the reason you suggested re: passing on the outputs. I also like the idea of a I just want to clarify a few things for myself re: terminology and capabilities as well so I'm clear on where this would fit in. I'm not super familiar with XLA, and I'm thinking in terms of 'data parallel' approaches like PyTorch and optimizer state sharding like Fairscale/ZeRO. Am I right in thinking that really both of these would be enabled by So, for example, 'data parallel' a la PyTorch could be achieved by putting model replicas on each device then using In that case, being able to easily chain |
Correct, but we also want to add |
The goal is to add
Nx.Defn.pmap
. This is a discussion of its API.Input
When it comes to the input, the first option is to automatically shard the input. For example:
can shard the first argument based on the number of the devices. We can make the sharding dimensions customizable too. Something like:
will shard the first argument at axis 2 and the second argument at axis 0.
The second option is to allow a list of lists of already sharded tensors to be given. To convert the data to this format, one can use
Nx.to_batched_list/2
but perhaps we can also addNx.shard
.Output
When it comes to the output, we have two options. The most logical option, thinking about Elixir, is for it to return a list of results of the same size as the list of inputs. In the GPU/TPU case, the inputs will remain allocated on each GPU. This is great, especially with the second input API above because you can easily call
Nx.Defn.pmap
, with a separate list of inputs, to continue performing computations:If your goal is to put the tensors back together into one, then you need to call
Nx.stack/1
. However, given each tensor belongs to a separate device, perhaps you will need something like this:Maybe we should make it so
Nx.stack()
automatically performs the transfer across devices (this is something we can discuss separately).The other approach is to handle it the same way as JAX: it creates a separate "tensor backend" that knows in practice the tensor belongs to n-other backends. The benefit is that it can still present the data as one and perhaps encapsulate the backend transfer code above, at the cost of one additional abstraction.
Personal thoughts section
My personal thought is that the most Elixir-like approach is to have a list of arguments as input and a list of outputs. We can add functions such as
Nx.shard/2
to make it easier to shard existing values. We can also addNx.Defn.num_shards(opts)
to return the number of shards that the current compiler supports.If we feel like we want to support first-class sharding, then we can add
Nx.smap
, which stands forshard map
, that automatically does the sharding for you but is built on top ofpmap
.TODO
Nx.Defn.pmap
Nx.Defn.num_shards(opts)
(any better names than shards?)Nx.shard
Nx.Defn.smap
The text was updated successfully, but these errors were encountered: