Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][RFC] TorchFT integration #806

Draft
wants to merge 8 commits into
base: gh/fegin/7/base
Choose a base branch
from
Draft

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jan 27, 2025

Stack from ghstack (oldest at bottom):

Summary
This is a WIP TorchFT integration PR.

Current Issues

This doesn't work at this moment as there are hanged groups when a new group joins.

Issue 1:
Group 0 and group 1 will hang during the first should_commit after group 1 applying the pending state_dict from group 0.

Fixed with: pytorch/torchft#83

Issue 2:
Group 0 and group 1 will pass the should_commit but group 0 needs healing which is wrong and the healing process will cause another hang.

Fixed with: pytorch/torchft#83

Issue 3:
The byproduct of issue 1 and issue 2: group 1 will continue to print out

[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>

Fixed with pytorch/torchft#91 and several other fixes.

Issue 4:
When there are 3 groups, everyone requests the state dict every step.
How to reproduce?
Using the Reproduce steps to run 2 groups, then add another group by modifying the command.

Seems to be fixed, will need more tests.

Issue 5:
Hang will happen if using functional collective.
How to reproduce?
Pull the latest version of this PR and comment out line 41 and uncomment line 42 in torchtitan/utils.py

Reproduce steps:

  1. Patch TorchFT with [WIP][RFC] Required changes for integration with TorchTitan torchft#82
  2. Execute lighthouse
  3. Execute the following command in one terminal:
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
  1. Wait 10 seconds, execute following command in another terminal:
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 27, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 91788fc1db5700ec50812469a253c868b499b824
Pull Request resolved: #806
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 27, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 27, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ea405c81e19d3185b11d6bb6ae254c0913c9e503
Pull Request resolved: #806
@fegin fegin marked this pull request as draft January 27, 2025 21:31
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 28, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4e56a9c0300d5ff58293863323bc2a66bb219229
Pull Request resolved: #806
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 29, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ea405c81e19d3185b11d6bb6ae254c0913c9e503
Pull Request resolved: #806
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
The byproduct of issue 1 and issue 2: group 1 will continue to print out
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

***How to reproduce?***
Using the following the steps in `Reproduce steps` to run 2 groups. Then kill any of the group after both start training. Remember to apply pytorch/torchft#83.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.

***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 07a02ffa43cbc1e16ff35ef9be820db52905d683
Pull Request resolved: #806
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 82fc6a4b105e7b0e9172766fb7706603f0baf653
Pull Request resolved: #806
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 514fd10bb87a82c856690d2099192bee8b901641
Pull Request resolved: #806
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py`


**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 3, 2025
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 9244a1078fa9a10e564d6c28001bb508d75a1434
Pull Request resolved: #806
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Feb 7, 2025
…llectives (#146376)

@fegin  found an issue where torchft is not compatible with functional collectives.

Found in pytorch/torchtitan#806

The root cause is because PyProcessGroup/PyWork are not compatible with functional collectives due to a nasty ownership bug.

PyWork relies on a pybind trampoline to propagate requests to Python unfortunately the way Pybind works is that the Python object owns the C++ object rather than some form of shared ownership. Thus what happens is that the PyWork Python object will collected when returned to C++ from the PyProcessGroup but the C++ PyWork object still exists. When the PyWork object is used, this causes a deadlock as the corresponding Python object no longer exists

To solve this, we introduce a new `PyWorkHolder` class which holds a reference to the `py::object` as well as the trampoline class. This resolves any dependency issues since we can now hold ownership in C++ to both the Python and C++ objects.

To make this cleaner we introduce a `WORK_OVERRIDE` macro which is a patched version of `PYBIND11_OVERRIDE` that returns a `PyWorkHolder` rather than just `PyWork` and use for all collectives in PyProcessGroup.

Test plan:

```
cd pytorch
pytest test/distributed/test_c10d_functional_native.py
```

```
cd torchft
pytest torchft/process_group_test.py -k functional -v -x -s
```

Pull Request resolved: #146376
Approved by: https://github.com/yifuwang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants