Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Transformer] Pydantic guess python type #2976

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion flytekit/extras/pydantic_transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import os
from typing import Type
from functools import lru_cache
from typing import Any, Dict, List, Type, Union

import msgpack
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from pydantic import BaseModel
from pydantic import BaseModel, Field, create_model


from flytekit import FlyteContext
from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK
Expand All @@ -16,6 +19,97 @@
from flytekit.models.literals import Binary, Literal, Scalar
from flytekit.models.types import LiteralType, TypeStructure

DEFINITIONS = "definitions"
TITLE = "title"

# Helper function to convert JSON schema to Pydantic BaseModel
# Reference: https://github.com/pydantic/pydantic/issues/643#issuecomment-1999755873
def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]:
"""
Converts a JSON schema to a Pydantic BaseModel class.

Args:
json_schema: The JSON schema to convert.

Returns:
A Pydantic BaseModel class.
"""
# Extract the model name from the schema title.
model_name = json_schema.get('title', 'DynamicModel')

Check warning on line 38 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L38

Added line #L38 was not covered by tests

# Extract the field definitions from the schema properties.
field_definitions = {

Check warning on line 41 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L41

Added line #L41 was not covered by tests
name: json_schema_to_pydantic_field(name, prop, json_schema.get('required', []))
for name, prop in json_schema.get('properties', {}).items()
}

# Create the BaseModel class using create_model().
return create_model(model_name, **field_definitions)

Check warning on line 47 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L47

Added line #L47 was not covered by tests

def json_schema_to_pydantic_field(name: str, json_schema: Dict[str, Any], required: List[str]) -> Any:
"""
Converts a JSON schema property to a Pydantic field definition.

Args:
name: The field name.
json_schema: The JSON schema property.

Returns:
A Pydantic field definition.
"""
# Get the field type.
type_ = json_schema_to_pydantic_type(json_schema)

Check warning on line 61 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L61

Added line #L61 was not covered by tests

# Get the field description.
description = json_schema.get('description', None)

Check warning on line 64 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L64

Added line #L64 was not covered by tests

# Get the field examples.
examples = json_schema.get('examples', None)

Check warning on line 67 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L67

Added line #L67 was not covered by tests

# Create a Field object with the type, description, and examples.
# The 'required' flag will be set later when creating the model.
return (type_, Field(description=description, examples=examples, default=... if name in required else None))

Check warning on line 71 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L71

Added line #L71 was not covered by tests

def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any:
"""
Converts a JSON schema type to a Pydantic type.

Args:
json_schema: The JSON schema to convert.

Returns:
A Pydantic type.
"""
type_ = json_schema.get('type')

Check warning on line 83 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L83

Added line #L83 was not covered by tests

if type_ == 'string':
return str

Check warning on line 86 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L86

Added line #L86 was not covered by tests
elif type_ == 'integer':
return int

Check warning on line 88 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L88

Added line #L88 was not covered by tests
elif type_ == 'number':
return float

Check warning on line 90 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L90

Added line #L90 was not covered by tests
elif type_ == 'boolean':
return bool

Check warning on line 92 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L92

Added line #L92 was not covered by tests
elif type_ == 'array':
items_schema = json_schema.get('items')

Check warning on line 94 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L94

Added line #L94 was not covered by tests
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
return List[item_type]

Check warning on line 97 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L96-L97

Added lines #L96 - L97 were not covered by tests
else:
return List

Check warning on line 99 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L99

Added line #L99 was not covered by tests
elif type_ == 'object':
# Handle nested models.
properties = json_schema.get('properties')

Check warning on line 102 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L102

Added line #L102 was not covered by tests
if properties:
nested_model = json_schema_to_model(json_schema)
return nested_model

Check warning on line 105 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L104-L105

Added lines #L104 - L105 were not covered by tests
else:
return Dict

Check warning on line 107 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L107

Added line #L107 was not covered by tests
elif type_ == 'null':
return Union[None, Any] # Use Union[None, Any] for nullable fields

Check warning on line 109 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L109

Added line #L109 was not covered by tests
else:
raise ValueError(f'Unsupported JSON schema type: {type_}')

Check warning on line 111 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L111

Added line #L111 was not covered by tests


class PydanticTransformer(TypeTransformer[BaseModel]):
def __init__(self):
Expand All @@ -39,6 +133,13 @@

return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts)

@lru_cache(typed=True)
def guess_python_type(self, literal_type: LiteralType) -> Type[BaseModel]: # type: ignore
if literal_type.simple == types.SimpleType.STRUCT:
if literal_type.metadata is not None:
return json_schema_to_model(literal_type.metadata)
raise ValueError(f"Pydantic transformer cannot reverse {literal_type}")

Check warning on line 141 in flytekit/extras/pydantic_transformer/transformer.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic_transformer/transformer.py#L140-L141

Added lines #L140 - L141 were not covered by tests

def to_generic_literal(
self,
ctx: FlyteContext,
Expand Down
Loading