You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the past two days, I've been using GCP's queued-resources to create spot TPU v4-256/v4-64, and then running the following Python script.
importjaxjax.distributed.initialize()
print(1)
However, I found that it gets stuck at the jax.distributed.initialize() command. This is very strange because when I created an on-demand TPU v4-64 two weeks ago, the jax.distributed.initialize() command executed without any issues, and it still works fine on that machine. But now, with the newly created instances, I'm facing this problem. Therefore, I'd like to seek help from the JAX team !
BUG
jax.distributed.initialize()
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 231, in initialize
global_state.initialize(coordinator_address, num_processes, process_id,
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 55, in initialize
clusters.ClusterEnv.auto_detect_unset_distributed_params(
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cluster.py", line 82, in auto_detect_unset_distributed_params
process_id = env.get_process_id()
^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 144, in get_process_id
slice_id = cls._get_slice_id()
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 159, in _get_slice_id
if has_megascale_address():
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 74, in has_megascale_address
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 71, in get_tpu_env_value
return value if value is not None else get_tpu_env_value_from_metadata(key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 59, in get_tpu_env_value_from_metadata
tpu_env_data = get_metadata('tpu-env')[0]
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 45, in get_metadata
api_resp = requests.get(
^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 73, in get
return request("get", url, params=params, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 59, in request
return session.request(method=method, url=url, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 589, in request
resp = self.send(prep, **send_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 703, in send
r = adapter.send(request, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/requests/adapters.py", line 700, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/tpu-env (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f9ac8f53c50>: Failed to establish a new connection: [Errno 111] Connection refused'))
Description
Hi JAX team,
In the past two days, I've been using GCP's queued-resources to create spot TPU v4-256/v4-64, and then running the following Python script.
However, I found that it gets stuck at the jax.distributed.initialize() command. This is very strange because when I created an on-demand TPU v4-64 two weeks ago, the jax.distributed.initialize() command executed without any issues, and it still works fine on that machine. But now, with the newly created instances, I'm facing this problem. Therefore, I'd like to seek help from the JAX team !
BUG
pip list
Setup bash
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]
jax.devices (128 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=126, process_index=31, coords=(2,3,7), core_on_chip=0) TpuDevice(id=127, process_index=31, coords=(3,3,7), core_on_chip=0)]
process_count: 32
platform: uname_result(system='Linux', node='t1v-n-db3292ae-w-17', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
The text was updated successfully, but these errors were encountered: