Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded distributed sampler for cached dataloading in DDP #195

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

ziw-liu
Copy link
Collaborator

@ziw-liu ziw-liu commented Oct 21, 2024

Add a distributed sampler that only permutes index within ranks, improving cache hit rate in DDP.

See viscy/scripts/shared_dict.py for usage.

@ziw-liu ziw-liu marked this pull request as ready for review October 21, 2024 21:17
@ziw-liu ziw-liu added enhancement New feature or request translation Image translation (VS) labels Oct 21, 2024
@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Oct 21, 2024

Example output: GPU available: True (cuda), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /hpc/mydata/ziwen.liu/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3 Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3 Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3 ---------------------------------------------------------------------------------------------------- distributed_backend=gloo All distributed processes registered. Starting with 3 processes ----------------------------------------------------------------------------------------------------

=== Initializing cache pool for rank 0 ===
=== Initializing cache pool for rank 1 ===
=== Initializing cache pool for rank 2 ===

| Name | Type | Params | Mode

0 | layer | Linear | 2 | train

2 Trainable params
0 Non-trainable params
2 Total params
0.000 Total estimated model params size (MB)
1 Modules in train mode
0 Modules in eval mode

  • Adding 31 to cache dict on rank 1
  • Adding 32 to cache dict on rank 2
  • Adding 38 to cache dict on rank 2
  • Adding 42 to cache dict on rank 0
  • Adding 30 to cache dict on rank 0
  • Adding 36 to cache dict on rank 0
  • Adding 37 to cache dict on rank 1
  • Adding 43 to cache dict on rank 1
  • Adding 48 to cache dict on rank 0
  • Adding 34 to cache dict on rank 1
  • Adding 49 to cache dict on rank 1
  • Adding 30 to cache dict on rank 2
  • Adding 44 to cache dict on rank 2
  • Adding 41 to cache dict on rank 2
  • Adding 35 to cache dict on rank 2
  • Adding 40 to cache dict on rank 1
  • Adding 46 to cache dict on rank 1
  • Adding 39 to cache dict on rank 0
  • Adding 33 to cache dict on rank 0
  • Adding 47 to cache dict on rank 2
  • Adding 45 to cache dict on rank 0
  • Adding 24 to cache dict on rank 2
  • Adding 13 to cache dict on rank 1
  • Adding 0 to cache dict on rank 0
  • Adding 20 to cache dict on rank 2
  • Adding 4 to cache dict on rank 0
  • Adding 29 to cache dict on rank 2
  • Adding 19 to cache dict on rank 1
  • Adding 26 to cache dict on rank 2
  • Adding 28 to cache dict on rank 2
    === Starting training ===
    === Starting training epoch 0 ===
  • Adding 8 to cache dict on rank 0
  • Adding 15 to cache dict on rank 1
  • Adding 3 to cache dict on rank 0
  • Adding 21 to cache dict on rank 2
  • Adding 11 to cache dict on rank 1
  • Adding 7 to cache dict on rank 0
  • Adding 23 to cache dict on rank 2
  • Adding 27 to cache dict on rank 2
  • Adding 22 to cache dict on rank 2
  • Adding 1 to cache dict on rank 0
  • Adding 9 to cache dict on rank 0
  • Adding 5 to cache dict on rank 0
  • Adding 17 to cache dict on rank 1
  • Adding 6 to cache dict on rank 0
  • Adding 18 to cache dict on rank 1
  • Adding 16 to cache dict on rank 1
  • Adding 14 to cache dict on rank 1
  • Adding 10 to cache dict on rank 1
  • Adding 25 to cache dict on rank 2
  • Adding 2 to cache dict on rank 0
  • Adding 12 to cache dict on rank 1
    === Starting training epoch 1 ===
    === Starting training epoch 2 ===
    === Starting training epoch 3 ===
    === Starting training epoch 4 ===
    Trainer.fit stopped: max_epochs=5 reached.

@ziw-liu ziw-liu changed the base branch from ram_dataloader to main October 21, 2024 23:28
persistent_workers=bool(self.num_workers),
pin_memory=True,
shuffle=False,
timeout=self.timeout,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edyoshikun why is this needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the beginning I had to add this timeout if it was taking long time to cache. I don't think we need this and in fact if it's =0 it works fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request translation Image translation (VS)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants