Skip to content

Commit

Permalink
prevent the creation of embedded models
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 19, 2025
1 parent 7dd117f commit 8df7cfe
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 15 deletions.
11 changes: 11 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs):
super().__init__(*args, **kwargs)

def check(self, **kwargs):
from ..models import EmbeddedModel

errors = super().check(**kwargs)
if not issubclass(self.embedded_model, EmbeddedModel):
return [
checks.Error(
"Embedded model must be a subclass of "
"django_mongodb_backend.models.EmbeddedModel.",
obj=self,
id="django_mongodb_backend.embedded_model.E002",
)
]
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
Expand Down
41 changes: 41 additions & 0 deletions django_mongodb_backend/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
from django.db import NotSupportedError
from django.db.models.manager import BaseManager

from .queryset import MongoQuerySet


class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
pass


class EmbeddedModelManager(BaseManager):
"""
Prevent all queryset operations on embedded models since they don't have
their own collection.
"""

def get_queryset(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def all(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get_or_create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def filter(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def bulk_create(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def bulk_update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def update_or_create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated or created.")

def delete(self):
raise NotSupportedError("EmbeddedModels cannot be deleted.")
16 changes: 16 additions & 0 deletions django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import NotSupportedError, models

from .managers import EmbeddedModelManager


class EmbeddedModel(models.Model):
objects = EmbeddedModelManager()

class Meta:
abstract = True

def delete(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be deleted.")

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")
27 changes: 27 additions & 0 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
from .utils import OperationCollector


def ignore_embedded_models(func):
"""Make a SchemaEditor a no-op if model is an EmbeddedModel."""

def wrapper(self, model, *args, **kwargs):
# If parent_model isn't None, this is a valid recursive operation.
parent_model = kwargs.get("parent_model")
from .models import EmbeddedModel

if parent_model is None and issubclass(model, EmbeddedModel):
return
func(self, model, *args, **kwargs)

return wrapper


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def get_collection(self, name):
if self.collect_sql:
Expand All @@ -22,6 +37,7 @@ def get_database(self):
return self.connection.get_database()

@wrap_database_errors
@ignore_embedded_models
def create_model(self, model):
self.get_database().create_collection(model._meta.db_table)
self._create_model_indexes(model)
Expand Down Expand Up @@ -75,13 +91,15 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model)

@ignore_embedded_models
def delete_model(self, model):
# Delete implicit M2m tables.
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
self.get_collection(model._meta.db_table).drop()

@ignore_embedded_models
def add_field(self, model, field):
# Create implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand All @@ -103,6 +121,7 @@ def add_field(self, model, field):
elif self._field_should_have_unique(field):
self._add_field_unique(model, field)

@ignore_embedded_models
def _alter_field(
self,
model,
Expand Down Expand Up @@ -149,6 +168,7 @@ def _alter_field(
if not old_field_unique and new_field_unique:
self._add_field_unique(model, new_field)

@ignore_embedded_models
def remove_field(self, model, field):
# Remove implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand Down Expand Up @@ -210,6 +230,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.remove_index(parent_model or model, index)

@ignore_embedded_models
def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""):
olds = {tuple(fields) for fields in old_index_together}
news = {tuple(fields) for fields in new_index_together}
Expand All @@ -222,6 +243,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co
for field_names in news.difference(olds):
self._add_composed_index(model, field_names, column_prefix=column_prefix)

@ignore_embedded_models
def alter_unique_together(
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -249,6 +271,7 @@ def alter_unique_together(
model, constraint, parent_model=parent_model, column_prefix=column_prefix
)

@ignore_embedded_models
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -302,6 +325,7 @@ def _add_field_index(self, model, field, *, column_prefix=""):
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
self.add_index(model, index, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_index(self, model, index):
if index.contains_expressions:
return
Expand Down Expand Up @@ -355,6 +379,7 @@ def _remove_field_index(self, model, field, column_prefix=""):
)
collection.drop_index(index_names[0])

@ignore_embedded_models
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -384,6 +409,7 @@ def _add_field_unique(self, model, field, column_prefix=""):
constraint = UniqueConstraint(fields=[field.name], name=name)
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -417,6 +443,7 @@ def _remove_field_unique(self, model, field, column_prefix=""):
)
self.get_collection(model._meta.db_table).drop_index(constraint_names[0])

@ignore_embedded_models
def alter_db_table(self, model, old_db_table, new_db_table):
if old_db_table == new_db_table:
return
Expand Down
3 changes: 2 additions & 1 deletion docs/source/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ The basics
Let's consider this example::

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Customer(models.Model):
name = models.CharField(...)
address = EmbeddedModelField("Address")
...

class Address(models.Model):
class Address(EmbeddedModel):
...
city = models.CharField(...)

Expand Down
7 changes: 5 additions & 2 deletions docs/source/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ Stores a model of type ``embedded_model``.

Specifies the model class to embed. It can be either a concrete model
class or a :ref:`lazy reference <lazy-relationships>` to a model class.
The target model must be a subclass of
``django_mongodb_backend.models.EmbeddedModel``.

The embedded model cannot have relational fields
(:class:`~django.db.models.ForeignKey`,
Expand All @@ -234,11 +236,12 @@ Stores a model of type ``embedded_model``.

from django.db import models
from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Address(models.Model):
class Address(EmbeddedModel):
...

class Author(models.Model):
class Author(EmbeddedModel):
address = EmbeddedModelField(Address)

class Book(models.Model):
Expand Down
Empty file added docs/source/models.rst
Empty file.
7 changes: 4 additions & 3 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
from django_mongodb_backend.models import EmbeddedModel


# ObjectIdField
Expand Down Expand Up @@ -98,19 +99,19 @@ class Holder(models.Model):
data = EmbeddedModelField("Data", null=True, blank=True)


class Data(models.Model):
class Data(EmbeddedModel):
integer = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)


class Address(models.Model):
class Address(EmbeddedModel):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)


class Author(models.Model):
class Author(EmbeddedModel):
name = models.CharField(max_length=10)
age = models.IntegerField()
address = EmbeddedModelField(Address)
Expand Down
18 changes: 17 additions & 1 deletion tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.test.utils import isolate_apps

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

from .models import (
Address,
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_nested(self):
@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
class Target(models.Model):
class Target(EmbeddedModel):
key = models.ForeignKey("MyModel", models.CASCADE)

class MyModel(models.Model):
Expand All @@ -121,3 +122,18 @@ class MyModel(models.Model):
self.assertEqual(
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
)

def test_embedded_model_subclass(self):
class Target(models.Model):
pass

class MyModel(models.Model):
field = EmbeddedModelField(Target)

errors = MyModel().check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002")
msg = errors[0].msg
self.assertEqual(
msg, "Embedded model must be a subclass of django_mongodb_backend.models.EmbeddedModel."
)
Empty file added tests/models_/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/models_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django_mongodb_backend.models import EmbeddedModel


class Embed(EmbeddedModel):
pass
59 changes: 59 additions & 0 deletions tests/models_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from django.db import NotSupportedError
from django.test import SimpleTestCase

from .models import Embed


class TestMethods(SimpleTestCase):
def test_save(self):
e = Embed()
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be saved."):
e.save()

def test_delete(self):
e = Embed()
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."):
e.delete()


class TestManagerMethods(SimpleTestCase):
def test_all(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
Embed.objects.all()

def test_get(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
Embed.objects.get()

def test_get_or_create(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
Embed.objects.get_or_create()

def test_filter(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
Embed.objects.filter(foo="bar")

def test_create(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."):
Embed.objects.create()

def test_bulk_create(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."):
Embed.objects.bulk_create()

def test_update(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."):
Embed.objects.update(foo="bar")

def test_bulk_update(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."):
Embed.objects.bulk_update()

def test_update_or_create(self):
msg = "EmbeddedModels cannot be updated or created."
with self.assertRaisesMessage(NotSupportedError, msg):
Embed.objects.update_or_create()

def test_delete(self):
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."):
Embed.objects.delete()
Loading

0 comments on commit 8df7cfe

Please sign in to comment.