diff --git a/qcportal/qcportal/client_base.py b/qcportal/qcportal/client_base.py index 254e69562..487f231f6 100644 --- a/qcportal/qcportal/client_base.py +++ b/qcportal/qcportal/client_base.py @@ -282,18 +282,84 @@ def encoding(self, encoding: str): enc_headers = {"Content-Type": encoding, "Accept": encoding} self._req_session.headers.update(enc_headers) - def _get_JWT_token(self) -> None: + def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response: + """ + Sends a prepared request, optionally retrying on errors + + Parameters + ---------- + prep_req + A prepared request to send + allow_retries + If true, attempts to retry on certain kinds of errors + + Returns + ------- + : + The response returned from the request + """ + + prep_req = self._req_session.prepare_request(req) + + if self.debug_requests: + pretty_print_request(prep_req) + + if not allow_retries: + ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False) + + if self.debug_requests: + pretty_print_response(ret) + + if ret.is_redirect: + raise RuntimeError("Redirection is not allowed") + return ret + + retry_count = 0 + try: - ret = self._req_session.post( - self.address + "auth/v1/login", - json={"username": self._username, "password": self._password}, - verify=self._verify, - ) + while True: + try: + ret = self._req_session.send( + prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False + ) + break + except requests.exceptions.SSLError: + raise ConnectionRefusedError(_ssl_error_msg) from None + except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e: + if retry_count >= self.retry_max: + raise + + # eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05 + jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction) + time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter + + retry_count += 1 + self._logger.warning( + f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds " + f"[{retry_count}/{self.retry_max}]" + ) + time.sleep(time_to_wait) except requests.exceptions.SSLError: raise ConnectionRefusedError(_ssl_error_msg) from None except requests.exceptions.ConnectionError: raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None + if self.debug_requests: + pretty_print_response(ret) + + if ret.is_redirect: + raise RuntimeError("Redirection is not allowed") + + return ret + + def _get_JWT_token(self) -> None: + + full_uri = self.address + "auth/v1/login" + json = {"username": self._username, "password": self._password} + + req = requests.Request(method="POST", url=full_uri, json=json) + ret = self._send_request(req) + if ret.status_code == 200: ret_json = ret.json() self._jwt_refresh_token = ret_json["refresh_token"] @@ -318,11 +384,12 @@ def _get_JWT_token(self) -> None: raise AuthenticationFailure(msg) def _refresh_JWT_token(self) -> None: - ret = self._req_session.post( - self.address + "auth/v1/refresh", - headers={"Authorization": f"Bearer {self._jwt_refresh_token}"}, - verify=self._verify, - ) + + full_uri = self.address + "auth/v1/refresh" + headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"} + + req = requests.Request(method="POST", url=full_uri, headers=headers) + ret = self._send_request(req) if ret.status_code == 200: ret_json = ret.json() @@ -366,46 +433,8 @@ def _request( self._refresh_JWT_token() full_uri = self.address + endpoint - req = requests.Request(method=method.upper(), url=full_uri, data=body, params=url_params) - prep_req = self._req_session.prepare_request(req) - - if self.debug_requests: - pretty_print_request(prep_req) - - try: - if not allow_retries: - r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False) - if r.is_redirect: - raise RuntimeError("Redirection is not allowed") - else: - current_retries = 0 - while True: - try: - r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout) - break - except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e: - if current_retries >= self.retry_max: - raise - - # eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05 - jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction) - time_to_wait = self.retry_delay * (self.retry_backoff**current_retries) * jitter - - current_retries += 1 - self._logger.warning( - f"Connection failed: {str(e)} - retrying in {time_to_wait:.2f} seconds " - f"[{current_retries}/{self.retry_max}]" - ) - time.sleep(time_to_wait) - - if self.debug_requests: - pretty_print_response(r) - - except requests.exceptions.SSLError: - raise ConnectionRefusedError(_ssl_error_msg) from None - except requests.exceptions.ConnectionError: - raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None + r = self._send_request(req, allow_retries=allow_retries) # If JWT token expired, automatically renew it and retry once. This should have been caught above, # but can happen in rare instances where the token expires between the time we check it and the time @@ -529,4 +558,8 @@ def get_server_information(self) -> Dict[str, Any]: """ # Request the info, and store here for later use - return self.make_request("get", self._information_endpoint, Dict[str, Any]) + # TODO - this fallback is temporary - remove in a future version + try: + return self.make_request("get", self._information_endpoint, Dict[str, Any]) + except PortalRequestError as e: + return self.make_request("get", "api/v1/information", Dict[str, Any])