From 55decceb7f15a7cbf56e351e7af3e191ea816880 Mon Sep 17 00:00:00 2001 From: Marco Sirabella Date: Fri, 9 Jun 2023 00:38:10 -0700 Subject: [PATCH] Make historical records' m2m fields type-compatible with non-historical Fixes #1186 --- simple_history/manager.py | 29 +++++++++++++++++++++++++++++ simple_history/models.py | 8 ++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/simple_history/manager.py b/simple_history/manager.py index e91b2491b..48549561d 100644 --- a/simple_history/manager.py +++ b/simple_history/manager.py @@ -127,6 +127,35 @@ def __get__(self, instance, owner): return HistoryManager.from_queryset(HistoricalQuerySet)(self.model, instance) +class HistoryManyToManyDescriptor: + def __init__(self, model, rel): + self.rel = rel + self.model = model + + def __get__(self, instance, owner): + return HistoryManyRelatedManager.from_queryset(QuerySet)( + self.model, self.rel, instance + ) + + +class HistoryManyRelatedManager(models.Manager): + def __init__(self, through, rel, instance=None): + super().__init__() + self.model = rel.model + self.through = through + self.instance = instance + self._m2m_through_field_name = rel.field.m2m_reverse_field_name() + + def get_queryset(self): + qs = super().get_queryset() + through_qs = HistoryManager.from_queryset(HistoricalQuerySet)( + self.through, self.instance + ) + return qs.filter( + pk__in=through_qs.all().values_list(self._m2m_through_field_name, flat=True) + ) + + class HistoryManager(models.Manager): def __init__(self, model, instance=None): super().__init__() diff --git a/simple_history/models.py b/simple_history/models.py index db19c66fa..6900e3c1c 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -31,7 +31,11 @@ from simple_history import utils from . import exceptions -from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor +from .manager import ( + SIMPLE_HISTORY_REVERSE_ATTR_NAME, + HistoryDescriptor, + HistoryManyToManyDescriptor, +) from .signals import ( post_create_historical_m2m_records, post_create_historical_record, @@ -227,7 +231,7 @@ def finalize(self, sender, **kwargs): setattr(module, m2m_model.__name__, m2m_model) - m2m_descriptor = HistoryDescriptor(m2m_model) + m2m_descriptor = HistoryManyToManyDescriptor(m2m_model, field.remote_field) setattr(history_model, field.name, m2m_descriptor) def get_history_model_name(self, model):