Skip to content

Commit

Permalink
Use wandb.login instead of environment variable
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan committed May 29, 2024
1 parent f54688f commit 8480849
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
secret: Optional[Union[Secret, Callable]] = None,
id: Optional[str] = None,
host: str = "https://wandb.ai",
api_host: str = "https://api.wandb.ai",
**init_kwargs: dict,
):
"""Weights and Biases plugin.
Expand All @@ -35,6 +36,7 @@ def __init__(
The callable takes no arguments and returns a string. (Required)
id (str, optional): A unique id for this wandb run.
host (str, optional): URL to your wandb service. The default is "https://wandb.ai".
api_host (str, optional): URL to your API Host, The default is "https://api.wandb.ai".
**init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see
[the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details.
"""
Expand All @@ -51,6 +53,7 @@ def __init__(
self.init_kwargs = init_kwargs
self.secret = secret
self.host = host
self.api_host = api_host

# All kwargs need to be passed up so that the function wrapping works for both
# `@wandb_init` and `@wandb_init(...)`
Expand All @@ -61,6 +64,7 @@ def __init__(
secret=secret,
id=id,
host=host,
api_host=api_host,
**init_kwargs,
)

Expand All @@ -81,7 +85,7 @@ def execute(self, *args, **kwargs):
# Get API key with callable
wandb_api_key = self.secret()

os.environ["WANDB_API_KEY"] = wandb_api_key
wandb.login(key=wandb_api_key, host=self.api_host)

if self.id is None:
# The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt}
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-wandb/tests/test_wandb_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock):

wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"])
ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz")
assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret"
wandb_mock.login.assert_called_with(key="this_is_the_secret", host="https://api.wandb.ai")


def test_errors():
Expand Down Expand Up @@ -113,4 +113,4 @@ def test_secret_callable_remote(wandb_mock, manager_mock, os_mock):
train_model_with_id_callable_secret()

wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"])
assert os_mock.environ["WANDB_API_KEY"] == get_secret()
wandb_mock.login.assert_called_with(key=get_secret(), host="https://api.wandb.ai")

0 comments on commit 8480849

Please sign in to comment.