Skip to content

Commit

Permalink
Merge pull request #40 from skystrife/develop
Browse files Browse the repository at this point in the history
fix(dev): Use requests.Session object for updating cookies
  • Loading branch information
hfaran authored Sep 6, 2018
2 parents 0976a43 + 47efca0 commit 908f194
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
7 changes: 4 additions & 3 deletions piazza_api/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ class Network(object):
"""Abstraction for a Piazza "Network" (or class)
:param network_id: ID of the network
:param cookies: RequestsCookieJar containing cookies used for authentication
:param session: requests.Session object containing cookies used for
authentication
"""
def __init__(self, network_id, cookies):
def __init__(self, network_id, session):
self._nid = network_id
self._rpc = PiazzaRPC(network_id=self._nid)
self._rpc.cookies = cookies
self._rpc.session = session

ff = namedtuple('FeedFilters', ['unread', 'following', 'folder'])
self._feed_filters = ff(UnreadFilter, FollowingFilter, FolderFilter)
Expand Down
2 changes: 1 addition & 1 deletion piazza_api/piazza.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def network(self, network_id):
https://piazza.com/class/{network_id}
"""
self._ensure_authenticated()
return Network(network_id, self._rpc_api.cookies)
return Network(network_id, self._rpc_api.session)

def get_user_profile(self):
"""Get profile of the current user
Expand Down
21 changes: 9 additions & 12 deletions piazza_api/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, network_id=None):
"logic": "https://piazza.com/logic/api",
"main": "https://piazza.com/main/api",
}
self.cookies = None
self.session = requests.Session()

def user_login(self, email=None, password=None):
"""Login with email, password and get back a session cookie
Expand All @@ -51,14 +51,13 @@ def user_login(self, email=None, password=None):
}
# If the user/password match, the server respond will contain a
# session cookie that you can use to authenticate future requests.
r = requests.post(
r = self.session.post(
self.base_api_urls["logic"],
data=json.dumps(login_data),
)
if r.json()["result"] not in ["OK"]:
raise AuthenticationError("Could not authenticate.\n{}"
.format(r.json()))
self.cookies = r.cookies

def demo_login(self, auth=None, url=None):
"""Authenticate with a "Share Your Class" URL using a demo user.
Expand All @@ -76,10 +75,9 @@ def demo_login(self, auth=None, url=None):
if url is None:
url = "https://piazza.com/demo_login"
params = dict(nid=self._nid, auth=auth)
res = requests.get(url, params=params)
res = self.session.get(url, params=params)
else:
res = requests.get(url)
self.cookies = res.cookies
res = self.session.get(url)

def content_get(self, cid, nid=None):
"""Get data from post `cid` in network `nid`
Expand Down Expand Up @@ -379,9 +377,9 @@ def request(self, method, data=None, nid=None, nid_key='nid',
data = {}

headers = {}
if "session_id" in self.cookies:
headers["CSRF-Token"] = self.cookies["session_id"]
if "session_id" in self.session.cookies:
headers["CSRF-Token"] = self.session.cookies["session_id"]

# Adding a nonce to the request
endpoint = self.base_api_urls[api_type]
if api_type == "logic":
Expand All @@ -390,13 +388,12 @@ def request(self, method, data=None, nid=None, nid_key='nid',
_piazza_nonce()
)

response = requests.post(
response = self.session.post(
endpoint,
data=json.dumps({
"method": method,
"params": dict({nid_key: nid}, **data)
}),
cookies=self.cookies,
headers=headers
)
return response if return_response else response.json()
Expand All @@ -410,7 +407,7 @@ def _check_authenticated(self):
:raises: NotAuthenticatedError
"""
if self.cookies is None:
if not self.session.cookies:
raise NotAuthenticatedError("You must authenticate before "
"making any other requests.")

Expand Down

0 comments on commit 908f194

Please sign in to comment.