diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py index 2723b11a17..3d0a4ac894 100644 --- a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -24,6 +24,7 @@ def __init__( secret: Optional[Union[Secret, Callable]] = None, id: Optional[str] = None, host: str = "https://wandb.ai", + api_host: str = "https://api.wandb.ai", **init_kwargs: dict, ): """Weights and Biases plugin. @@ -35,6 +36,7 @@ def __init__( The callable takes no arguments and returns a string. (Required) id (str, optional): A unique id for this wandb run. host (str, optional): URL to your wandb service. The default is "https://wandb.ai". + api_host (str, optional): URL to your API Host, The default is "https://api.wandb.ai". **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see [the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details. """ @@ -51,6 +53,7 @@ def __init__( self.init_kwargs = init_kwargs self.secret = secret self.host = host + self.api_host = api_host # All kwargs need to be passed up so that the function wrapping works for both # `@wandb_init` and `@wandb_init(...)` @@ -61,6 +64,7 @@ def __init__( secret=secret, id=id, host=host, + api_host=api_host, **init_kwargs, ) @@ -81,7 +85,7 @@ def execute(self, *args, **kwargs): # Get API key with callable wandb_api_key = self.secret() - os.environ["WANDB_API_KEY"] = wandb_api_key + wandb.login(key=wandb_api_key, host=self.api_host) if self.id is None: # The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} diff --git a/plugins/flytekit-wandb/tests/test_wandb_init.py b/plugins/flytekit-wandb/tests/test_wandb_init.py index 3520acf01b..664e4a77ac 100644 --- a/plugins/flytekit-wandb/tests/test_wandb_init.py +++ b/plugins/flytekit-wandb/tests/test_wandb_init.py @@ -75,7 +75,7 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock): wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz") - assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" + wandb_mock.login.assert_called_with(key="this_is_the_secret", host="https://api.wandb.ai") def test_errors(): @@ -113,4 +113,4 @@ def test_secret_callable_remote(wandb_mock, manager_mock, os_mock): train_model_with_id_callable_secret() wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"]) - assert os_mock.environ["WANDB_API_KEY"] == get_secret() + wandb_mock.login.assert_called_with(key=get_secret(), host="https://api.wandb.ai")