Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make historical records' m2m fields type-compatible with non-historical #1187

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions simple_history/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ def __get__(self, instance, owner):
return HistoryManager.from_queryset(HistoricalQuerySet)(self.model, instance)


class HistoryManyToManyDescriptor:
def __init__(self, model, rel):
self.rel = rel
self.model = model

def __get__(self, instance, owner):
return HistoryManyRelatedManager.from_queryset(QuerySet)(
self.model, self.rel, instance
)


class HistoryManyRelatedManager(models.Manager):
def __init__(self, through, rel, instance=None):
super().__init__()
self.model = rel.model
self.through = through
self.instance = instance
self._m2m_through_field_name = rel.field.m2m_reverse_field_name()

def get_queryset(self):
qs = super().get_queryset()
through_qs = HistoryManager.from_queryset(HistoricalQuerySet)(
self.through, self.instance
)
return qs.filter(
pk__in=through_qs.all().values_list(self._m2m_through_field_name, flat=True)
)


class HistoryManager(models.Manager):
def __init__(self, model, instance=None):
super().__init__()
Expand Down
8 changes: 6 additions & 2 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from simple_history import utils

from . import exceptions
from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor
from .manager import (
SIMPLE_HISTORY_REVERSE_ATTR_NAME,
HistoryDescriptor,
HistoryManyToManyDescriptor,
)
from .signals import (
post_create_historical_m2m_records,
post_create_historical_record,
Expand Down Expand Up @@ -227,7 +231,7 @@ def finalize(self, sender, **kwargs):

setattr(module, m2m_model.__name__, m2m_model)

m2m_descriptor = HistoryDescriptor(m2m_model)
m2m_descriptor = HistoryManyToManyDescriptor(m2m_model, field.remote_field)
setattr(history_model, field.name, m2m_descriptor)

def get_history_model_name(self, model):
Expand Down
26 changes: 13 additions & 13 deletions simple_history/tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,12 +1806,12 @@ def test_separation(self):
self.assertEqual(book.restaurants.all().count(), 0)
self.assertEqual(book.books.all().count(), 1)
self.assertEqual(book.places.all().count(), 1)
self.assertEqual(book.books.first().book, self.book)
self.assertEqual(book.books.first(), self.book)

self.assertEqual(place.restaurants.all().count(), 0)
self.assertEqual(place.books.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.restaurants.all().count(), 0)
self.assertEqual(add.books.all().count(), 0)
Expand Down Expand Up @@ -1847,11 +1847,11 @@ def test_separation(self):

self.assertEqual(book.books.all().count(), 1)
self.assertEqual(book.places.all().count(), 1)
self.assertEqual(book.books.first().book, self.book)
self.assertEqual(book.books.first(), self.book)

self.assertEqual(place.books.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.books.all().count(), 0)
self.assertEqual(add.places.all().count(), 0)
Expand All @@ -1860,11 +1860,11 @@ def test_separation(self):

self.assertEqual(restaurant.restaurants.all().count(), 1)
self.assertEqual(restaurant.places.all().count(), 1)
self.assertEqual(restaurant.restaurants.first().restaurant, self.restaurant)
self.assertEqual(restaurant.restaurants.first(), self.restaurant)

self.assertEqual(place.restaurants.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.restaurants.all().count(), 0)
self.assertEqual(add.places.all().count(), 0)
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def test_create(self):

# And the historical place is the correct one
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_remove(self):
# Add and remove a many-to-many child
Expand All @@ -2002,7 +2002,7 @@ def test_remove(self):

# And the previous row still has the correct one
historical_place = previous_m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_clear(self):
# Add some places
Expand Down Expand Up @@ -2054,7 +2054,7 @@ def test_delete_child(self):
# Place instance cannot be created...
historical_place = m2m_record.places.first()
with self.assertRaises(ObjectDoesNotExist):
historical_place.place.id
historical_place.id

# But the values persist
historical_place_values = m2m_record.places.all().values()[0]
Expand Down Expand Up @@ -2084,7 +2084,7 @@ def test_delete_parent(self):

# And it is the correct one
historical_place = prev_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_update_child(self):
self.poll.places.add(self.place)
Expand All @@ -2102,7 +2102,7 @@ def test_update_child(self):
m2m_record = self.poll.history.all()[0]
self.assertEqual(m2m_record.places.count(), 1)
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place.name, "Updated")
self.assertEqual(historical_place.name, "Updated")

def test_update_parent(self):
self.poll.places.add(self.place)
Expand All @@ -2120,7 +2120,7 @@ def test_update_parent(self):
m2m_record = self.poll.history.all()[0]
self.assertEqual(m2m_record.places.count(), 1)
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_bulk_add_remove(self):
# Add some places
Expand Down Expand Up @@ -2152,7 +2152,7 @@ def test_bulk_add_remove(self):
self.assertEqual(m2m_record.places.all().count(), 1)

historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_m2m_relation(self):
# Ensure only the correct M2Ms are saved and returned for history objects
Expand Down