Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.distributed.initialize() crash #24399

Open
demon2036 opened this issue Oct 19, 2024 · 1 comment
Open

jax.distributed.initialize() crash #24399

demon2036 opened this issue Oct 19, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@demon2036
Copy link

Description

Hi JAX team,

In the past two days, I've been using GCP's queued-resources to create spot TPU v4-256/v4-64, and then running the following Python script.

import jax
jax.distributed.initialize()
print(1)

However, I found that it gets stuck at the jax.distributed.initialize() command. This is very strange because when I created an on-demand TPU v4-64 two weeks ago, the jax.distributed.initialize() command executed without any issues, and it still works fine on that machine. But now, with the newly created instances, I'm facing this problem. Therefore, I'd like to seek help from the JAX team !

BUG


    jax.distributed.initialize()
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 231, in initialize
    global_state.initialize(coordinator_address, num_processes, process_id,
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 55, in initialize
    clusters.ClusterEnv.auto_detect_unset_distributed_params(
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cluster.py", line 82, in auto_detect_unset_distributed_params
    process_id = env.get_process_id()
                 ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 144, in get_process_id
    slice_id = cls._get_slice_id()
               ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 159, in _get_slice_id
    if has_megascale_address():
       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 74, in has_megascale_address
    return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 71, in get_tpu_env_value
    return value if value is not None else get_tpu_env_value_from_metadata(key)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 59, in get_tpu_env_value_from_metadata
    tpu_env_data = get_metadata('tpu-env')[0]
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 45, in get_metadata
    api_resp = requests.get(
               ^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 73, in get
    return request("get", url, params=params, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 589, in request
    resp = self.send(prep, **send_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 703, in send
    r = adapter.send(request, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/adapters.py", line 700, in send
    raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/tpu-env (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f9ac8f53c50>: Failed to establish a new connection: [Errno 111] Connection refused'))

pip list

Package                 Versionw-3:~# 
----------------------- -----------------------
anaconda-anon-usage     0.4.4
archspec                0.2.3
boltons                 23.0.0
Brotli                  1.0.9
certifi                 2024.7.4
cffi                    1.16.0
charset-normalizer      3.3.2
conda                   24.7.1
conda-content-trust     0.2.0
conda-libmamba-solver   24.7.0
conda-package-handling  2.3.0
conda_package_streaming 0.10.0
cryptography            42.0.5
distro                  1.9.0
frozendict              2.4.2
idna                    3.7
jax                     0.4.34
jaxlib                  0.4.34
jsonpatch               1.33
jsonpointer             2.1
libmambapy              1.5.8
libtpu-nightly          0.1.dev20241002+nightly
menuinst                2.1.2
ml_dtypes               0.5.0
numpy                   2.1.2
opt_einsum              3.4.0
packaging               24.1
pip                     24.2
platformdirs            3.10.0
pluggy                  1.0.0
pycosat                 0.6.6
pycparser               2.21
PySocks                 1.7.1
requests                2.32.3
ruamel.yaml             0.17.21
scipy                   1.14.1
setuptools              72.1.0
tqdm                    4.66.4
truststore              0.8.0
urllib3                 2.2.2
wheel                   0.43.0
zstandard               0.22.0

Setup bash

rm -rf ~/miniconda3

wget https://repo.anaconda.com/miniconda/Miniconda3-py311_24.7.1-0-Linux-x86_64.sh
bash Miniconda3-py311_24.7.1-0-Linux-x86_64.sh -b -u
rm Miniconda3-py311_24.7.1-0-Linux-x86_64.sh


~/miniconda3/bin/conda init bash
eval "$(~/miniconda3/bin/conda shell.bash hook)"


# 2. Install requirements.
pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]
jax.devices (128 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=126, process_index=31, coords=(2,3,7), core_on_chip=0) TpuDevice(id=127, process_index=31, coords=(3,3,7), core_on_chip=0)]
process_count: 32
platform: uname_result(system='Linux', node='t1v-n-db3292ae-w-17', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

@demon2036 demon2036 added the bug Something isn't working label Oct 19, 2024
@thiagolaitz
Copy link

I'm facing the same problem..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants