Skip to content

Commit

Permalink
Adding items from todo list, including inheritance and options testing
Browse files Browse the repository at this point in the history
and topo sort
  • Loading branch information
coleifer committed Oct 3, 2012
1 parent e8a068d commit 54fcef2
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 13 deletions.
4 changes: 0 additions & 4 deletions TODO.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
todo
====

* Q() with django syntax
* inheritance test
* model options test
* topo sort
* backwards compat, esp places where existing api allows strings
* stronger input validation?
* docs
32 changes: 30 additions & 2 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def __get__(self, instance, instance_type=None):
class ForeignKeyField(Field):
def __init__(self, rel_model, null=False, related_name=None, cascade=False, extra=None, *args, **kwargs):
self.rel_model = rel_model
self.related_name = related_name
self._related_name = related_name
self.cascade = cascade
self.extra = extra

Expand All @@ -516,7 +516,7 @@ def add_to_class(self, model_class, name):
model_class._meta.fields[self.name] = self
model_class._meta.columns[self.db_column] = self

self.related_name = self.related_name or '%s_set' % (model_class._meta.name)
self.related_name = self._related_name or '%s_set' % (model_class._meta.name)

if self.rel_model == 'self':
self.rel_model = self.model_class
Expand Down Expand Up @@ -1976,3 +1976,31 @@ def __eq__(self, other):

def __ne__(self, other):
return not self == other


def create_model_tables(models, **create_table_kwargs):
"""Create tables for all given models (in the right order)."""
for m in sort_models_topologically(models):
m.create_table(**create_table_kwargs)

def drop_model_tables(models, **drop_table_kwargs):
"""Drop tables for all given models (in the right order)."""
for m in reversed(sort_models_topologically(models)):
m.drop_table(**drop_table_kwargs)

def sort_models_topologically(models):
"""Sort models topologically so that parents will precede children."""
models = set(models)
seen = set()
ordering = []
def dfs(model):
if model in models and model not in seen:
seen.add(model)
for child_model in model._meta.reverse_rel.values():
dfs(child_model)
ordering.append(model) # parent will follow descendants
# order models by name and table initially to guarantee a total ordering
names = lambda m: (m._meta.name, m._meta.db_table)
for m in sorted(models, key=names, reverse=True):
dfs(m)
return list(reversed(ordering)) # want parents first in output ordering
172 changes: 165 additions & 7 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Meta:
db_table = 'users'

class Blog(TestModel):
user = ForeignKeyField(User, related_name='blogs')
user = ForeignKeyField(User)
title = CharField(max_length=25)
content = TextField(default='')
pub_date = DateTimeField(null=True)
Expand Down Expand Up @@ -178,9 +178,13 @@ class Meta:
(('f2', 'f3'), False),
)

class BlogTwo(Blog):
title = TextField()
extra_field = CharField()


MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel]
NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel, BlogTwo]
INT = test_db.interpolation

def drop_tables(only=None):
Expand Down Expand Up @@ -762,8 +766,8 @@ def test_related_name(self):
b12 = Blog.create(user=u1, title='b12')
b2 = Blog.create(user=u2, title='b2')

self.assertEqual([b.title for b in u1.blogs], ['b11', 'b12'])
self.assertEqual([b.title for b in u2.blogs], ['b2'])
self.assertEqual([b.title for b in u1.blog_set], ['b11', 'b12'])
self.assertEqual([b.title for b in u2.blog_set], ['b2'])

def test_fk_exceptions(self):
c1 = Category.create(name='c1')
Expand Down Expand Up @@ -1408,9 +1412,116 @@ def reader_thread(q, num):
self.assertEqual(data_queue.qsize(), 100)


class ModelInheritanceTestCase(BasePeeweeTestCase):
# TODO
pass
class ModelOptionInheritanceTestCase(BasePeeweeTestCase):
def test_db_table(self):
self.assertEqual(User._meta.db_table, 'users')

class Foo(TestModel):
pass
self.assertEqual(Foo._meta.db_table, 'foo')

class Foo2(TestModel):
pass
self.assertEqual(Foo2._meta.db_table, 'foo2')

class Foo_3(TestModel):
pass
self.assertEqual(Foo_3._meta.db_table, 'foo_3')

def test_option_inheritance(self):
x_test_db = SqliteDatabase('testing.db')
child2_db = SqliteDatabase('child2.db')

class FakeUser(Model):
pass

class ParentModel(Model):
title = CharField()
user = ForeignKeyField(FakeUser)

class Meta:
database = x_test_db

class ChildModel(ParentModel):
pass

class ChildModel2(ParentModel):
special_field = CharField()

class Meta:
database = child2_db

class GrandChildModel(ChildModel):
pass

class GrandChildModel2(ChildModel2):
special_field = TextField()

self.assertEqual(ParentModel._meta.database.database, 'testing.db')
self.assertEqual(ParentModel._meta.model_class, ParentModel)

self.assertEqual(ChildModel._meta.database.database, 'testing.db')
self.assertEqual(ChildModel._meta.model_class, ChildModel)
self.assertEqual(sorted(ChildModel._meta.fields.keys()), [
'id', 'title', 'user'
])

self.assertEqual(ChildModel2._meta.database.database, 'child2.db')
self.assertEqual(ChildModel2._meta.model_class, ChildModel2)
self.assertEqual(sorted(ChildModel2._meta.fields.keys()), [
'id', 'special_field', 'title', 'user'
])

self.assertEqual(GrandChildModel._meta.database.database, 'testing.db')
self.assertEqual(GrandChildModel._meta.model_class, GrandChildModel)
self.assertEqual(sorted(GrandChildModel._meta.fields.keys()), [
'id', 'title', 'user'
])

self.assertEqual(GrandChildModel2._meta.database.database, 'child2.db')
self.assertEqual(GrandChildModel2._meta.model_class, GrandChildModel2)
self.assertEqual(sorted(GrandChildModel2._meta.fields.keys()), [
'id', 'special_field', 'title', 'user'
])
self.assertTrue(isinstance(GrandChildModel2._meta.fields['special_field'], TextField))


class ModelInheritanceTestCase(ModelTestCase):
requires = [Blog, BlogTwo, User]

def test_model_inheritance_attrs(self):
self.assertEqual(Blog._meta.get_field_names(), ['pk', 'user', 'title', 'content', 'pub_date'])
self.assertEqual(BlogTwo._meta.get_field_names(), ['id', 'user', 'content', 'pub_date', 'title', 'extra_field'])

self.assertEqual(Blog._meta.primary_key.name, 'pk')
self.assertEqual(BlogTwo._meta.primary_key.name, 'id')

self.assertEqual(Blog.user.related_name, 'blog_set')
self.assertEqual(BlogTwo.user.related_name, 'blogtwo_set')

self.assertEqual(User.blog_set.rel_model, Blog)
self.assertEqual(User.blogtwo_set.rel_model, BlogTwo)

self.assertFalse(BlogTwo._meta.db_table == Blog._meta.db_table)

def test_model_inheritance_flow(self):
u = User.create(username='u')

b = Blog.create(title='b', user=u)
b2 = BlogTwo.create(title='b2', extra_field='foo', user=u)

self.assertEqual(list(u.blog_set), [b])
self.assertEqual(list(u.blogtwo_set), [b2])

self.assertEqual(Blog.select().count(), 1)
self.assertEqual(BlogTwo.select().count(), 1)

b_from_db = Blog.get(pk=b.pk)
b2_from_db = BlogTwo.get(id=b2.id)

self.assertEqual(b_from_db.user, u)
self.assertEqual(b2_from_db.user, u)
self.assertEqual(b2_from_db.extra_field, 'foo')


class DatabaseTestCase(BasePeeweeTestCase):
Expand Down Expand Up @@ -1449,6 +1560,53 @@ def test_connection_state(self):
self.assertFalse(test_db.is_closed())


class TopologicalSortTestCase(unittest.TestCase):
def test_topological_sort_fundamentals(self):
FKF = ForeignKeyField
# we will be topo-sorting the following models
class A(Model): pass
class B(Model): a = FKF(A) # must follow A
class C(Model): a, b = FKF(A), FKF(B) # must follow A and B
class D(Model): c = FKF(C) # must follow A and B and C
class E(Model): e = FKF('self')
# but excluding this model, which is a child of E
class Excluded(Model): e = FKF(E)

# property 1: output ordering must not depend upon input order
repeatable_ordering = None
for input_ordering in permutations([A, B, C, D, E]):
output_ordering = sort_models_topologically(input_ordering)
repeatable_ordering = repeatable_ordering or output_ordering
self.assertEqual(repeatable_ordering, output_ordering)

# property 2: output ordering must have same models as input
self.assertEqual(len(output_ordering), 5)
self.assertFalse(Excluded in output_ordering)

# property 3: parents must precede children
def assert_precedes(X, Y):
lhs, rhs = map(output_ordering.index, [X, Y])
self.assertTrue(lhs < rhs)
assert_precedes(A, B)
assert_precedes(B, C) # if true, C follows A by transitivity
assert_precedes(C, D) # if true, D follows A and B by transitivity

# property 4: independent model hierarchies must be in name order
assert_precedes(A, E)

def permutations(xs):
if not xs:
yield []
else:
for y, ys in selections(xs):
for pys in permutations(ys):
yield [y] + pys

def selections(xs):
for i in xrange(len(xs)):
yield (xs[i], xs[:i] + xs[i + 1:])


if test_db.for_update:
class ForUpdateTestCase(ModelTestCase):
requires = [User]
Expand Down

0 comments on commit 54fcef2

Please sign in to comment.