Skip to content

Commit

Permalink
Handle cyclic references in JSON encoding and schema (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Apr 30, 2024
1 parent 25e031c commit 7d76ed6
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 82 deletions.
111 changes: 64 additions & 47 deletions logfire/_internal/json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""The maximum size of a dimension of a numpy array."""


def _bytes_encoder(o: bytes) -> str:
def _bytes_encoder(o: bytes, _seen: set[int]) -> str:
"""Encode bytes using repr() to get a string representation of the bytes object.
We remove the leading 'b' and the quotes around the string representation.
Expand All @@ -36,22 +36,30 @@ def _bytes_encoder(o: bytes) -> str:
return repr(o)[2:-1]


def _bytearray_encoder(o: bytearray) -> str:
return _bytes_encoder(bytes(o))
def _bytearray_encoder(o: bytearray, seen: set[int]) -> str:
return _bytes_encoder(bytes(o), seen)


def _set_encoder(o: set[Any]) -> JsonValue:
def _set_encoder(o: set[Any], seen: set[int]) -> JsonValue:
try:
return [to_json_value(item) for item in sorted(o)]
return [to_json_value(item, seen) for item in sorted(o)]
except TypeError:
return [to_json_value(item) for item in o]
return [to_json_value(item, seen) for item in o]


def _to_isoformat(o: Any) -> str:
def _to_isoformat(o: Any, _seen: set[int]) -> str:
return o.isoformat()


def _pandas_data_frame_encoder(o: Any) -> JsonValue:
def _to_str(o: Any, _seen: set[int]) -> str:
return str(o)


def _to_repr(o: Any, _seen: set[int]) -> str:
return repr(o)


def _pandas_data_frame_encoder(o: Any, seen: set[int]) -> JsonValue:
"""Encode pandas data frame by extracting important information.
It summarizes rows and columns if they are more than limit.
Expand Down Expand Up @@ -94,10 +102,10 @@ def _pandas_data_frame_encoder(o: Any) -> JsonValue:
else:
rows.append(list(row))

return to_json_value(rows)
return to_json_value(rows, seen)


def _numpy_array_encoder(o: Any) -> JsonValue:
def _numpy_array_encoder(o: Any, seen: set[int]) -> JsonValue:
"""Encode numpy array by extracting important information.
It summarizes rows and columns if they are more than limit.
Expand Down Expand Up @@ -138,17 +146,17 @@ def _numpy_array_encoder(o: Any) -> JsonValue:
end = o[tuple(slices)]
o = numpy.concatenate((front, end), axis=dimension)

return to_json_value(o.tolist())
return to_json_value(o.tolist(), seen)


def _pydantic_model_encoder(o: Any) -> JsonValue:
def _pydantic_model_encoder(o: Any, seen: set[int]) -> JsonValue:
import pydantic

assert isinstance(o, pydantic.BaseModel)
return to_json_value(o.model_dump())
return to_json_value(o.model_dump(), seen)


def _get_sqlalchemy_data(o: Any) -> JsonValue:
def _get_sqlalchemy_data(o: Any, seen: set[int]) -> JsonValue:
try:
from sqlalchemy import inspect as sa_inspect

Expand All @@ -158,11 +166,12 @@ def _get_sqlalchemy_data(o: Any) -> JsonValue:
deferred = set() # type: ignore

return to_json_value(
{field: getattr(o, field) if field not in deferred else '<deferred>' for field in o.__mapper__.attrs.keys()}
{field: getattr(o, field) if field not in deferred else '<deferred>' for field in o.__mapper__.attrs.keys()},
seen,
)


EncoderFunction = Callable[[Any], JsonValue]
EncoderFunction = Callable[[Any, set[int]], JsonValue]


@lru_cache(maxsize=None)
Expand All @@ -175,30 +184,30 @@ def encoder_by_type() -> dict[type[Any], EncoderFunction]:
datetime.date: _to_isoformat,
datetime.datetime: _to_isoformat,
datetime.time: _to_isoformat,
datetime.timedelta: lambda o: o.total_seconds(),
Decimal: str,
Enum: lambda o: to_json_value(o.value),
GeneratorType: repr,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
PosixPath: str,
Pattern: lambda o: to_json_value(o.pattern),
UUID: str,
Exception: str,
datetime.timedelta: lambda o, _: o.total_seconds(),
Decimal: _to_str,
Enum: lambda o, seen: to_json_value(o.value, seen),
GeneratorType: _to_repr,
IPv4Address: _to_str,
IPv4Interface: _to_str,
IPv4Network: _to_str,
IPv6Address: _to_str,
IPv6Interface: _to_str,
IPv6Network: _to_str,
PosixPath: _to_str,
Pattern: lambda o, seen: to_json_value(o.pattern, seen),
UUID: _to_str,
Exception: _to_str,
}
with contextlib.suppress(ModuleNotFoundError):
import pydantic

lookup.update(
{
pydantic.AnyUrl: str,
pydantic.NameEmail: str,
pydantic.SecretBytes: str,
pydantic.SecretStr: str,
pydantic.AnyUrl: _to_str,
pydantic.NameEmail: _to_str,
pydantic.SecretBytes: _to_str,
pydantic.SecretStr: _to_str,
pydantic.BaseModel: _pydantic_model_encoder,
}
)
Expand All @@ -215,21 +224,29 @@ def encoder_by_type() -> dict[type[Any], EncoderFunction]:
return lookup


def to_json_value(o: Any) -> JsonValue:
def to_json_value(o: Any, seen: set[int]) -> JsonValue:
try:
if isinstance(o, (int, float, str, bool, type(None))):
return o

if id(o) in seen:
return '<circular reference>'

seen.add(id(o))

if isinstance(o, (list, tuple)):
# we do list & tuple before Mapping as it's > twice as fast and just as common
return [to_json_value(item) for item in o] # type: ignore
return [to_json_value(item, seen) for item in o] # type: ignore
elif isinstance(o, Mapping):
return {key if isinstance(key, str) else safe_repr(key): to_json_value(value) for key, value in o.items()} # type: ignore
return {
key if isinstance(key, str) else safe_repr(key): to_json_value(value, seen) for key, value in o.items()
} # type: ignore
elif dataclasses.is_dataclass(o):
return {f.name: to_json_value(getattr(o, f.name)) for f in dataclasses.fields(o)}
return {f.name: to_json_value(getattr(o, f.name), seen) for f in dataclasses.fields(o)}
elif is_attrs(o):
return _get_attrs_data(o)
return _get_attrs_data(o, seen)
elif is_sqlalchemy(o):
return _get_sqlalchemy_data(o)
return _get_sqlalchemy_data(o, seen)

# Check the class type and its superclasses for a matching encoder
for base in o.__class__.__mro__[:-1]:
Expand All @@ -238,10 +255,10 @@ def to_json_value(o: Any) -> JsonValue:
except KeyError:
pass
else:
return encoder(o)
return encoder(o, seen)

if isinstance(o, Sequence):
return [to_json_value(item) for item in o] # type: ignore
return [to_json_value(item, seen) for item in o] # type: ignore
except Exception: # pragma: no cover
pass

Expand All @@ -251,11 +268,11 @@ def to_json_value(o: Any) -> JsonValue:

def logfire_json_dumps(obj: Any) -> str:
try:
return json.dumps(obj, default=to_json_value, separators=(',', ':'))
except TypeError:
return json.dumps(obj, default=lambda o: to_json_value(o, set()), separators=(',', ':'))
except Exception:
# fallback to eagerly calling to_json_value to take care of object keys which are not strings
# see https://github.com/pydantic/platform/pull/2045
return json.dumps(to_json_value(obj), separators=(',', ':'))
return json.dumps(to_json_value(obj, set()), separators=(',', ':'))


def is_sqlalchemy(obj: Any) -> bool:
Expand All @@ -278,7 +295,7 @@ def is_attrs(obj: Any) -> bool:
return False


def _get_attrs_data(o: Any) -> JsonValue:
def _get_attrs_data(o: Any, seen: set[int]) -> JsonValue:
import attrs

return {f.name: to_json_value(getattr(o, f.name)) for f in attrs.fields(o.__class__)}
return {f.name: to_json_value(getattr(o, f.name), seen) for f in attrs.fields(o.__class__)}
Loading

0 comments on commit 7d76ed6

Please sign in to comment.