Skip to content

Commit

Permalink
Make historical records' m2m fields type-compatible with non-historical
Browse files Browse the repository at this point in the history
Fixes #1186
  • Loading branch information
mjsir911 committed Aug 18, 2023
1 parent b5eadd5 commit 9896652
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
29 changes: 29 additions & 0 deletions simple_history/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ def __get__(self, instance, owner):
return HistoryManager.from_queryset(HistoricalQuerySet)(self.model, instance)


class HistoryManyToManyDescriptor:
def __init__(self, model, rel):
self.rel = rel
self.model = model

def __get__(self, instance, owner):
return HistoryManyRelatedManager.from_queryset(QuerySet)(
self.model, self.rel, instance
)


class HistoryManyRelatedManager(models.Manager):
def __init__(self, through, rel, instance=None):
super().__init__()
self.model = rel.model
self.through = through
self.instance = instance
self._m2m_through_field_name = rel.field.m2m_reverse_field_name()

def get_queryset(self):
qs = super().get_queryset()
through_qs = HistoryManager.from_queryset(HistoricalQuerySet)(
self.through, self.instance
)
return qs.filter(
pk__in=through_qs.all().values_list(self._m2m_through_field_name, flat=True)
)


class HistoryManager(models.Manager):
def __init__(self, model, instance=None):
super().__init__()
Expand Down
8 changes: 6 additions & 2 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from simple_history import utils

from . import exceptions
from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor
from .manager import (
SIMPLE_HISTORY_REVERSE_ATTR_NAME,
HistoryDescriptor,
HistoryManyToManyDescriptor,
)
from .signals import (
post_create_historical_m2m_records,
post_create_historical_record,
Expand Down Expand Up @@ -227,7 +231,7 @@ def finalize(self, sender, **kwargs):

setattr(module, m2m_model.__name__, m2m_model)

m2m_descriptor = HistoryDescriptor(m2m_model)
m2m_descriptor = HistoryManyToManyDescriptor(m2m_model, field.remote_field)
setattr(history_model, field.name, m2m_descriptor)

def get_history_model_name(self, model):
Expand Down
26 changes: 13 additions & 13 deletions simple_history/tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,12 +1788,12 @@ def test_separation(self):
self.assertEqual(book.restaurants.all().count(), 0)
self.assertEqual(book.books.all().count(), 1)
self.assertEqual(book.places.all().count(), 1)
self.assertEqual(book.books.first().book, self.book)
self.assertEqual(book.books.first(), self.book)

self.assertEqual(place.restaurants.all().count(), 0)
self.assertEqual(place.books.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.restaurants.all().count(), 0)
self.assertEqual(add.books.all().count(), 0)
Expand Down Expand Up @@ -1829,11 +1829,11 @@ def test_separation(self):

self.assertEqual(book.books.all().count(), 1)
self.assertEqual(book.places.all().count(), 1)
self.assertEqual(book.books.first().book, self.book)
self.assertEqual(book.books.first(), self.book)

self.assertEqual(place.books.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.books.all().count(), 0)
self.assertEqual(add.places.all().count(), 0)
Expand All @@ -1842,11 +1842,11 @@ def test_separation(self):

self.assertEqual(restaurant.restaurants.all().count(), 1)
self.assertEqual(restaurant.places.all().count(), 1)
self.assertEqual(restaurant.restaurants.first().restaurant, self.restaurant)
self.assertEqual(restaurant.restaurants.first(), self.restaurant)

self.assertEqual(place.restaurants.all().count(), 0)
self.assertEqual(place.places.all().count(), 1)
self.assertEqual(place.places.first().place, self.place)
self.assertEqual(place.places.first(), self.place)

self.assertEqual(add.restaurants.all().count(), 0)
self.assertEqual(add.places.all().count(), 0)
Expand Down Expand Up @@ -1964,7 +1964,7 @@ def test_create(self):

# And the historical place is the correct one
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_remove(self):
# Add and remove a many-to-many child
Expand All @@ -1984,7 +1984,7 @@ def test_remove(self):

# And the previous row still has the correct one
historical_place = previous_m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_clear(self):
# Add some places
Expand Down Expand Up @@ -2036,7 +2036,7 @@ def test_delete_child(self):
# Place instance cannot be created...
historical_place = m2m_record.places.first()
with self.assertRaises(ObjectDoesNotExist):
historical_place.place.id
historical_place.id

# But the values persist
historical_place_values = m2m_record.places.all().values()[0]
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def test_delete_parent(self):

# And it is the correct one
historical_place = prev_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_update_child(self):
self.poll.places.add(self.place)
Expand All @@ -2084,7 +2084,7 @@ def test_update_child(self):
m2m_record = self.poll.history.all()[0]
self.assertEqual(m2m_record.places.count(), 1)
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place.name, "Updated")
self.assertEqual(historical_place.name, "Updated")

def test_update_parent(self):
self.poll.places.add(self.place)
Expand All @@ -2102,7 +2102,7 @@ def test_update_parent(self):
m2m_record = self.poll.history.all()[0]
self.assertEqual(m2m_record.places.count(), 1)
historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_bulk_add_remove(self):
# Add some places
Expand Down Expand Up @@ -2134,7 +2134,7 @@ def test_bulk_add_remove(self):
self.assertEqual(m2m_record.places.all().count(), 1)

historical_place = m2m_record.places.first()
self.assertEqual(historical_place.place, self.place)
self.assertEqual(historical_place, self.place)

def test_m2m_relation(self):
# Ensure only the correct M2Ms are saved and returned for history objects
Expand Down

0 comments on commit 9896652

Please sign in to comment.