From 4a7feea364ce7f07c87ba26de229fca043a2eb65 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Tue, 14 Jan 2025 11:24:51 -0500 Subject: [PATCH] Update dockerfile --- docker.txt | 10 +++ docker/Dockerfile-cuda | 2 +- milabench/_version.py | 6 +- milabench/cli/slurm.py | 19 ++++- milabench/network.py | 153 +++++++++++++++++++++++++++++++++++++++++ milabench/system.py | 151 +--------------------------------------- tests/test_network.py | 68 ++++++++++++++++++ 7 files changed, 252 insertions(+), 157 deletions(-) create mode 100644 docker.txt create mode 100644 milabench/network.py create mode 100644 tests/test_network.py diff --git a/docker.txt b/docker.txt new file mode 100644 index 000000000..5c6711d93 --- /dev/null +++ b/docker.txt @@ -0,0 +1,10 @@ + + + + + + podman build -f docker/Dockerfile-cuda -t milabenchdev . + + + + podman pull ghcr.io/mila-iqia/milabench:cuda-nightly \ No newline at end of file diff --git a/docker/Dockerfile-cuda b/docker/Dockerfile-cuda index 6e7641844..9647cd2a8 100644 --- a/docker/Dockerfile-cuda +++ b/docker/Dockerfile-cuda @@ -1,7 +1,7 @@ # FROM ubuntu:22.04 # For cuda-gdb -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 # Arguments # --------- diff --git a/milabench/_version.py b/milabench/_version.py index 8f0797634..12face4da 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v1.0.0-7-g4dc0ba9" -__commit__ = "4dc0ba9f3b35119dc1947d6dc8f6502e5d34dd50" -__date__ = "2024-12-11 14:44:23 -0500" +__tag__ = "v1.0.0-9-gd1cb39a" +__commit__ = "d1cb39a317c32988d1dd1c486b6ed010aa52bef5" +__date__ = "2025-01-13 12:27:42 -0500" diff --git a/milabench/cli/slurm.py b/milabench/cli/slurm.py index 35f1fe94e..55470c73b 100644 --- a/milabench/cli/slurm.py +++ b/milabench/cli/slurm.py @@ -4,9 +4,22 @@ import subprocess from coleo import tooled -from ..system import get_gpu_capacity, is_loopback, resolve_hostname, gethostname - - +from ..system import get_gpu_capacity +from ..network import resolve_hostname + + +def gethostname(host): + try: + # "-oCheckHostIP=no", + # "-oPasswordAuthentication=no", + return subprocess.check_output([ + "ssh", + "-oCheckHostIP=no", + "-oPasswordAuthentication=no", + "-oStrictHostKeyChecking=no", host, "cat", "/etc/hostname"], text=True).strip() + except: + print("Could not resolve hostname") + return host def make_node_list_from_slurm(node_list): diff --git a/milabench/network.py b/milabench/network.py new file mode 100644 index 000000000..90bfad200 --- /dev/null +++ b/milabench/network.py @@ -0,0 +1,153 @@ +import psutil +import socket +import ipaddress +from contextlib import contextmanager + +# If true that means we cannot resolve the ip addresses +# so we ignore errors +offline = False + + +@contextmanager +def enable_offline(enabled): + global offline + old = offline + + offline = enabled + yield + + +def is_loopback(address: str) -> bool: + try: + # Create an IP address object + ip = ipaddress.ip_address(address) + # Check if the address is a loopback address + return ip.is_loopback + except ValueError: + # If the address is invalid, return False + return False + + +def local_ips() -> set[str]: + interfaces = psutil.net_if_addrs() + ip_addresses = [] + + for _, addresses in interfaces.items(): + for addr in addresses: + if addr.family.name == 'AF_INET': # IPv4 + ip_addresses.append(addr.address) + + elif addr.family.name == 'AF_INET6': # IPv6 + ip_addresses.append(addr.address) + + return set(ip_addresses) + + +def gethostbyaddr(addr): + if offline: + return addr, [addr] + + try: + hostname, _, iplist = socket.gethostbyaddr(addr) + + return hostname, iplist + + except socket.herror: + pass + except socket.gaierror: + pass + + print("Could not resolve address with DNS") + print("Use IP in your node configuration") + # This happens if we cannot do a DNS lookup for some reason + return addr, [addr] + + +def resolve_ip(ip): + local = local_ips() + + # we are running code on `cn-l092` + # gethostbyaddr(172.16.9.192) cn-l092.server.mila.quebec ['172.16.9.192'] + # gethostbyaddr(cn-l092 ) cn-l092 ['127.0.1.1'] <= + # gethostbyaddr(cn-d003 ) cn-d003.server.mila.quebec ['172.16.8.75'] + # gethostbyaddr(172.16.8.75 ) cn-d003.server.mila.quebec ['172.16.8.75'] + + hostname, iplist = gethostbyaddr(ip) + real_ip = ip + + if len(iplist) == 1 and not is_loopback(iplist[0]): + real_ip = iplist[0] + + # we need this because + # + # hostname, _, ['127.0.1.1'] = socket.gethostbyaddr("cn-l092") + # + # and 127.0.1.1 is not included in our list of IPs + # + for ip_entry in iplist: + if is_loopback(ip_entry): + return hostname, real_ip, True + + if local.intersection(iplist): + return hostname, real_ip, True + + if hostname == socket.gethostname(): + return hostname, real_ip, True + + return hostname, real_ip, False + + +def normalize_local(node): + # Local node usually get stuck resolving local loopback + # depending on how it was configured + # this fetch the outbound ip and hostname + assert node["local"] is True + + # + hostname, ip, local = resolve_ip(socket.getfqdn()) + + node["hostname"] = hostname + if '.' not in node["ip"]: + node["ip"] = ip + + # assert local is True + + +def resolve_node_address(node): + hostname, ip, local = resolve_ip(node["ip"]) + + node["hostname"] = hostname + node["ip"] = ip + node["local"] = local + + if local: + try: + # `gethostbyaddr` returns `cn-d003` but we want `cn-d003.server.mila.quebec` + # else torchrun does not recognize the main node + node["hostname"] = socket.getfqdn() + + normalize_local(node) + except Exception: + print("Skipped local normalization") + + return local + + +def resolve_addresses(nodes): + # This normalize the node ip/hostname + # for convenience we support a range of values in the IP field + # we use DNS lookup to resolve the IP/hostname and normalize the fields + # + # If DNS is not available then we just leave things as is + # we also try to find the node we are currently running code on + # we do that by simply checking all the available IP on this node + # and check which node has that IP + self = None + for node in nodes: + if resolve_node_address(node): + self = node + + return self + + + diff --git a/milabench/system.py b/milabench/system.py index 691d06bd9..e00f55857 100644 --- a/milabench/system.py +++ b/milabench/system.py @@ -1,9 +1,6 @@ import contextvars from copy import deepcopy -import ipaddress import os -import socket -import subprocess import sys from contextlib import contextmanager from dataclasses import dataclass, field @@ -14,6 +11,7 @@ from .fs import XPath from .merge import merge +from .network import resolve_addresses system_global = contextvars.ContextVar("system", default=None) multirun_global = contextvars.ContextVar("multirun", default=None) @@ -316,149 +314,6 @@ def check_node_config(nodes): assert field in node, f"The `{field}` of the node `{name}` is missing" -def get_remote_ip(): - """Get all the ip of all the network interfaces""" - if offline: - return set() - - addresses = psutil.net_if_addrs() - stats = psutil.net_if_stats() - - result = [] - - for interface, address_list in addresses.items(): - for address in address_list: - # if address.family in (socket.AF_INET, socket.AF_INET6): - if interface in stats and getattr(stats[interface], "isup"): - result.append(address.address) - - return set(result) - - -def is_loopback(address: str) -> bool: - try: - # Create an IP address object - ip = ipaddress.ip_address(address) - # Check if the address is a loopback address - return ip.is_loopback - except ValueError: - # If the address is invalid, return False - return False - - - -# If true that means we cannot resolve the ip addresses -# so we ignore errors -offline = True - - -@contextmanager -def enable_offline(enabled): - global offline - old = offline - - offline = enabled - yield - offline = old - - -def _resolve_addresses(nodes): - # Note: it is possible for self to be none - # if we are running milabench on a node that is not part of the system - # in that case it should still work; the local is then going to - # ssh into the main node which will dispatch the work to the other nodes - self = None - lazy_raise = None - ip_list = get_remote_ip() - - for node in nodes: - ip = node["ip"] - - is_local = is_loopback(ip) - - if ip in ip_list: - is_local = True - - node["local"] = is_local - - if is_local: - node["hostname"] = socket.gethostname() - - if is_local and self is None: - self = node - node["ipaddrlist"] = list(set(list(ip_list))) - - # if self is node we might be outisde the cluster - # which explains why we could not resolve the IP of the nodes - if not offline: - if self is not None and lazy_raise: - raise RuntimeError("Could not resolve node ip") from lazy_raise - - return self - - -def gethostname(host): - try: - # "-oCheckHostIP=no", - # "-oPasswordAuthentication=no", - return subprocess.check_output([ - "ssh", - "-oCheckHostIP=no", - "-oPasswordAuthentication=no", - "-oStrictHostKeyChecking=no", host, "cat", "/etc/hostname"], text=True).strip() - except: - print("Could not resolve hostname") - return host - - -def resolve_hostname(ip): - try: - hostname, _, iplist = socket.gethostbyaddr(ip) - - for ip in iplist: - if is_loopback(ip): - return hostname, True - - # FIXME - return socket.gethostname(), hostname.startswith(socket.gethostname()) - return hostname, hostname == socket.gethostname() - - except: - if offline: - return ip, False - - raise - -def resolve_node_address(node): - hostname, local = resolve_hostname(node["ip"]) - - node["hostname"] = hostname - node["local"] = local - - if local: - # `gethostbyaddr` returns `cn-d003` but we want `cn-d003.server.mila.quebec` - # else torchrun does not recognize the main node - node["hostname"] = socket.gethostname() - - return local - - -def resolve_addresses(nodes): - if offline: - for n in nodes: - n["hostname"] = n["ip"] - - return nodes[0] - - self = None - - for node in nodes: - if resolve_node_address(node): - self = node - - return self - - def build_system_config(config_file, defaults=None, gpu=True): """Load the system configuration, verify its validity and resolve ip addresses @@ -537,7 +392,3 @@ def compact(d, depth): print(json.dumps(config, indent=2)) else: compact(config, 0) - - -if __name__ == "__main__": - show_overrides() diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 000000000..4820c01da --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,68 @@ +from milabench.network import resolve_addresses, enable_offline + +from copy import deepcopy +import pytest + + +# test that it works without DNS +# podman run --rm -it --dns 0.0.0.0 -v "$(pwd):/mnt" python-with-psutil bash + +cases = [ + { + "ip": "172.16.9.192" + }, + { + "ip": "cn-l092" + }, + { + "ip": "cn-d003" + }, + { + "ip": "cn-l092.server.mila.quebec" + }, + { + "ip": "cn-d003.server.mila.quebec" + }, + { + "ip": "172.16.8.75" + } +] + + +@pytest.mark.skip(reason="those hostnames mean nothing in the CI") +def test_network(): + nodes = deepcopy(cases) + print(resolve_addresses(nodes)) + print() + for n in nodes: + print(n) + + +@pytest.mark.skip(reason="those hostnames mean nothing in the CI") +def test_no_dns_network(): + with enable_offline(): + nodes = deepcopy(cases) + print(resolve_addresses(nodes)) + print() + for n in nodes: + print(n) + + +def check_dns(): + nodes = deepcopy(cases) + print(resolve_addresses(nodes)) + print() + for n in nodes: + print(n) + + print("===") + with enable_offline(): + nodes = deepcopy(cases) + print(resolve_addresses(nodes)) + print() + for n in nodes: + print(n) + + +if __name__ == "__main__": + check_dns()