Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load accepts Sequence rather than Iterable (rejects generators) #2795

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ValidationError,
_FieldInstanceResolutionError,
)
from marshmallow.utils import is_aware, is_collection
from marshmallow.validate import And, Length

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -501,9 +500,9 @@ def __init__(
**kwargs: Unpack[_BaseFieldKwargs],
):
# Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
if only is not None and not utils.is_sequence_but_not_string(only):
raise StringNotCollectionError('"only" should be a collection of strings.')
if not is_collection(exclude):
if not utils.is_sequence_but_not_string(exclude):
raise StringNotCollectionError(
'"exclude" should be a collection of strings.'
)
Expand Down Expand Up @@ -818,7 +817,7 @@ def _deserialize(
data: typing.Mapping[str, typing.Any] | None,
**kwargs,
) -> tuple:
if not utils.is_collection(value):
if not utils.is_sequence_but_not_string(value):
raise self.make_error("invalid")

self.validate_length(value)
Expand Down Expand Up @@ -1322,7 +1321,7 @@ def __init__(

def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
ret = super()._deserialize(value, attr, data, **kwargs)
if is_aware(ret):
if utils.is_aware(ret):
if self.timezone is None:
raise self.make_error(
"invalid_awareness",
Expand Down Expand Up @@ -1359,7 +1358,7 @@ def __init__(

def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
ret = super()._deserialize(value, attr, data, **kwargs)
if not is_aware(ret):
if not utils.is_aware(ret):
if self.default_timezone is None:
raise self.make_error(
"invalid_awareness",
Expand Down
37 changes: 15 additions & 22 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid
from abc import ABCMeta
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Mapping, Sequence

from marshmallow import class_registry, types
from marshmallow import fields as ma_fields
Expand All @@ -30,7 +30,12 @@
from marshmallow.error_store import ErrorStore
from marshmallow.exceptions import SCHEMA, StringNotCollectionError, ValidationError
from marshmallow.orderedset import OrderedSet
from marshmallow.utils import get_value, is_collection, set_value
from marshmallow.utils import (
get_value,
is_collection,
is_sequence_but_not_string,
set_value,
)

if typing.TYPE_CHECKING:
from marshmallow.fields import Field
Expand Down Expand Up @@ -582,10 +587,7 @@ def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):

def _deserialize(
self,
data: (
typing.Mapping[str, typing.Any]
| typing.Iterable[typing.Mapping[str, typing.Any]]
),
data: Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]],
*,
error_store: ErrorStore,
many: bool = False,
Expand All @@ -612,13 +614,13 @@ def _deserialize(
index_errors = self.opts.index_errors
index = index if index_errors else None
if many:
if not is_collection(data):
if not is_sequence_but_not_string(data):
error_store.store_error([self.error_messages["type"]], index=index)
ret_l = []
else:
ret_l = [
self._deserialize(
typing.cast(dict, d),
d,
error_store=error_store,
many=False,
partial=partial,
Expand Down Expand Up @@ -696,10 +698,7 @@ def getter(

def load(
self,
data: (
typing.Mapping[str, typing.Any]
| typing.Iterable[typing.Mapping[str, typing.Any]]
),
data: Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]],
*,
many: bool | None = None,
partial: bool | types.StrSequenceOrSet | None = None,
Expand Down Expand Up @@ -807,10 +806,7 @@ def _run_validator(

def validate(
self,
data: (
typing.Mapping[str, typing.Any]
| typing.Iterable[typing.Mapping[str, typing.Any]]
),
data: Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]],
*,
many: bool | None = None,
partial: bool | types.StrSequenceOrSet | None = None,
Expand All @@ -837,10 +833,7 @@ def validate(

def _do_load(
self,
data: (
typing.Mapping[str, typing.Any]
| typing.Iterable[typing.Mapping[str, typing.Any]]
),
data: (Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]]),
*,
many: bool | None = None,
partial: bool | types.StrSequenceOrSet | None = None,
Expand Down Expand Up @@ -1092,7 +1085,7 @@ def _invoke_dump_processors(
def _invoke_load_processors(
self,
tag: str,
data,
data: Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]],
*,
many: bool,
original_data,
Expand Down Expand Up @@ -1216,7 +1209,7 @@ def _invoke_processors(
tag: str,
*,
pass_collection: bool,
data,
data: Mapping[str, typing.Any] | Sequence[Mapping[str, typing.Any]],
many: bool,
original_data=None,
**kwargs,
Expand Down
19 changes: 15 additions & 4 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,33 @@
import datetime as dt
import inspect
import typing
from collections.abc import Mapping
from collections.abc import Mapping, Sequence

# Remove when we drop Python 3.9
try:
from typing import TypeGuard
except ImportError:
from typing_extensions import TypeGuard

from marshmallow.constants import missing


def is_generator(obj) -> bool:
def is_generator(obj) -> TypeGuard[typing.Generator]:
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)


def is_iterable_but_not_string(obj) -> bool:
def is_iterable_but_not_string(obj) -> TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is an iterable object that isn't a string."""
return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)


def is_collection(obj) -> bool:
def is_sequence_but_not_string(obj) -> TypeGuard[Sequence]:
"""Return True if ``obj`` is a sequence that isn't a string."""
return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))


def is_collection(obj) -> TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class Sch(Schema):
assert e.value.valid_data == []


@pytest.mark.parametrize("val", ([], set()))
@pytest.mark.parametrize("val", ([], tuple()))
def test_load_many_empty_collection(val):
class Sch(Schema):
name = fields.Str()
Expand All @@ -276,7 +276,7 @@ class Outer(Schema):
}


@pytest.mark.parametrize("val", ([], set()))
@pytest.mark.parametrize("val", ([], tuple()))
def test_load_many_in_nested_empty_collection(val):
class Inner(Schema):
name = fields.String()
Expand Down
Loading