Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-worker support for JAX training. #18654

Merged
merged 4 commits into from
Oct 24, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Oct 19, 2023

I have tried to add some backend agnostic tests to dataset builder, and the JAX specific multi-worker test will be hosted internally.

@qlzh727
Copy link
Member Author

qlzh727 commented Oct 19, 2023

This will update #18561 and #18560

@codecov-commenter
Copy link

codecov-commenter commented Oct 19, 2023

Codecov Report

Attention: 28 lines in your changes are missing coverage. Please review.

Comparison is base (2ad8e07) 78.57% compared to head (174b193) 78.51%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18654      +/-   ##
==========================================
- Coverage   78.57%   78.51%   -0.06%     
==========================================
  Files         335      335              
  Lines       32979    33020      +41     
  Branches     6455     6467      +12     
==========================================
+ Hits        25913    25927      +14     
- Misses       5510     5532      +22     
- Partials     1556     1561       +5     
Flag Coverage Δ
keras 78.41% <41.66%> (-0.06%) ⬇️
keras-jax 63.39% <41.66%> (-0.04%) ⬇️
keras-numpy 57.67% <12.50%> (-0.06%) ⬇️
keras-tensorflow 64.51% <14.58%> (-0.07%) ⬇️
keras-torch 65.16% <14.58%> (-0.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/jax/core.py 90.32% <100.00%> (+0.10%) ⬆️
keras/backend/jax/trainer.py 95.54% <100.00%> (+0.02%) ⬆️
keras/trainers/data_adapters/tf_dataset_adapter.py 96.15% <100.00%> (+0.15%) ⬆️
keras/trainers/data_adapters/__init__.py 68.96% <60.00%> (-1.41%) ⬇️
keras/distribution/distribution_lib.py 91.26% <12.50%> (-2.71%) ⬇️
keras/backend/jax/distribution_lib.py 63.15% <24.00%> (-19.54%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -7,7 +7,7 @@
class TFDatasetAdapter(DataAdapter):
"""Adapter that handles `tf.data.Dataset`."""

def __init__(self, dataset, class_weight=None):
def __init__(self, dataset, class_weight=None, distribution=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a Args: section to document the type of each argument (to avoid confusion with tf.distribute)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

):
raise ValueError(
"Only `tf.data.Dataset` is supported for multi worker "
f"distribution, received input types is {type(x)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using multi-worker distribution, the data must be provided 
as a `tf.data.Dataset` instance. Received: type(x)={type(x)}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

inputs: `jax.Array` that is already sharded to a local process size.
layout: `TensorLayout` for the distribution information, or a
`jax.sharding.Sharding` instance.
Returns:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add space above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

).prefetch(tf.data.AUTOTUNE)
batch_size = tf_data_distribute.compute_batch_size(dataset)
if batch_size.numpy() < 0:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases does this happen? Unbatched dataset? The error message should make explicit what user action is required (e.g. calling .batch()).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, most likely due to the dataset is not batched. Update the error message.

@qlzh727 qlzh727 requested a review from fchollet October 23, 2023 17:51
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 24, 2023
@fchollet fchollet merged commit ee8b1ea into keras-team:master Oct 24, 2023
6 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 24, 2023
@qlzh727 qlzh727 deleted the multiworker branch October 24, 2023 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

4 participants