diff --git a/neptune/internal/backends/hosted_neptune_backend.py b/neptune/internal/backends/hosted_neptune_backend.py index e0991b9ad..1fff2cf1a 100644 --- a/neptune/internal/backends/hosted_neptune_backend.py +++ b/neptune/internal/backends/hosted_neptune_backend.py @@ -46,7 +46,7 @@ from neptune.notebook import Notebook from neptune.oauth import NeptuneAuthenticator from neptune.projects import Project -from neptune.utils import is_float, with_api_exceptions_handler +from neptune.utils import is_float, with_api_exceptions_handler, update_session_proxies _logger = logging.getLogger(__name__) @@ -63,8 +63,8 @@ def __init__(self, api_token=None, proxies=None): ssl_verify = False self._http_client = RequestsClient(ssl_verify=ssl_verify) - if proxies is not None: - self._update_proxies(proxies) + + update_session_proxies(self._http_client.session, proxies) self.backend_swagger_client = self._get_swagger_client('{}/api/backend/swagger.json' .format(self.api_address)) @@ -72,7 +72,7 @@ def __init__(self, api_token=None, proxies=None): self.leaderboard_swagger_client = self._get_swagger_client('{}/api/leaderboard/swagger.json' .format(self.api_address)) - self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify) + self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies) self._http_client.authenticator = self.authenticator # This is not a top-level import because of circular dependencies @@ -872,12 +872,6 @@ def _upload_tar_data(self, experiment, api_method, data): response.raise_for_status() return response - def _update_proxies(self, proxies): - try: - self._http_client.session.proxies.update(proxies) - except (TypeError, ValueError): - raise ValueError("Wrong proxies format: {}".format(proxies)) - @with_api_exceptions_handler def _get_swagger_client(self, url): return SwaggerClient.from_url( @@ -892,10 +886,11 @@ def _get_swagger_client(self, url): ) @with_api_exceptions_handler - def _create_authenticator(self, api_token, ssl_verify): + def _create_authenticator(self, api_token, ssl_verify, proxies): return NeptuneAuthenticator( self.backend_swagger_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result, - ssl_verify + ssl_verify, + proxies ) diff --git a/neptune/oauth.py b/neptune/oauth.py index 120d43725..0afc6cf5c 100644 --- a/neptune/oauth.py +++ b/neptune/oauth.py @@ -21,7 +21,7 @@ from requests.auth import AuthBase from requests_oauthlib import OAuth2Session -from neptune.utils import with_api_exceptions_handler +from neptune.utils import with_api_exceptions_handler, update_session_proxies class NeptuneAuth(AuthBase): @@ -59,7 +59,7 @@ def _refresh_token(self): class NeptuneAuthenticator(Authenticator): - def __init__(self, auth_tokens, ssl_verify): + def __init__(self, auth_tokens, ssl_verify, proxies): super(NeptuneAuthenticator, self).__init__(host='') decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False) expires_at = decoded_json_token.get(u'exp') @@ -78,6 +78,9 @@ def __init__(self, auth_tokens, ssl_verify): token_updater=_no_token_updater ) session.verify = ssl_verify + + update_session_proxies(session, proxies) + self.auth = NeptuneAuth(session) def matches(self, url): diff --git a/neptune/utils.py b/neptune/utils.py index 2bc680adc..3c9efb120 100644 --- a/neptune/utils.py +++ b/neptune/utils.py @@ -139,6 +139,14 @@ def discover_git_repo_location(): return None +def update_session_proxies(session, proxies): + if proxies is not None: + try: + session.proxies.update(proxies) + except (TypeError, ValueError): + raise ValueError("Wrong proxies format: {}".format(proxies)) + + def get_git_info(repo_path=None): """Retrieve information about git repository. diff --git a/tests/neptune/test_oauth.py b/tests/neptune/test_oauth.py index ecd384bf6..9f2d73199 100644 --- a/tests/neptune/test_oauth.py +++ b/tests/neptune/test_oauth.py @@ -102,7 +102,7 @@ def test_apply_oauth2_session_to_request(self, time_mock, session_mock): session.token = dict() # and - neptune_authenticator = NeptuneAuthenticator(auth_tokens, False) + neptune_authenticator = NeptuneAuthenticator(auth_tokens, False, None) request = a_request() # when