Skip to content
This repository has been archived by the owner on Jun 22, 2023. It is now read-only.

Commit

Permalink
Support redirect_state in OAuth1 backends too (enable twitter by defa…
Browse files Browse the repository at this point in the history
  • Loading branch information
omab committed Aug 2, 2014
1 parent b9b0250 commit 5aed7c0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 64 deletions.
125 changes: 71 additions & 54 deletions social/backends/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class OAuthAuth(BaseAuth):
ACCESS_TOKEN_METHOD = 'GET'
REVOKE_TOKEN_URL = None
REVOKE_TOKEN_METHOD = 'POST'
REDIRECT_STATE = False
STATE_PARAMETER = False

def extra_data(self, user, uid, response, details=None):
"""Return access_token and extra defined names to store in
Expand All @@ -36,6 +38,59 @@ def extra_data(self, user, uid, response, details=None):
data['access_token'] = response.get('access_token', '')
return data

def state_token(self):
"""Generate csrf token to include as state parameter."""
return self.strategy.random_string(32)

def get_or_create_state(self):
if self.STATE_PARAMETER or self.REDIRECT_STATE:
# Store state in session for further request validation. The state
# value is passed as state parameter (as specified in OAuth2 spec),
# but also added to redirect, that way we can still verify the
# request if the provider doesn't implement the state parameter.
# Reuse token if any.
name = self.name + '_state'
state = self.strategy.session_get(name)
if state is None:
state = self.state_token()
self.strategy.session_set(name, state)
else:
state = None
return state

def get_session_state(self):
return self.strategy.session_get(self.name + '_state')

def get_request_state(self):
request_state = self.data.get('state') or \
self.data.get('redirect_state')
if request_state and isinstance(request_state, list):
request_state = request_state[0]
return request_state

def validate_state(self):
"""Validate state value. Raises exception on error, returns state
value if valid."""
if not self.STATE_PARAMETER and not self.REDIRECT_STATE:
return None
state = self.get_session_state()
request_state = self.get_request_state()
if not request_state:
raise AuthMissingParameter(self, 'state')
elif not state:
raise AuthStateMissing(self, 'state')
elif not request_state == state:
raise AuthStateForbidden(self)
else:
return state

def get_redirect_uri(self, state=None):
"""Build redirect with redirect_state parameter."""
uri = self.redirect_uri
if self.REDIRECT_STATE and state:
uri = url_add_parameters(uri, {'redirect_state': state})
return uri

def get_scope(self):
"""Return list with needed access scope"""
scope = self.setting('SCOPE', [])
Expand Down Expand Up @@ -109,6 +164,7 @@ def auth_complete(self, *args, **kwargs):
"""Return user, might be logged in"""
# Multiple unauthorized tokens are supported (see #521)
self.process_error(self.data)
self.validate_state()
token = self.get_unauthorized_token()
try:
access_token = self.access_token(token)
Expand Down Expand Up @@ -169,12 +225,14 @@ def unauthorized_token(self):
# decoding='utf-8' produces errors with python-requests on Python3
# since the final URL will be of type bytes
decoding = None if six.PY3 else 'utf-8'
response = self.request(self.REQUEST_TOKEN_URL,
params=params,
auth=OAuth1(key, secret,
callback_uri=self.redirect_uri,
decoding=decoding),
method=self.REQUEST_TOKEN_METHOD)
state = self.get_or_create_state()
response = self.request(
self.REQUEST_TOKEN_URL,
params=params,
auth=OAuth1(key, secret, callback_uri=self.get_redirect_uri(state),
decoding=decoding),
method=self.REQUEST_TOKEN_METHOD
)
content = response.content
if response.encoding or response.apparent_encoding:
content = content.decode(response.encoding or
Expand All @@ -192,7 +250,8 @@ def oauth_authorization_request(self, token):
params[self.OAUTH_TOKEN_PARAMETER_NAME] = token.get(
self.OAUTH_TOKEN_PARAMETER_NAME
)
params[self.REDIRECT_URI_PARAMETER_NAME] = self.redirect_uri
state = self.get_or_create_state()
params[self.REDIRECT_URI_PARAMETER_NAME] = self.get_redirect_uri(state)
return self.AUTHORIZATION_URL + '?' + urlencode(params)

def oauth_auth(self, token=None, oauth_verifier=None,
Expand All @@ -203,10 +262,11 @@ def oauth_auth(self, token=None, oauth_verifier=None,
# decoding='utf-8' produces errors with python-requests on Python3
# since the final URL will be of type bytes
decoding = None if six.PY3 else 'utf-8'
state = self.get_or_create_state()
return OAuth1(key, secret,
resource_owner_key=token.get('oauth_token'),
resource_owner_secret=token.get('oauth_token_secret'),
callback_uri=self.redirect_uri,
callback_uri=self.get_redirect_uri(state),
verifier=oauth_verifier,
signature_type=signature_type,
decoding=decoding)
Expand Down Expand Up @@ -241,17 +301,6 @@ class BaseOAuth2(OAuthAuth):
REDIRECT_STATE = True
STATE_PARAMETER = True

def state_token(self):
"""Generate csrf token to include as state parameter."""
return self.strategy.random_string(32)

def get_redirect_uri(self, state=None):
"""Build redirect with redirect_state parameter."""
uri = self.redirect_uri
if self.REDIRECT_STATE and state:
uri = url_add_parameters(uri, {'redirect_state': state})
return uri

def auth_params(self, state=None):
client_id, client_secret = self.get_key_and_secret()
params = {
Expand All @@ -266,20 +315,7 @@ def auth_params(self, state=None):

def auth_url(self):
"""Return redirect url"""
if self.STATE_PARAMETER or self.REDIRECT_STATE:
# Store state in session for further request validation. The state
# value is passed as state parameter (as specified in OAuth2 spec),
# but also added to redirect, that way we can still verify the
# request if the provider doesn't implement the state parameter.
# Reuse token if any.
name = self.name + '_state'
state = self.strategy.session_get(name)
if state is None:
state = self.state_token()
self.strategy.session_set(name, state)
else:
state = None

state = self.get_or_create_state()
params = self.auth_params(state)
params.update(self.get_scope_argument())
params.update(self.auth_extra_arguments())
Expand All @@ -290,26 +326,6 @@ def auth_url(self):
params = unquote(params)
return self.AUTHORIZATION_URL + '?' + params

def validate_state(self):
"""Validate state value. Raises exception on error, returns state
value if valid."""
if not self.STATE_PARAMETER and not self.REDIRECT_STATE:
return None
state = self.strategy.session_get(self.name + '_state')
request_state = self.data.get('state') or \
self.data.get('redirect_state')
if request_state and isinstance(request_state, list):
request_state = request_state[0]

if not request_state:
raise AuthMissingParameter(self, 'state')
elif not state:
raise AuthStateMissing(self, 'state')
elif not request_state == state:
raise AuthStateForbidden(self)
else:
return state

def auth_complete_params(self, state=None):
client_id, client_secret = self.get_key_and_secret()
return {
Expand Down Expand Up @@ -338,11 +354,12 @@ def process_error(self, data):

def auth_complete(self, *args, **kwargs):
"""Completes loging process, must return user instance"""
state = self.validate_state()
self.process_error(self.data)
try:
response = self.request_access_token(
self.ACCESS_TOKEN_URL,
data=self.auth_complete_params(self.validate_state()),
data=self.auth_complete_params(state),
headers=self.auth_headers(),
method=self.ACCESS_TOKEN_METHOD
)
Expand Down
1 change: 1 addition & 0 deletions social/backends/twitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TwitterOAuth(BaseOAuth1):
AUTHORIZATION_URL = 'https://api.twitter.com/oauth/authenticate'
REQUEST_TOKEN_URL = 'https://api.twitter.com/oauth/request_token'
ACCESS_TOKEN_URL = 'https://api.twitter.com/oauth/access_token'
REDIRECT_STATE = True

def process_error(self, data):
if 'denied' in data:
Expand Down
26 changes: 16 additions & 10 deletions social/tests/backends/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from httpretty import HTTPretty

from social.p3 import urlparse
from social.utils import parse_qs
from social.utils import parse_qs, url_add_parameters

from social.tests.models import User
from social.tests.backends.base import BaseBackendTest
Expand All @@ -29,15 +29,21 @@ def _method(self, method):
'POST': HTTPretty.POST}[method]

def handle_state(self, start_url, target_url):
try:
if self.backend.STATE_PARAMETER or self.backend.REDIRECT_STATE:
query = parse_qs(urlparse(start_url).query)
target_url = target_url + ('?' in target_url and '&' or '?')
if 'state' in query or 'redirect_state' in query:
name = 'state' in query and 'state' or 'redirect_state'
target_url += '{0}={1}'.format(name, query[name])
except AttributeError:
pass
start_query = parse_qs(urlparse(start_url).query)
redirect_uri = start_query.get('redirect_uri')

if getattr(self.backend, 'STATE_PARAMETER', False):
if start_query.get('state'):
target_url = url_add_parameters(target_url, {
'state': start_query['state']
})

if redirect_uri and getattr(self.backend, 'REDIRECT_STATE', False):
redirect_query = parse_qs(urlparse(redirect_uri).query)
if redirect_query.get('redirect_state'):
target_url = url_add_parameters(target_url, {
'redirect_state': redirect_query['redirect_state']
})
return target_url

def auth_handlers(self, start_url):
Expand Down

0 comments on commit 5aed7c0

Please sign in to comment.