From 8221b0be8de96d8f9663873a31a85add4acf06fb Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 16 Jan 2024 11:08:36 +0100 Subject: [PATCH] refactor: Add support for auto-refreshing token without refresh token --- requests_oauthlib/oauth2_session.py | 22 +++++++- tests/test_oauth2_session.py | 85 ++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index c9566a21..0c927020 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -41,6 +41,7 @@ def __init__( client=None, auto_refresh_url=None, auto_refresh_kwargs=None, + auto_refresh_type="refresh_token", scope=None, redirect_uri=None, token=None, @@ -67,6 +68,8 @@ def __init__( your access tokens. :auto_refresh_kwargs: Extra arguments to pass to the refresh token endpoint. + :auto_refresh_type: Type of auto refresh method to use. Must be either + "refresh_token" (default) or "access_token". :token_updater: Method with one argument, token, to be used to update your token database on automatic token refresh. If not set a TokenUpdated warning will be raised when a token @@ -83,6 +86,7 @@ def __init__( self._state = state self.auto_refresh_url = auto_refresh_url self.auto_refresh_kwargs = auto_refresh_kwargs or {} + self.auto_refresh_type = auto_refresh_type self.token_updater = token_updater # Ensure that requests doesn't do any automatic auth. See #278. @@ -481,6 +485,20 @@ def refresh_token( self.token["refresh_token"] = refresh_token return self.token + def update_token(self, auth=None, **kwargs): + if self.auto_refresh_type == "refresh_token": + return self.refresh_token( + self.auto_refresh_url, auth=auth, **kwargs + ) + + if self.auto_refresh_type == "access_token": + return self.fetch_token( + self.auto_refresh_url, + auth=auth, + **dict(kwargs, **self.auto_refresh_kwargs), + ) + raise RuntimeError("Unknown auto_refresh_type: %s" % self.auto_refresh_type) + def request( self, method, @@ -526,9 +544,7 @@ def request( client_id, ) auth = requests.auth.HTTPBasicAuth(client_id, client_secret) - token = self.refresh_token( - self.auto_refresh_url, auth=auth, **kwargs - ) + token = self.update_token(auth=auth, **kwargs) if self.token_updater: log.debug( "Updating token to %s using %s.", token, self.token_updater diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index 2f7b227d..94c3cc20 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -1,12 +1,16 @@ from __future__ import unicode_literals + import json -import time -import tempfile -import shutil import os +import shutil +import tempfile +import time from base64 import b64encode from copy import deepcopy from unittest import TestCase +from unittest.mock import patch, Mock + +from requests import Response, Request try: from unittest import mock @@ -497,6 +501,81 @@ def fake_send(r, **kwargs): sess.fetch_token(url) self.assertTrue(sess.authorized) + @patch("requests.sessions.Session.request") + def test_request_when_auto_refresh_type_is_token_refresh(self, mock_request): + # Auto refresh and auto update + def token_updater(token): + self.assertEqual(token["token_type"], self.token["token_type"]) + self.assertEqual(token["access_token"], self.token["access_token"]) + self.assertEqual(token["refresh_token"], self.token["refresh_token"]) + self.assertEqual(token["expires_in"], self.token["expires_in"]) + self.assertIsNotNone(token["expires_at"]) + self.assertIsNotNone(self.token["expires_at"]) + + expired_token = dict(self.token) + expired_token["expires_at"] = time.time() - 7200 + + mock_response = Mock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.text = json.dumps(dict(self.token)) + mock_response.json.return_value = dict(self.token) + mock_request.return_value = mock_response + + for client in self.clients: + sess = OAuth2Session( + client=client, + token=expired_token, + token_updater=token_updater, + auto_refresh_url="https://i.b/refresh" + ) + sess.request( + method="POST", + url="https://example.com", + client_id="someclientid", + client_secret="someclientsecret" + ) + + @patch("requests.sessions.Session.request") + def test_request_when_auto_refresh_type_is_access_token(self, mock_request): + # Auto refresh and auto update + def token_updater(token): + self.assertEqual(token["token_type"], self.token["token_type"]) + self.assertEqual(token["access_token"], self.token["access_token"]) + self.assertEqual(token["refresh_token"], self.token["refresh_token"]) + self.assertEqual(token["expires_in"], self.token["expires_in"]) + self.assertIsNotNone(token["expires_at"]) + self.assertIsNotNone(self.token["expires_at"]) + + expired_token = dict(self.token) + expired_token["expires_at"] = time.time() - 7200 + + mock_response = Mock(spec=Response) + mock_response.request = Mock(spec=Request) + mock_response.request.url = "https://i.b/access" + mock_response.request.headers = {} + mock_response.request.body = {} + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.request = mock_request + mock_response.text = json.dumps(dict(self.token)) + mock_request.return_value = mock_response + + for client in self.clients: + sess = OAuth2Session( + client=client, + token=expired_token, + auto_refresh_type="access_token", + token_updater=token_updater, + auto_refresh_url="https://i.b/access" + ) + sess.request( + method="POST", + url="https://example.com", + username="someclientid", + password="someclientsecret" + ) + class OAuth2SessionNetrcTest(OAuth2SessionTest): """Ensure that there is no magic auth handling.