Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistNetworkError when using multiprocessing_context parameter in pytorch dataloader #20516

Closed
forestbat opened this issue Dec 21, 2024 · 4 comments
Labels
ver: 2.4.x waiting on author Waiting on user action, correction, or update

Comments

@forestbat
Copy link

forestbat commented Dec 21, 2024

Bug description

Because of some special reasons I want to use spawn method to create worker in DataLoader of Pytorch, but it crashed with this error in topic.

Port 55733 is listened by training processes before so it will crash. But I want to know, why port will be bind repeatedly when multiprocessing_context is spawn?

Update: when I use #pytorch only, the problem disappeared. It occurs in lightning.Fabric.

Hope for your reply.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import lightning

fabric = lightning.Fabric(devices=[0, 2], num_nodes=1, strategy='ddp')
fabric.launch()

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)  

    def forward(self, x):
        return self.linear(x)


if __name__ == '__main__':
    x = torch.randn(100, 10)
    y = torch.rand(100, 2)
    dataset = TensorDataset(x, y)
    # crashed because of multiprocessing_context='spawn', 'forkserver' has same problem
    train_loader = fabric.setup_dataloaders(DataLoader(dataset, batch_size=10, shuffle=True, 
                   num_workers=1, multiprocessing_context='spawn'))
    model = LinearModel()
    crit = nn.MSELoss()
    model, optimizer = fabric.setup(model, optim.Adam(model.parameters(), lr=0.01))
    for epoch in range(0, 10):
        print(f'Epoch {epoch}')
        for xs, ys in train_loader:
            output = model(xs)
            loss = crit(output, ys)
            fabric.backward(loss)
            optimizer.step()

Error messages and logs

# https://pastebin.com/BqA9mjiE
Epoch 0
Epoch 0
……
torch.distributed.DistNetworkError: The server socket has failed to listen on any local network address. 
The server socket has failed to bind to [::]:55733 (errno: 98 - Address already in use). 
The server socket has failed to bind to 0.0.0.0:55733 (errno: 98 - Address already in use).

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA RTX 5000 Ada Generation
      • NVIDIA A40
      • NVIDIA A40
    • available: True
    • version: 12.1
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.9
    • pytorch-lightning: 2.4.0
    • torch: 2.2.2
    • torchaudio: 2.2.2
    • torchdata: 0.7.1
    • torchmetrics: 1.6.0
    • torchvision: 0.17.2
  • Packages:
    • absl-py: 2.1.0
    • affine: 2.4.0
    • aiobotocore: 2.13.2
    • aiodns: 3.2.0
    • aiohappyeyeballs: 2.3.7
    • aiohttp: 3.10.4
    • aiohttp-client-cache: 0.11.1
    • aioitertools: 0.11.0
    • aiosignal: 1.3.1
    • aiosqlite: 0.20.0
    • annotated-types: 0.7.0
    • appdirs: 1.4.4
    • argon2-cffi: 23.1.0
    • argon2-cffi-bindings: 21.2.0
    • asciitree: 0.3.3
    • async-retriever: 0.17.0
    • attrs: 24.2.0
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • black: 24.8.0
    • bleach: 6.1.0
    • bokeh: 3.5.1
    • boto3: 1.34.131
    • botocore: 1.34.131
    • branca: 0.7.2
    • brotli: 1.1.0
    • bump2version: 1.0.1
    • cachetools: 5.5.0
    • cartopy: 0.23.0
    • cattrs: 23.2.3
    • certifi: 2024.8.30
    • cffi: 1.17.0
    • cfgrib: 0.9.14.0
    • cftime: 1.6.4
    • chardet: 5.2.0
    • charset-normalizer: 3.3.2
    • click: 8.1.7
    • click-plugins: 1.1.1
    • cligj: 0.7.2
    • cloudpickle: 3.0.0
    • codetiming: 1.4.0
    • colorama: 0.4.6
    • contourpy: 1.2.1
    • cryptography: 43.0.0
    • cupy: 13.3.0
    • cycler: 0.12.1
    • cytoolz: 0.12.3
    • dask: 2024.8.1
    • dask-expr: 1.1.11
    • dataretrieval: 1.0.10
    • deepspeed: 0.16.1
    • defusedxml: 0.7.1
    • dgl: 2.2.1+cu121
    • distributed: 2024.8.1
    • docutils: 0.21.2
    • eccodes: 1.7.1
    • einops: 0.8.0
    • et-xmlfile: 1.1.0
    • exceptiongroup: 1.2.2
    • fasteners: 0.19
    • fastrlock: 0.8.2
    • filelock: 3.15.4
    • findlibs: 0.0.5
    • flake8: 7.1.1
    • flexcache: 0.3
    • flexparser: 0.3.1
    • folium: 0.17.0
    • fonttools: 4.53.1
    • frozenlist: 1.4.1
    • fsspec: 2024.6.1
    • geopandas: 1.0.1
    • gmpy2: 2.1.5
    • greenlet: 3.0.3
    • grpcio: 1.62.2
    • h2: 4.1.0
    • h5netcdf: 1.3.0
    • h5py: 3.11.0
    • hjson: 3.1.0
    • hpack: 4.0.0
    • hydrodataset: 0.1.13
    • hydrodatasource: 0.0.8
    • hydroerr: 1.24
    • hydrosignatures: 0.17.0
    • hydrotopo: 0.0.6
    • hydroutils: 0.0.12
    • hyperframe: 6.0.1
    • idna: 3.7
    • igraph: 0.11.6
    • importlib-metadata: 8.2.0
    • importlib-resources: 6.4.0
    • inflect: 7.3.1
    • iniconfig: 2.0.0
    • intake: 2.0.6
    • itsdangerous: 2.2.0
    • jaraco.classes: 3.4.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.2
    • jaraco.text: 3.12.1
    • jeepney: 0.8.0
    • jinja2: 3.1.4
    • jmespath: 1.0.1
    • joblib: 1.4.2
    • kaggle: 1.6.17
    • kerchunk: 0.2.6
    • keyring: 25.3.0
    • kiwisolver: 1.4.5
    • lightning: 2.4.0
    • lightning-utilities: 0.11.9
    • llvmlite: 0.43.0
    • locket: 1.0.0
    • loguru: 0.7.2
    • lxml: 5.3.0
    • lz4: 4.3.3
    • markdown: 3.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • matplotlib: 3.9.2
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • minio: 7.2.8
    • more-itertools: 10.4.0
    • mpmath: 1.3.0
    • msgpack: 1.0.8
    • multidict: 6.0.5
    • mypy-extensions: 1.0.0
    • netcdf4: 1.7.1.post2
    • networkx: 3.3
    • nh3: 0.2.18
    • ninja: 1.11.1.3
    • nuitka: 2.4.7
    • numba: 0.60.0
    • numcodecs: 0.13.0
    • numpy: 1.26.4
    • nvidia-ml-py: 12.535.161
    • nvitop: 1.3.2
    • openpyxl: 3.1.5
    • ordered-set: 4.1.0
    • owslib: 0.31.0
    • packaging: 24.1
    • pandas: 2.2.2
    • partd: 1.4.2
    • pathspec: 0.12.1
    • pillow: 10.4.0
    • pint: 0.24.3
    • pint-pandas: 0.6.2
    • pint-xarray: 0.4
    • pip: 24.2
    • pkginfo: 1.10.0
    • platformdirs: 4.2.2
    • pluggy: 1.5.0
    • polars: 1.17.1
    • protobuf: 4.25.3
    • psutil: 6.0.0
    • psycopg2-binary: 2.9.9
    • py-cpuinfo: 9.0.0
    • pyarrow: 17.0.0
    • pyarrow-hotfix: 0.6
    • pycairo: 1.27.0
    • pycares: 4.4.0
    • pycodestyle: 2.12.1
    • pycparser: 2.22
    • pycryptodome: 3.20.0
    • pydantic: 2.8.2
    • pydantic-core: 2.20.1
    • pyflakes: 3.2.0
    • pygeohydro: 0.17.0
    • pygeoogc: 0.17.0
    • pygeoutils: 0.17.0
    • pygments: 2.18.0
    • pykalman: 0.9.7
    • pynhd: 0.17.0
    • pyogrio: 0.9.0
    • pyparsing: 3.1.2
    • pyproj: 3.6.1
    • pyshp: 2.3.1
    • pysocks: 1.7.1
    • pytest: 8.3.2
    • python-dateutil: 2.9.0
    • python-slugify: 8.0.4
    • pytorch-lightning: 2.4.0
    • pytz: 2024.1
    • pyyaml: 6.0.2
    • rasterio: 1.3.10
    • readme-renderer: 44.0
    • requests: 2.32.3
    • requests-cache: 1.2.1
    • requests-toolbelt: 1.0.0
    • rfc3986: 2.0.0
    • rich: 13.7.1
    • rioxarray: 0.17.0
    • s3fs: 2024.6.1
    • s3transfer: 0.10.2
    • scikit-learn: 1.5.1
    • scipy: 1.14.0
    • seaborn: 0.13.2
    • secretstorage: 3.3.3
    • setuptools: 72.2.0
    • shap: 0.45.1
    • shapely: 2.0.1
    • six: 1.16.0
    • slicer: 0.0.8
    • snuggs: 1.4.7
    • sortedcontainers: 2.4.0
    • sqlalchemy: 2.0.32
    • sympy: 1.13.2
    • tblib: 3.0.0
    • tbparse: 0.0.9
    • tensorboard: 2.17.1
    • tensorboard-data-server: 0.7.0
    • termcolor: 2.5.0
    • text-unidecode: 1.3
    • texttable: 1.7.0
    • threadpoolctl: 3.5.0
    • tomli: 2.0.1
    • toolz: 0.12.1
    • torch: 2.2.2
    • torchaudio: 2.2.2
    • torchdata: 0.7.1
    • torchmetrics: 1.6.0
    • torchvision: 0.17.2
    • tornado: 6.4.1
    • tqdm: 4.66.5
    • triton: 2.2.0
    • twine: 5.1.1
    • typeguard: 4.3.0
    • typing-extensions: 4.12.2
    • tzdata: 2024.1
    • tzfpy: 0.15.5
    • ujson: 5.10.0
    • url-normalize: 1.4.3
    • urllib3: 2.2.2
    • webencodings: 0.5.1
    • werkzeug: 3.0.3
    • wget: 3.2
    • wheel: 0.44.0
    • wrapt: 1.16.0
    • xarray: 2024.7.0
    • xlrd: 2.0.1
    • xyzservices: 2024.6.0
    • yarl: 1.9.4
    • zarr: 2.18.2
    • zict: 3.0.0
    • zipp: 3.20.0
    • zstandard: 0.23.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.9
    • release: 5.4.0-195-generic
    • version: Demos #215-Ubuntu SMP Fri Aug 2 18:28:05 UTC 2024

More info

No response

@forestbat forestbat added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 21, 2024
@forestbat
Copy link
Author

A week has passed, where are you? @lantiga

@lantiga
Copy link
Collaborator

lantiga commented Dec 29, 2024

🎄

@lantiga
Copy link
Collaborator

lantiga commented Jan 6, 2025

Hey @forestbat, thanks for your patience.

The issue here is that calling fabric.launch in the global context will cause TCPStore to reinitialize when new processes are started, which will lead to the port clashing. The solution here is calling fabric.launch in __main__ so that

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import lightning

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)  

    def forward(self, x):
        return self.linear(x)


if __name__ == '__main__':
    x = torch.randn(100, 10)
    y = torch.rand(100, 2)
    dataset = TensorDataset(x, y)

    fabric = lightning.Fabric(devices=[0, 2], num_nodes=1, strategy='ddp')
    fabric.launch()

    # crashed because of multiprocessing_context='spawn', 'forkserver' has same problem
    train_loader = fabric.setup_dataloaders(DataLoader(dataset, batch_size=10, shuffle=True, 
                   num_workers=1, multiprocessing_context='spawn'))

    model = LinearModel()
    crit = nn.MSELoss()
    model, optimizer = fabric.setup(model, optim.Adam(model.parameters(), lr=0.01))

    for epoch in range(0, 10):
        print(f'Epoch {epoch}')
        for xs, ys in train_loader:
            output = model(xs)
            loss = crit(output, ys)
            fabric.backward(loss)
            optimizer.step()

@lantiga lantiga added waiting on author Waiting on user action, correction, or update and removed bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 6, 2025
@forestbat
Copy link
Author

Thanks for your reply,now I‘m fighting with mysterious deadlock but I will probably open a new issue.

forestbat added a commit to iHeadWater/torchhydro that referenced this issue Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ver: 2.4.x waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

2 participants