Skip to content

Commit

Permalink
Update dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jan 14, 2025
1 parent d1cb39a commit 4a7feea
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 157 deletions.
10 changes: 10 additions & 0 deletions docker.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@





podman build -f docker/Dockerfile-cuda -t milabenchdev .



podman pull ghcr.io/mila-iqia/milabench:cuda-nightly
2 changes: 1 addition & 1 deletion docker/Dockerfile-cuda
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# FROM ubuntu:22.04

# For cuda-gdb
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04

# Arguments
# ---------
Expand Down
6 changes: 3 additions & 3 deletions milabench/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""This file is generated, do not modify"""

__tag__ = "v1.0.0-7-g4dc0ba9"
__commit__ = "4dc0ba9f3b35119dc1947d6dc8f6502e5d34dd50"
__date__ = "2024-12-11 14:44:23 -0500"
__tag__ = "v1.0.0-9-gd1cb39a"
__commit__ = "d1cb39a317c32988d1dd1c486b6ed010aa52bef5"
__date__ = "2025-01-13 12:27:42 -0500"
19 changes: 16 additions & 3 deletions milabench/cli/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@
import subprocess
from coleo import tooled

from ..system import get_gpu_capacity, is_loopback, resolve_hostname, gethostname


from ..system import get_gpu_capacity
from ..network import resolve_hostname


def gethostname(host):
try:
# "-oCheckHostIP=no",
# "-oPasswordAuthentication=no",
return subprocess.check_output([
"ssh",
"-oCheckHostIP=no",
"-oPasswordAuthentication=no",
"-oStrictHostKeyChecking=no", host, "cat", "/etc/hostname"], text=True).strip()
except:
print("Could not resolve hostname")
return host


def make_node_list_from_slurm(node_list):
Expand Down
153 changes: 153 additions & 0 deletions milabench/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import psutil
import socket
import ipaddress
from contextlib import contextmanager

# If true that means we cannot resolve the ip addresses
# so we ignore errors
offline = False


@contextmanager
def enable_offline(enabled):
global offline
old = offline

offline = enabled
yield


def is_loopback(address: str) -> bool:
try:
# Create an IP address object
ip = ipaddress.ip_address(address)
# Check if the address is a loopback address
return ip.is_loopback
except ValueError:
# If the address is invalid, return False
return False


def local_ips() -> set[str]:
interfaces = psutil.net_if_addrs()
ip_addresses = []

for _, addresses in interfaces.items():
for addr in addresses:
if addr.family.name == 'AF_INET': # IPv4
ip_addresses.append(addr.address)

elif addr.family.name == 'AF_INET6': # IPv6
ip_addresses.append(addr.address)

return set(ip_addresses)


def gethostbyaddr(addr):
if offline:
return addr, [addr]

try:
hostname, _, iplist = socket.gethostbyaddr(addr)

return hostname, iplist

except socket.herror:
pass
except socket.gaierror:
pass

print("Could not resolve address with DNS")
print("Use IP in your node configuration")
# This happens if we cannot do a DNS lookup for some reason
return addr, [addr]


def resolve_ip(ip):
local = local_ips()

# we are running code on `cn-l092`
# gethostbyaddr(172.16.9.192) cn-l092.server.mila.quebec ['172.16.9.192']
# gethostbyaddr(cn-l092 ) cn-l092 ['127.0.1.1'] <=
# gethostbyaddr(cn-d003 ) cn-d003.server.mila.quebec ['172.16.8.75']
# gethostbyaddr(172.16.8.75 ) cn-d003.server.mila.quebec ['172.16.8.75']

hostname, iplist = gethostbyaddr(ip)
real_ip = ip

if len(iplist) == 1 and not is_loopback(iplist[0]):
real_ip = iplist[0]

# we need this because
#
# hostname, _, ['127.0.1.1'] = socket.gethostbyaddr("cn-l092")
#
# and 127.0.1.1 is not included in our list of IPs
#
for ip_entry in iplist:
if is_loopback(ip_entry):
return hostname, real_ip, True

if local.intersection(iplist):
return hostname, real_ip, True

if hostname == socket.gethostname():
return hostname, real_ip, True

return hostname, real_ip, False


def normalize_local(node):
# Local node usually get stuck resolving local loopback
# depending on how it was configured
# this fetch the outbound ip and hostname
assert node["local"] is True

#
hostname, ip, local = resolve_ip(socket.getfqdn())

node["hostname"] = hostname
if '.' not in node["ip"]:
node["ip"] = ip

# assert local is True


def resolve_node_address(node):
hostname, ip, local = resolve_ip(node["ip"])

node["hostname"] = hostname
node["ip"] = ip
node["local"] = local

if local:
try:
# `gethostbyaddr` returns `cn-d003` but we want `cn-d003.server.mila.quebec`
# else torchrun does not recognize the main node
node["hostname"] = socket.getfqdn()

normalize_local(node)
except Exception:
print("Skipped local normalization")

return local


def resolve_addresses(nodes):
# This normalize the node ip/hostname
# for convenience we support a range of values in the IP field
# we use DNS lookup to resolve the IP/hostname and normalize the fields
#
# If DNS is not available then we just leave things as is
# we also try to find the node we are currently running code on
# we do that by simply checking all the available IP on this node
# and check which node has that IP
self = None
for node in nodes:
if resolve_node_address(node):
self = node

return self



151 changes: 1 addition & 150 deletions milabench/system.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import contextvars
from copy import deepcopy
import ipaddress
import os
import socket
import subprocess
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand All @@ -14,6 +11,7 @@

from .fs import XPath
from .merge import merge
from .network import resolve_addresses

system_global = contextvars.ContextVar("system", default=None)
multirun_global = contextvars.ContextVar("multirun", default=None)
Expand Down Expand Up @@ -316,149 +314,6 @@ def check_node_config(nodes):
assert field in node, f"The `{field}` of the node `{name}` is missing"


def get_remote_ip():
"""Get all the ip of all the network interfaces"""
if offline:
return set()

addresses = psutil.net_if_addrs()
stats = psutil.net_if_stats()

result = []

for interface, address_list in addresses.items():
for address in address_list:
# if address.family in (socket.AF_INET, socket.AF_INET6):
if interface in stats and getattr(stats[interface], "isup"):
result.append(address.address)

return set(result)


def is_loopback(address: str) -> bool:
try:
# Create an IP address object
ip = ipaddress.ip_address(address)
# Check if the address is a loopback address
return ip.is_loopback
except ValueError:
# If the address is invalid, return False
return False



# If true that means we cannot resolve the ip addresses
# so we ignore errors
offline = True


@contextmanager
def enable_offline(enabled):
global offline
old = offline

offline = enabled
yield
offline = old


def _resolve_addresses(nodes):
# Note: it is possible for self to be none
# if we are running milabench on a node that is not part of the system
# in that case it should still work; the local is then going to
# ssh into the main node which will dispatch the work to the other nodes
self = None
lazy_raise = None
ip_list = get_remote_ip()

for node in nodes:
ip = node["ip"]

is_local = is_loopback(ip)

if ip in ip_list:
is_local = True

node["local"] = is_local

if is_local:
node["hostname"] = socket.gethostname()

if is_local and self is None:
self = node
node["ipaddrlist"] = list(set(list(ip_list)))

# if self is node we might be outisde the cluster
# which explains why we could not resolve the IP of the nodes
if not offline:
if self is not None and lazy_raise:
raise RuntimeError("Could not resolve node ip") from lazy_raise

return self


def gethostname(host):
try:
# "-oCheckHostIP=no",
# "-oPasswordAuthentication=no",
return subprocess.check_output([
"ssh",
"-oCheckHostIP=no",
"-oPasswordAuthentication=no",
"-oStrictHostKeyChecking=no", host, "cat", "/etc/hostname"], text=True).strip()
except:
print("Could not resolve hostname")
return host


def resolve_hostname(ip):
try:
hostname, _, iplist = socket.gethostbyaddr(ip)

for ip in iplist:
if is_loopback(ip):
return hostname, True

# FIXME
return socket.gethostname(), hostname.startswith(socket.gethostname())
return hostname, hostname == socket.gethostname()

except:
if offline:
return ip, False

raise

def resolve_node_address(node):
hostname, local = resolve_hostname(node["ip"])

node["hostname"] = hostname
node["local"] = local

if local:
# `gethostbyaddr` returns `cn-d003` but we want `cn-d003.server.mila.quebec`
# else torchrun does not recognize the main node
node["hostname"] = socket.gethostname()

return local


def resolve_addresses(nodes):
if offline:
for n in nodes:
n["hostname"] = n["ip"]

return nodes[0]

self = None

for node in nodes:
if resolve_node_address(node):
self = node

return self


def build_system_config(config_file, defaults=None, gpu=True):
"""Load the system configuration, verify its validity and resolve ip addresses
Expand Down Expand Up @@ -537,7 +392,3 @@ def compact(d, depth):
print(json.dumps(config, indent=2))
else:
compact(config, 0)


if __name__ == "__main__":
show_overrides()
Loading

0 comments on commit 4a7feea

Please sign in to comment.