Skip to content

Commit

Permalink
use dataclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 14, 2024
1 parent c1f5137 commit 155e85f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 53 deletions.
39 changes: 1 addition & 38 deletions python-package/xgboost/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
from dataclasses import dataclass
from enum import IntEnum, unique
from typing import Any, Dict, Optional, TypeAlias, TypeVar, Union
from typing import Any, Dict, Optional, TypeAlias, Union

import numpy as np

Expand Down Expand Up @@ -61,43 +61,6 @@ def get_comm_config(self, args: _Args) -> _Args:
args["dmlc_timeout"] = self.timeout
return args

def to_dict(self) -> _Args:
"Convert the configuration into a dictionary."
return {
k: getattr(self, k)
for k in (
"retry",
"timeout",
"tracker_host",
"tracker_port",
"tracker_timeout",
)
}

@staticmethod
def from_dict(cfg: _Args) -> "Config":
"Create a configuration from a dictionary."
T = TypeVar("T", str, int)

def to_t(key: str, typ: T) -> Optional[T]:
v = cfg.get(key, None)
if v is None:
return v
if not isinstance(v, type(typ)):
raise TypeError(
f"Invalid type for configuration `{key}`, "
f"expecting {type(typ).__name__}, got {type(v).__name__}."
)
return v

return Config(
retry=to_t("retry", int()),
timeout=to_t("timeout", int()),
tracker_host=to_t("tracker_host", str()),
tracker_port=to_t("tracker_port", int()),
tracker_timeout=to_t("tracker_timeout", int()),
)


def init(**args: _ArgVals) -> None:
"""Initialize the collective library with arguments.
Expand Down
2 changes: 2 additions & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class LintersPaths:
# tests
"tests/python/test_config.py",
"tests/python/test_callback.py",
"tests/python/test_collective.py",
"tests/python/test_data_iterator.py",
"tests/python/test_dmatrix.py",
"tests/python/test_dt.py",
Expand Down Expand Up @@ -94,6 +95,7 @@ class LintersPaths:
# core
"python-package/",
# tests
"tests/python/test_collective.py",
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_data_iterator.py",
Expand Down
20 changes: 5 additions & 15 deletions tests/python/test_collective.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import socket
from dataclasses import asdict

import numpy as np
import pytest
Expand Down Expand Up @@ -58,13 +59,13 @@ def run_federated_worker(port: int, world_size: int, rank: int) -> int:

@pytest.mark.skipif(**tm.skip_win())
@pytest.mark.skipif(**tm.no_loky())
def test_federated_communicator():
def test_federated_communicator() -> None:
if not build_info()["USE_FEDERATED"]:
pytest.skip("XGBoost not built with federated learning enabled")

port = 9091
world_size = 2
with get_reusable_executor(max_workers=world_size+1) as pool:
with get_reusable_executor(max_workers=world_size + 1) as pool:
kwargs = {"port": port, "n_workers": world_size, "blocking": False}
tracker = pool.submit(federated.run_federated_server, **kwargs)
if not tracker.running():
Expand All @@ -81,17 +82,6 @@ def test_federated_communicator():


def test_config_serialization() -> None:
cfg = Config(
retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None
)
cfg1 = Config.from_dict(cfg.to_dict())
cfg = Config(retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None)
cfg1 = Config(**asdict(cfg))
assert cfg == cfg1

d = cfg.to_dict()
with pytest.raises(TypeError, match="retry"):
d.update({"retry": "2"})
cfg1 = Config.from_dict(d)
d = cfg.to_dict()
with pytest.raises(TypeError, match="tracker_host"):
d.update({"tracker_host": 123})
cfg1 = Config.from_dict(d)

0 comments on commit 155e85f

Please sign in to comment.