Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use retries for JWT fetching #886

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 84 additions & 51 deletions qcportal/qcportal/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,84 @@ def encoding(self, encoding: str):
enc_headers = {"Content-Type": encoding, "Accept": encoding}
self._req_session.headers.update(enc_headers)

def _get_JWT_token(self) -> None:
def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response:
"""
Sends a prepared request, optionally retrying on errors

Parameters
----------
prep_req
A prepared request to send
allow_retries
If true, attempts to retry on certain kinds of errors

Returns
-------
:
The response returned from the request
"""

prep_req = self._req_session.prepare_request(req)

if self.debug_requests:
pretty_print_request(prep_req)

if not allow_retries:
ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False)

if self.debug_requests:
pretty_print_response(ret)

if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")
return ret

retry_count = 0

try:
ret = self._req_session.post(
self.address + "auth/v1/login",
json={"username": self._username, "password": self._password},
verify=self._verify,
)
while True:
try:
ret = self._req_session.send(
prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False
)
break
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e:
if retry_count >= self.retry_max:
raise

# eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05
jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction)
time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter

retry_count += 1
self._logger.warning(
f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds "
f"[{retry_count}/{self.retry_max}]"
)
time.sleep(time_to_wait)
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except requests.exceptions.ConnectionError:
raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None

if self.debug_requests:
pretty_print_response(ret)

if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")

return ret

def _get_JWT_token(self) -> None:

full_uri = self.address + "auth/v1/login"
json = {"username": self._username, "password": self._password}

req = requests.Request(method="POST", url=full_uri, json=json)
ret = self._send_request(req)

if ret.status_code == 200:
ret_json = ret.json()
self._jwt_refresh_token = ret_json["refresh_token"]
Expand All @@ -318,11 +384,12 @@ def _get_JWT_token(self) -> None:
raise AuthenticationFailure(msg)

def _refresh_JWT_token(self) -> None:
ret = self._req_session.post(
self.address + "auth/v1/refresh",
headers={"Authorization": f"Bearer {self._jwt_refresh_token}"},
verify=self._verify,
)

full_uri = self.address + "auth/v1/refresh"
headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"}

req = requests.Request(method="POST", url=full_uri, headers=headers)
ret = self._send_request(req)

if ret.status_code == 200:
ret_json = ret.json()
Expand Down Expand Up @@ -366,46 +433,8 @@ def _request(
self._refresh_JWT_token()

full_uri = self.address + endpoint

req = requests.Request(method=method.upper(), url=full_uri, data=body, params=url_params)
prep_req = self._req_session.prepare_request(req)

if self.debug_requests:
pretty_print_request(prep_req)

try:
if not allow_retries:
r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False)
if r.is_redirect:
raise RuntimeError("Redirection is not allowed")
else:
current_retries = 0
while True:
try:
r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout)
break
except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e:
if current_retries >= self.retry_max:
raise

# eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05
jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction)
time_to_wait = self.retry_delay * (self.retry_backoff**current_retries) * jitter

current_retries += 1
self._logger.warning(
f"Connection failed: {str(e)} - retrying in {time_to_wait:.2f} seconds "
f"[{current_retries}/{self.retry_max}]"
)
time.sleep(time_to_wait)

if self.debug_requests:
pretty_print_response(r)

except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except requests.exceptions.ConnectionError:
raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None
r = self._send_request(req, allow_retries=allow_retries)

# If JWT token expired, automatically renew it and retry once. This should have been caught above,
# but can happen in rare instances where the token expires between the time we check it and the time
Expand Down Expand Up @@ -529,4 +558,8 @@ def get_server_information(self) -> Dict[str, Any]:
"""

# Request the info, and store here for later use
return self.make_request("get", self._information_endpoint, Dict[str, Any])
# TODO - this fallback is temporary - remove in a future version
try:
return self.make_request("get", self._information_endpoint, Dict[str, Any])
except PortalRequestError as e:
return self.make_request("get", "api/v1/information", Dict[str, Any])