-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-worker support for JAX training. #18654
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18654 +/- ##
==========================================
- Coverage 78.57% 78.51% -0.06%
==========================================
Files 335 335
Lines 32979 33020 +41
Branches 6455 6467 +12
==========================================
+ Hits 25913 25927 +14
- Misses 5510 5532 +22
- Partials 1556 1561 +5
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -7,7 +7,7 @@ | |||
class TFDatasetAdapter(DataAdapter): | |||
"""Adapter that handles `tf.data.Dataset`.""" | |||
|
|||
def __init__(self, dataset, class_weight=None): | |||
def __init__(self, dataset, class_weight=None, distribution=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a Args:
section to document the type of each argument (to avoid confusion with tf.distribute)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
): | ||
raise ValueError( | ||
"Only `tf.data.Dataset` is supported for multi worker " | ||
f"distribution, received input types is {type(x)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using multi-worker distribution, the data must be provided
as a `tf.data.Dataset` instance. Received: type(x)={type(x)}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
inputs: `jax.Array` that is already sharded to a local process size. | ||
layout: `TensorLayout` for the distribution information, or a | ||
`jax.sharding.Sharding` instance. | ||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add space above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
).prefetch(tf.data.AUTOTUNE) | ||
batch_size = tf_data_distribute.compute_batch_size(dataset) | ||
if batch_size.numpy() < 0: | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what cases does this happen? Unbatched dataset? The error message should make explicit what user action is required (e.g. calling .batch()
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, most likely due to the dataset is not batched. Update the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I have tried to add some backend agnostic tests to dataset builder, and the JAX specific multi-worker test will be hosted internally.