diff --git a/morango/models/fields/uuids.py b/morango/models/fields/uuids.py index d268fc47..b16c098b 100644 --- a/morango/models/fields/uuids.py +++ b/morango/models/fields/uuids.py @@ -2,6 +2,7 @@ import uuid from django.db import models + from morango.utils import _assert @@ -9,60 +10,30 @@ def sha2_uuid(*args): return hashlib.sha256("::".join(args).encode("utf-8")).hexdigest()[:32] -class UUIDField(models.CharField): +class UUIDField(models.UUIDField): """ Adaptation of Django's UUIDField, but with 32-char hex representation as Python representation rather than a UUID instance. """ - def __init__(self, *args, **kwargs): - kwargs["max_length"] = 32 - super(UUIDField, self).__init__(*args, **kwargs) - - def prepare_value(self, value): - if isinstance(value, uuid.UUID): - return value.hex - return value - - def deconstruct(self): - name, path, args, kwargs = super(UUIDField, self).deconstruct() - del kwargs["max_length"] - return name, path, args, kwargs - - def get_internal_type(self): - return "UUIDField" - def get_db_prep_value(self, value, connection, prepared=False): if value is None: return None if not isinstance(value, uuid.UUID): - try: - value = uuid.UUID(value) - except AttributeError: - raise TypeError(self.error_messages["invalid"] % {"value": value}) + value = super(UUIDField, self).to_python(value) + + if connection.features.has_native_uuid_field: + return value return value.hex def from_db_value(self, value, expression, connection, context): return self.to_python(value) def to_python(self, value): - if isinstance(value, uuid.UUID): - return value.hex - return value - - def get_default(self): - """ - Returns the default value for this field. - """ - if self.has_default(): - if callable(self.default): - default = self.default() - if isinstance(default, uuid.UUID): - return default.hex - return default - if isinstance(self.default, uuid.UUID): - return self.default.hex - return self.default - return None + value = super(UUIDField, self).to_python(value) + return value.hex if isinstance(value, uuid.UUID) else value + + def value_from_object(self, obj): + return self.to_python(super(UUIDField, self).value_from_object(obj)) class UUIDModelMixin(models.Model): diff --git a/tests/testapp/tests/sync/test_operations.py b/tests/testapp/tests/sync/test_operations.py index d50b8649..c3fa015b 100644 --- a/tests/testapp/tests/sync/test_operations.py +++ b/tests/testapp/tests/sync/test_operations.py @@ -1063,9 +1063,11 @@ def setUp(self): "content_id": uuid.uuid4().hex, } - def serialize_to_store(self, Model, data): + def serialize_to_store(self, Model, data, post_serialization=None): instance = Model(**data) serialized = instance.serialize() + if post_serialization: + serialized.update(post_serialization) Store.objects.create( id=serialized["id"], serialized=json.dumps(serialized), @@ -1078,10 +1080,10 @@ def serialize_to_store(self, Model, data): model_name=instance.morango_model_name, ) - def serialize_all_to_store(self): - self.serialize_to_store(MyUser, self.serialized_user) - self.serialize_to_store(SummaryLog, self.serialized_log1) - self.serialize_to_store(SummaryLog, self.serialized_log2) + def serialize_all_to_store(self, post_serialization=None): + self.serialize_to_store(MyUser, self.serialized_user, post_serialization=getattr(post_serialization, "user", {})) + self.serialize_to_store(SummaryLog, self.serialized_log1, post_serialization=getattr(post_serialization, "log1", {})) + self.serialize_to_store(SummaryLog, self.serialized_log2, post_serialization=getattr(post_serialization, "log2", {})) def assert_deserialization(self, user_deserialized=True, log1_deserialized=True, log2_deserialized=True): assert MyUser.objects.filter(id=self.serialized_user["id"]).exists() == user_deserialized @@ -1121,15 +1123,13 @@ def test_deserialization_with_excessively_long_username(self): def test_deserialization_with_invalid_content_id(self): - self.serialized_log1["content_id"] = "invalid" - - self.serialize_all_to_store() + self.serialize_all_to_store({"log1": {"content_id": "invalid"}}) _deserialize_from_store(self.profile) self.assert_deserialization(log1_deserialized=False) - def test_deserialization_with_invalid_log_user_id(self): + def test_deserialization_with_log_non_existent_user_id(self): self.serialized_log1["user_id"] = uuid.uuid4().hex