Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a context variable to pass Schema context #2707

Merged
merged 25 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
07c9c18
Add CONTEXT context variable
lafrech Dec 30, 2024
4f4e048
Expose Context context manager
lafrech Dec 30, 2024
a5003b8
Remove Schema.context and Field.context
lafrech Dec 30, 2024
7bb901f
Expose context as field/schema property
lafrech Dec 31, 2024
a454a88
Make current_context a Context class attribute
lafrech Jan 1, 2025
0f285ed
Don't provide None as default context
lafrech Jan 1, 2025
feffc2e
Fix Function field docstring about context
lafrech Jan 1, 2025
6fac57d
Allow passing a default to Context.get
lafrech Jan 1, 2025
1a4eec7
Never pass context to functions in Function field
lafrech Jan 1, 2025
4447c07
Remove utils.get_func_args
lafrech Jan 2, 2025
72de755
Make _CURRENT_CONTEXT a module-level attribute
lafrech Jan 2, 2025
17bd038
Move Context into experimental
lafrech Jan 2, 2025
63abfc1
Add typing to context.py
sloria Jan 2, 2025
c6c4e88
Add tests for decorated processors with context
lafrech Jan 3, 2025
c7d0bca
Merge branch '4.0' into context
lafrech Jan 3, 2025
947de51
Update documentation about removal of context
lafrech Jan 4, 2025
5b10b84
Update versionchanged in docstrings
lafrech Jan 4, 2025
318cae0
Update changelog about Context
lafrech Jan 4, 2025
e638af3
Context: initialize token at __init__
lafrech Jan 4, 2025
5e485cb
Merge branch '4.0' into context
sloria Jan 5, 2025
5604959
Minor edit to upgrading guide
sloria Jan 5, 2025
c89c15a
Add more documentation for Context
sloria Jan 5, 2025
f1cbe27
More complete examples
sloria Jan 5, 2025
63d46aa
Exemplify using type aliases for Context
sloria Jan 5, 2025
5edd8b5
Merge branch '4.0' into context
lafrech Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/marshmallow/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Experimental features.

The features in this subpackage are experimental. Breaking changes may be
introduced in minor marshmallow versions.
"""
25 changes: 25 additions & 0 deletions src/marshmallow/experimental/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Helper API for setting serialization/deserialization context."""

import contextlib
import contextvars
import typing

_T = typing.TypeVar("_T")
_CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context")


class Context(contextlib.AbstractContextManager, typing.Generic[_T]):
def __init__(self, context: _T) -> None:
self.context = context

lafrech marked this conversation as resolved.
Show resolved Hide resolved
def __enter__(self) -> None:
self.token = _CURRENT_CONTEXT.set(self.context)

def __exit__(self, *args, **kwargs) -> None:
_CURRENT_CONTEXT.reset(self.token)

@classmethod
def get(cls, default=...) -> _T:
if default is not ...:
return _CURRENT_CONTEXT.get(default)
return _CURRENT_CONTEXT.get()
29 changes: 4 additions & 25 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,6 @@ def _deserialize(
"""
return value

# Properties

@property
def context(self):
"""The context dictionary for the parent :class:`Schema`."""
return self.parent.context


class Raw(Field):
"""Field that applies no formatting."""
Expand Down Expand Up @@ -498,8 +491,6 @@ def schema(self):
Renamed from `serializer` to `schema`.
"""
if not self._schema:
# Inherit context from parent.
context = getattr(self.parent, "context", {})
if callable(self.nested) and not isinstance(self.nested, type):
nested = self.nested()
else:
Expand All @@ -512,7 +503,6 @@ def schema(self):

if isinstance(nested, SchemaABC):
self._schema = copy.copy(nested)
self._schema.context.update(context)
# Respect only and exclude passed from parent and re-initialize fields
set_class = self._schema.set_class
if self.only is not None:
Expand All @@ -539,7 +529,6 @@ def schema(self):
many=self.many,
only=self.only,
exclude=self.exclude,
context=context,
load_only=self._nested_normalized_option("load_only"),
dump_only=self._nested_normalized_option("dump_only"),
)
Expand Down Expand Up @@ -1909,14 +1898,12 @@ class Function(Field):

:param serialize: A callable from which to retrieve the value.
The function must take a single argument ``obj`` which is the object
to be serialized. It can also optionally take a ``context`` argument,
which is a dictionary of context variables passed to the serializer.
to be serialized.
If no callable is provided then the ```load_only``` flag will be set
to True.
:param deserialize: A callable from which to retrieve the value.
The function must take a single argument ``value`` which is the value
to be deserialized. It can also optionally take a ``context`` argument,
which is a dictionary of context variables passed to the deserializer.
to be deserialized.
If no callable is provided then ```value``` will be passed through
unchanged.

Expand Down Expand Up @@ -1951,21 +1938,13 @@ def __init__(
self.deserialize_func = deserialize and utils.callable_or_raise(deserialize)

def _serialize(self, value, attr, obj, **kwargs):
return self._call_or_raise(self.serialize_func, obj, attr)
return self.serialize_func(obj)

def _deserialize(self, value, attr, data, **kwargs):
if self.deserialize_func:
return self._call_or_raise(self.deserialize_func, value, attr)
return self.deserialize_func(value)
return value

def _call_or_raise(self, func, value, attr):
if len(utils.get_func_args(func)) > 1:
lafrech marked this conversation as resolved.
Show resolved Hide resolved
if self.parent.context is None:
msg = f"No context available for Function field {attr!r}"
raise ValidationError(msg)
return func(value, self.parent.context)
return func(value)


class Constant(Field):
"""A field that (de)serializes to a preset constant. If you only want the
Expand Down
4 changes: 0 additions & 4 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ class AlbumSchema(Schema):
delimiters.
:param many: Should be set to `True` if ``obj`` is a collection
so that the object will be serialized to a list.
:param context: Optional context passed to :class:`fields.Method` and
:class:`fields.Function` fields.
:param load_only: Fields to skip during serialization (write-only fields)
:param dump_only: Fields to skip during deserialization (read-only fields)
:param partial: Whether to ignore missing fields and not require
Expand Down Expand Up @@ -346,7 +344,6 @@ def __init__(
only: types.StrSequenceOrSet | None = None,
exclude: types.StrSequenceOrSet = (),
many: bool | None = None,
context: dict | None = None,
load_only: types.StrSequenceOrSet = (),
dump_only: types.StrSequenceOrSet = (),
partial: bool | types.StrSequenceOrSet | None = None,
Expand All @@ -373,7 +370,6 @@ def __init__(
if unknown is None
else validate_unknown_parameter_value(unknown)
)
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
self.fields = {} # type: dict[str, ma_fields.Field]
Expand Down
16 changes: 0 additions & 16 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import datetime as dt
import functools
import inspect
import typing
from collections.abc import Mapping
Expand Down Expand Up @@ -242,21 +241,6 @@ def _signature(func: typing.Callable) -> list[str]:
return list(inspect.signature(func).parameters.keys())


def get_func_args(func: typing.Callable) -> list[str]:
"""Given a callable, return a list of argument names. Handles
`functools.partial` objects and class-based callables.

.. versionchanged:: 3.0.0a1
Do not return bound arguments, eg. ``self``.
"""
if inspect.isfunction(func) or inspect.ismethod(func):
return _signature(func)
if isinstance(func, functools.partial):
return _signature(func.func)
# Callable class
return _signature(func)


def resolve_field_instance(cls_or_instance):
"""Return a Schema instance from a Schema class or instance.

Expand Down
Loading
Loading