Skip to content

Commit

Permalink
ENH: Add API to yield notebooks in bulk. (#30)
Browse files Browse the repository at this point in the history
Added functions under `pgcontents.query`, separate ones to generate
current files and remote checkpoints. Also added tests for these new
functions.
  • Loading branch information
nathanwolfe authored Jun 30, 2017
1 parent 488507f commit 5382165
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 2 deletions.
106 changes: 105 additions & 1 deletion pgcontents/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .api_utils import (
from_api_dirname,
from_api_filename,
reads_base64,
split_api_filepath,
to_api_path,
)
from .constants import UNLIMITED
from .db_utils import (
Expand Down Expand Up @@ -547,6 +549,36 @@ def save_file(db, user_id, path, content, encrypt_func, max_size_bytes):
return res


def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
"""
Create a generator of decrypted files.
This function selects all current notebooks (optionally, falling within a
datetime range), decrypts them, and returns a generator yielding dicts,
each containing a decoded notebook and metadata including the user,
filepath, and timestamp.
Parameters
----------
engine : SQLAlchemy.engine
Engine encapsulating database connections.
crypto_factory : function[str -> Any]
A function from user_id to an object providing the interface required
by PostgresContentsManager.crypto. Results of this will be used for
decryption of the selected notebooks.
min_dt : datetime.datetime, optional
Minimum last modified datetime at which a file will be included.
max_dt : datetime.datetime, optional
Last modified datetime at and after which a file will be excluded.
"""
where_conds = []
if min_dt is not None:
where_conds.append(files.c.created_at >= min_dt)
if max_dt is not None:
where_conds.append(files.c.created_at < max_dt)
return _generate_notebooks(files, engine, where_conds, crypto_factory)


# =======================================
# Checkpoints (PostgresCheckpoints)
# =======================================
Expand Down Expand Up @@ -700,6 +732,79 @@ def purge_remote_checkpoints(db, user_id):
)


def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
"""
Create a generator of decrypted remote checkpoints.
This function selects all notebook checkpoints (optionally, falling within
a datetime range), decrypts them, and returns a generator yielding dicts,
each containing a decoded notebook and metadata including the user,
filepath, and timestamp.
Parameters
----------
engine : SQLAlchemy.engine
Engine encapsulating database connections.
crypto_factory : function[str -> Any]
A function from user_id to an object providing the interface required
by PostgresContentsManager.crypto. Results of this will be used for
decryption of the selected notebooks.
min_dt : datetime.datetime, optional
Minimum last modified datetime at which a file will be included.
max_dt : datetime.datetime, optional
Last modified datetime at and after which a file will be excluded.
"""
where_conds = []
if min_dt is not None:
where_conds.append(remote_checkpoints.c.last_modified >= min_dt)
if max_dt is not None:
where_conds.append(remote_checkpoints.c.last_modified < max_dt)
return _generate_notebooks(remote_checkpoints,
engine, where_conds, crypto_factory)


# ====================
# Files or Checkpoints
# ====================
def _generate_notebooks(table, engine, where_conds, crypto_factory):
"""
See docstrings for `generate_files` and `generate_checkpoints`.
`where_conds` should be a list of SQLAlchemy expressions, which are used as
the conditions for WHERE clauses on the SELECT queries to the database.
"""
# Query for notebooks satisfying the conditions.
query = select([table]).order_by(table.c.user_id)
for cond in where_conds:
query = query.where(cond)
result = engine.execute(query)

# Decrypt each notebook and yield the result.
last_user_id = None
for nb_row in result:
# The decrypt function depends on the user, so if the user is the same
# then the decrypt function carries over.
if nb_row['user_id'] != last_user_id:
decrypt_func = crypto_factory(nb_row['user_id']).decrypt
last_user_id = nb_row['user_id']

nb_dict = to_dict_with_content(table.c, nb_row, decrypt_func)
if table is files:
# Correct for files schema differing somewhat from checkpoints.
nb_dict['path'] = nb_dict['parent_name'] + nb_dict['name']
nb_dict['last_modified'] = nb_dict['created_at']

# For 'content', we use `reads_base64` directly. If the db content
# format is changed from base64, the decoding should be changed
# here as well.
yield {
'id': nb_dict['id'],
'user_id': nb_dict['user_id'],
'path': to_api_path(nb_dict['path']),
'last_modified': nb_dict['last_modified'],
'content': reads_base64(nb_dict['content']),
}


##########################
# Reencryption Utilities #
##########################
Expand Down Expand Up @@ -776,7 +881,6 @@ def reencrypt_user_content(engine,
# file-reencryption process, but we might not see that checkpoint here,
# which means that we would never update the content of that checkpoint
# to the new encryption key.

logger.info("Re-encrypting files for %s", user_id)
for (file_id,) in select_file_ids(db, user_id):
reencrypt_row_content(
Expand Down
218 changes: 217 additions & 1 deletion pgcontents/tests/test_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from sqlalchemy import create_engine

from pgcontents import PostgresContentsManager
from pgcontents.crypto import FernetEncryption, NoEncryption
from pgcontents.crypto import (
FernetEncryption,
NoEncryption,
single_password_crypto_factory,
)
from pgcontents.query import generate_files, generate_checkpoints
from pgcontents.utils.ipycompat import new_markdown_cell

from .utils import (
Expand Down Expand Up @@ -177,3 +182,214 @@ def check_reencryption(old, new):
# crypto manager.
unencrypt_all_users(engine, crypto2_factory, logger)
check_reencryption(manager2, no_crypto_manager)


class TestGenerateNotebooks(TestCase):

def setUp(self):
remigrate_test_schema()
self.db_url = TEST_DB_URL
self.engine = create_engine(self.db_url)
encryption_pw = u'foobar'
self.crypto_factory = single_password_crypto_factory(encryption_pw)

def tearDown(self):
clear_test_db()

def populate_users(self, user_ids):
"""
Create a `PostgresContentsManager` and notebooks for each user.
"""
def encrypted_pgmanager(user_id):
return PostgresContentsManager(
user_id=user_id,
db_url=self.db_url,
crypto=self.crypto_factory(user_id),
create_user_on_startup=True,
)
managers = {user_id: encrypted_pgmanager(user_id)
for user_id in user_ids}
paths = {user_id: populate(managers[user_id]) for user_id in user_ids}
return (managers, paths)

def test_generate_files(self):
"""
Create files for three users; try fetching them using `generate_files`.
"""
user_ids = ['test_generate_files0',
'test_generate_files1',
'test_generate_files2']
(managers, paths) = self.populate_users(user_ids)

def get_file_dt(user_id, idx):
path = paths[user_id][idx]
return managers[user_id].get(path, content=False)['last_modified']

# Find a split datetime midway through each user's list of files
split_idx = len(paths[user_ids[0]]) // 2
split_dts = [get_file_dt(user_id, split_idx) for user_id in user_ids]

def check_call(kwargs, expect_files_by_user):
"""
Call `generate_files`; check that all expected files are found,
with the correct content.
"""
file_record = {user_id: [] for user_id in expect_files_by_user}
for result in generate_files(self.engine, self.crypto_factory,
**kwargs):
manager = managers[result['user_id']]

# This recreates functionality from
# `manager._notebook_model_from_db` to match with the model
# returned by `manager.get`.
nb = result['content']
manager.mark_trusted_cells(nb, result['path'])

# Check that the content returned by the pgcontents manager
# matches that returned by `generate_files`
self.assertEqual(nb, manager.get(result['path'])['content'])

file_record[result['user_id']].append(result['path'])

# Make sure all files were found
for user_id in expect_files_by_user:
self.assertEqual(sorted(file_record[user_id]),
sorted(expect_files_by_user[user_id]))

# Expect all files given no `min_dt`/`max_dt`
check_call({}, paths)

# `min_dt` is in the middle of 1's files; we get the latter half of 1's
# and all of 2's
check_call({'min_dt': split_dts[1]},
{
user_ids[0]: [],
user_ids[1]: paths[user_ids[1]][split_idx:],
user_ids[2]: paths[user_ids[2]],
})

# `max_dt` is in the middle of 1's files; we get all of 0's and the
# beginning half of 1's
check_call({'max_dt': split_dts[1]},
{
user_ids[0]: paths[user_ids[0]],
user_ids[1]: paths[user_ids[1]][:split_idx],
user_ids[2]: [],
})

# `min_dt` is in the middle of 0's files cutting off 0's beginning half
# `max_dt` is in the middle of 2's files cutting off 2's latter half
check_call({'min_dt': split_dts[0], 'max_dt': split_dts[2]},
{
user_ids[0]: paths[user_ids[0]][split_idx:],
user_ids[1]: paths[user_ids[1]],
user_ids[2]: paths[user_ids[2]][:split_idx],
})

def test_generate_checkpoints(self):
"""
Create checkpoints in three stages; try fetching them with
`generate_checkpoints`.
"""
user_ids = ['test_generate_checkpoints0',
'test_generate_checkpoints1',
'test_generate_checkpoints2']
(managers, paths) = self.populate_users(user_ids)

def update_content(user_id, path, text):
"""
Add a Markdown cell and save the notebook.
Returns the new notebook content.
"""
manager = managers[user_id]
model = manager.get(path)
model['content'].cells.append(
new_markdown_cell(text + ' on path: ' + path)
)
manager.save(model, path)
return manager.get(path)['content']

# Each of the next three steps creates a checkpoint for each notebook
# and stores the notebook content in a dict, keyed by the user id,
# the path, and the datetime of the new checkpoint.

# Begin by making a checkpoint for the original notebook content.
beginning_content = {}
for user_id in user_ids:
for path in paths[user_id]:
content = managers[user_id].get(path)['content']
dt = managers[user_id].create_checkpoint(path)['last_modified']
beginning_content[user_id, path, dt] = content

# Update each notebook and make a new checkpoint.
middle_content = {}
middle_min_dt = None
for user_id in user_ids:
for path in paths[user_id]:
content = update_content(user_id, path, '1st addition')
dt = managers[user_id].create_checkpoint(path)['last_modified']
middle_content[user_id, path, dt] = content
if middle_min_dt is None:
middle_min_dt = dt

# Update each notebook again and make another checkpoint.
end_content = {}
end_min_dt = None
for user_id in user_ids:
for path in paths[user_id]:
content = update_content(user_id, path, '2nd addition')
dt = managers[user_id].create_checkpoint(path)['last_modified']
end_content[user_id, path, dt] = content
if end_min_dt is None:
end_min_dt = dt

def merge_dicts(*args):
result = {}
for d in args:
result.update(d)
return result

def check_call(kwargs, expect_checkpoints_content):
"""
Call `generate_checkpoints`; check that all expected checkpoints
are found, with the correct content.
"""
expect_checkpoints = expect_checkpoints_content.keys()
checkpoint_record = []
for result in generate_checkpoints(self.engine,
self.crypto_factory, **kwargs):
manager = managers[result['user_id']]

# This recreates functionality from
# `manager._notebook_model_from_db` to match with the model
# returned by `manager.get`.
nb = result['content']
manager.mark_trusted_cells(nb, result['path'])

# Check that the checkpoint content matches what's expected
key = (result['user_id'], result['path'],
result['last_modified'])
self.assertEqual(nb, expect_checkpoints_content[key])

checkpoint_record.append(key)

# Make sure all checkpoints were found
self.assertEqual(sorted(checkpoint_record),
sorted(expect_checkpoints))

# No `min_dt`/`max_dt`
check_call({}, merge_dicts(beginning_content,
middle_content, end_content))

# `min_dt` cuts off `beginning_content` checkpoints
check_call({'min_dt': middle_min_dt},
merge_dicts(middle_content, end_content))

# `max_dt` cuts off `end_content` checkpoints
check_call({'max_dt': end_min_dt},
merge_dicts(beginning_content, middle_content))

# `min_dt` and `max_dt` together isolate `middle_content`
check_call({'min_dt': middle_min_dt, 'max_dt': end_min_dt},
middle_content)

0 comments on commit 5382165

Please sign in to comment.