Skip to content

Commit

Permalink
ENH: Order _generate_notebooks output by datetime (#33)
Browse files Browse the repository at this point in the history
* ENH: Order `generate_*` output by timestamp

* TST: Make `TestGenerateNotebooks` strict on yielded order

* ENH: Incorporate memoization into `single_password_crypto_factory`

* TST: Add `memoized_single_arg` test
  • Loading branch information
nathanwolfe authored Jul 25, 2017
1 parent 495ab66 commit 99b8455
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 143 deletions.
18 changes: 18 additions & 0 deletions pgcontents/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import sys
import base64
from functools import wraps

from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -221,8 +222,25 @@ def single_password_crypto_factory(password):
The factory here returns a ``FernetEncryption`` that uses a key derived
from ``password`` and salted with the supplied user_id.
"""
@memoize_single_arg
def factory(user_id):
return FernetEncryption(
Fernet(derive_single_fernet_key(password, user_id))
)
return factory


def memoize_single_arg(f):
"""
Decorator memoizing a single-argument function
"""
memo = {}

@wraps(f)
def memoized_f(arg):
try:
return memo[arg]
except KeyError:
result = memo[arg] = f(arg)
return result
return memoized_f
61 changes: 38 additions & 23 deletions pgcontents/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
"""
Create a generator of decrypted files.
Files are yielded in ascending order of their timestamp.
This function selects all current notebooks (optionally, falling within a
datetime range), decrypts them, and returns a generator yielding dicts,
each containing a decoded notebook and metadata including the user,
Expand All @@ -571,12 +573,8 @@ def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
max_dt : datetime.datetime, optional
Last modified datetime at and after which a file will be excluded.
"""
where_conds = []
if min_dt is not None:
where_conds.append(files.c.created_at >= min_dt)
if max_dt is not None:
where_conds.append(files.c.created_at < max_dt)
return _generate_notebooks(files, engine, where_conds, crypto_factory)
return _generate_notebooks(files, files.c.created_at,
engine, crypto_factory, min_dt, max_dt)


# =======================================
Expand Down Expand Up @@ -736,6 +734,8 @@ def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
"""
Create a generator of decrypted remote checkpoints.
Checkpoints are yielded in ascending order of their timestamp.
This function selects all notebook checkpoints (optionally, falling within
a datetime range), decrypts them, and returns a generator yielding dicts,
each containing a decoded notebook and metadata including the user,
Expand All @@ -754,38 +754,53 @@ def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
max_dt : datetime.datetime, optional
Last modified datetime at and after which a file will be excluded.
"""
where_conds = []
if min_dt is not None:
where_conds.append(remote_checkpoints.c.last_modified >= min_dt)
if max_dt is not None:
where_conds.append(remote_checkpoints.c.last_modified < max_dt)
return _generate_notebooks(remote_checkpoints,
engine, where_conds, crypto_factory)
remote_checkpoints.c.last_modified,
engine, crypto_factory, min_dt, max_dt)


# ====================
# Files or Checkpoints
# ====================
def _generate_notebooks(table, engine, where_conds, crypto_factory):
def _generate_notebooks(table, timestamp_column,
engine, crypto_factory, min_dt, max_dt):
"""
See docstrings for `generate_files` and `generate_checkpoints`.
`where_conds` should be a list of SQLAlchemy expressions, which are used as
the conditions for WHERE clauses on the SELECT queries to the database.
Parameters
----------
table : SQLAlchemy.Table
Table to fetch notebooks from, `files` or `remote_checkpoints.
timestamp_column : SQLAlchemy.Column
`table`'s column storing timestamps, `created_at` or `last_modified`.
engine : SQLAlchemy.engine
Engine encapsulating database connections.
crypto_factory : function[str -> Any]
A function from user_id to an object providing the interface required
by PostgresContentsManager.crypto. Results of this will be used for
decryption of the selected notebooks.
min_dt : datetime.datetime, optional
Minimum last modified datetime at which a file will be included.
max_dt : datetime.datetime, optional
Last modified datetime at and after which a file will be excluded.
"""
where_conds = []
if min_dt is not None:
where_conds.append(timestamp_column >= min_dt)
if max_dt is not None:
where_conds.append(timestamp_column < max_dt)

# Query for notebooks satisfying the conditions.
query = select([table]).order_by(table.c.user_id)
query = select([table]).order_by(timestamp_column)
for cond in where_conds:
query = query.where(cond)
result = engine.execute(query)

# Decrypt each notebook and yield the result.
last_user_id = None
for nb_row in result:
# The decrypt function depends on the user, so if the user is the same
# then the decrypt function carries over.
if nb_row['user_id'] != last_user_id:
decrypt_func = crypto_factory(nb_row['user_id']).decrypt
last_user_id = nb_row['user_id']
# The decrypt function depends on the user
user_id = nb_row['user_id']
decrypt_func = crypto_factory(user_id).decrypt

nb_dict = to_dict_with_content(table.c, nb_row, decrypt_func)
if table is files:
Expand All @@ -798,7 +813,7 @@ def _generate_notebooks(table, engine, where_conds, crypto_factory):
# here as well.
yield {
'id': nb_dict['id'],
'user_id': nb_dict['user_id'],
'user_id': user_id,
'path': to_api_path(nb_dict['path']),
'last_modified': nb_dict['last_modified'],
'content': reads_base64(nb_dict['content']),
Expand Down
89 changes: 58 additions & 31 deletions pgcontents/tests/test_encryption.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,83 @@
"""
Tests for notebook encryption utilities.
"""
from unittest import TestCase

from cryptography.fernet import Fernet

from ..crypto import (
derive_fallback_fernet_keys,
FallbackCrypto,
FernetEncryption,
memoize_single_arg,
NoEncryption,
single_password_crypto_factory,
)


def test_fernet_derivation():
pws = [u'currentpassword', u'oldpassword', None]
class TestEncryption(TestCase):

def test_fernet_derivation(self):
pws = [u'currentpassword', u'oldpassword', None]

# This must be Unicode, so we use the `u` prefix to support py2.
user_id = u'4e322fa200fffd0001000001'

current_crypto = single_password_crypto_factory(pws[0])(user_id)
old_crypto = single_password_crypto_factory(pws[1])(user_id)

def make_single_key_crypto(key):
if key is None:
return NoEncryption()
return FernetEncryption(Fernet(key.encode('ascii')))

multi_fernet_crypto = FallbackCrypto(
[make_single_key_crypto(k)
for k in derive_fallback_fernet_keys(pws, user_id)]
)

# This must be Unicode, so we use the `u` prefix to support py2.
user_id = u'4e322fa200fffd0001000001'
data = b'ayy lmao'

current_crypto = single_password_crypto_factory(pws[0])(user_id)
old_crypto = single_password_crypto_factory(pws[1])(user_id)
# Data encrypted with the current key.
encrypted_data_current = current_crypto.encrypt(data)
self.assertNotEqual(encrypted_data_current, data)
self.assertEqual(current_crypto.decrypt(encrypted_data_current), data)

def make_single_key_crypto(key):
if key is None:
return NoEncryption()
return FernetEncryption(Fernet(key.encode('ascii')))
# Data encrypted with the old key.
encrypted_data_old = old_crypto.encrypt(data)
self.assertNotEqual(encrypted_data_current, data)
self.assertEqual(old_crypto.decrypt(encrypted_data_old), data)

multi_fernet_crypto = FallbackCrypto(
[make_single_key_crypto(k)
for k in derive_fallback_fernet_keys(pws, user_id)]
)
# The single fernet with the first key should be able to decrypt the
# multi-fernet's encrypted data.
self.assertEqual(
current_crypto.decrypt(multi_fernet_crypto.encrypt(data)),
data
)

data = b'ayy lmao'
# Multi should be able decrypt anything encrypted with either key.
self.assertEqual(multi_fernet_crypto.decrypt(encrypted_data_current),
data)
self.assertEqual(multi_fernet_crypto.decrypt(encrypted_data_old), data)

# Data encrypted with the current key.
encrypted_data_current = current_crypto.encrypt(data)
assert encrypted_data_current != data
assert current_crypto.decrypt(encrypted_data_current) == data
# Unencrypted data should be returned unchanged.
self.assertEqual(multi_fernet_crypto.decrypt(data), data)

# Data encrypted with the old key.
encrypted_data_old = old_crypto.encrypt(data)
assert encrypted_data_current != data
assert old_crypto.decrypt(encrypted_data_old) == data
def test_memoize_single_arg(self):
full_calls = []

# The single fernet with the first key should be able to decrypt the
# multi-fernet's encrypted data.
@memoize_single_arg
def mock_factory(user_id):
full_calls.append(user_id)
return u'crypto' + user_id

assert current_crypto.decrypt(multi_fernet_crypto.encrypt(data)) == data
calls_to_make = [u'1', u'2', u'3', u'2', u'1']
expected_results = [u'crypto' + user_id for user_id in calls_to_make]
expected_full_calls = [u'1', u'2', u'3']

# Multi should be able decrypt anything encrypted with either key.
assert multi_fernet_crypto.decrypt(encrypted_data_current) == data
assert multi_fernet_crypto.decrypt(encrypted_data_old) == data
results = []
for user_id in calls_to_make:
results.append(mock_factory(user_id))

# Unencrypted data should be returned unchanged.
assert multi_fernet_crypto.decrypt(data) == data
self.assertEqual(results, expected_results)
self.assertEqual(full_calls, expected_full_calls)
Loading

0 comments on commit 99b8455

Please sign in to comment.