Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Make ref-getting more pleasant #3362

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
40 changes: 40 additions & 0 deletions tests/trace/test_uri_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

import weave


@pytest.fixture(
params=[
"dataset",
# "evaluation",
# "string_prompt",
# "messages_prompt",
# "easy_prompt",
]
)
def obj(request):
examples = [
{"question": "What is 2+2?", "expected": "4"},
{"question": "What is 3+3?", "expected": "6"},
]

if request.param == "dataset":
return weave.Dataset(rows=examples)
elif request.param == "evaluation":
return weave.Evaluation(dataset=examples)
elif request.param == "string_prompt":
return weave.StringPrompt("Hello, world!")
elif request.param == "messages_prompt":
return weave.MessagesPrompt([{"role": "user", "content": "Hello, world!"}])
elif request.param == "easy_prompt":
return weave.EasyPrompt("Hello world!")


def test_uri_get(client, obj):
ref = weave.publish(obj)

obj_cls = type(obj)
obj2 = obj_cls.from_uri(ref.uri())

for field_name in obj.model_fields:
assert getattr(obj, field_name) == getattr(obj2, field_name)
9 changes: 8 additions & 1 deletion weave/flow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any

from pydantic import field_validator
from typing_extensions import Self

import weave
from weave.flow.obj import Object
from weave.trace.vals import WeaveTable
from weave.trace import objectify
from weave.trace.vals import WeaveObject, WeaveTable


def short_str(obj: Any, limit: int = 25) -> str:
Expand All @@ -15,6 +17,7 @@ def short_str(obj: Any, limit: int = 25) -> str:
return str_val


@objectify.register
class Dataset(Object):
"""
Dataset object with easy saving and automatic versioning
Expand Down Expand Up @@ -42,6 +45,10 @@ class Dataset(Object):

rows: weave.Table

@classmethod
def from_obj(cls, obj: WeaveObject) -> Self:
return cls(rows=obj.rows)

@field_validator("rows", mode="before")
def convert_to_table(cls, rows: Any) -> weave.Table:
if not isinstance(rows, weave.Table):
Expand Down
13 changes: 13 additions & 0 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import PrivateAttr, model_validator
from rich import print
from rich.console import Console
from typing_extensions import Self

import weave
from weave.flow import util
Expand All @@ -27,6 +28,7 @@
get_scorer_attributes,
transpose,
)
from weave.trace import objectify
from weave.trace.env import get_weave_parallelism
from weave.trace.errors import OpCallError
from weave.trace.isinstance import weave_isinstance
Expand Down Expand Up @@ -57,6 +59,7 @@ class EvaluationResults(Object):
ScorerLike = Union[Callable, Op, Scorer]


@objectify.register
class Evaluation(Object):
"""
Sets up an evaluation which includes a set of scorers and a dataset.
Expand Down Expand Up @@ -114,6 +117,16 @@ def function_to_evaluate(question: str):
# internal attr to track whether to use the new `output` or old `model_output` key for outputs
_output_key: Literal["output", "model_output"] = PrivateAttr("output")

@classmethod
def from_obj(cls, obj: WeaveObject) -> Self:
return cls(
dataset=obj.dataset,
scorers=obj.scorers,
preprocess_model_input=obj.preprocess_model_input,
trials=obj.trials,
evaluation_name=obj.evaluation_name,
)

@model_validator(mode="after")
def _update_display_name(self) -> "Evaluation":
if self.evaluation_name:
Expand Down
8 changes: 8 additions & 0 deletions weave/flow/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ValidatorFunctionWrapHandler,
model_validator,
)
from typing_extensions import Self

from weave.trace.op import ObjectRef, Op
from weave.trace.vals import WeaveObject, pydantic_getattribute
Expand Down Expand Up @@ -51,6 +52,13 @@ class Object(BaseModel):

__str__ = BaseModel.__repr__

@classmethod
def from_uri(cls, uri: str) -> Self:
"""It's up to the subclass to implement this!

Using the @objectify.register decorator will also implement this automatically."""
raise NotImplementedError

# This is a "wrap" validator meaning we can run our own logic before
# and after the standard pydantic validation.
@model_validator(mode="wrap")
Expand Down
16 changes: 11 additions & 5 deletions weave/flow/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@

from pydantic import Field
from rich.table import Table
from typing_extensions import Self

from weave.flow.obj import Object
from weave.flow.prompt.common import ROLE_COLORS, color_role
from weave.trace import objectify
from weave.trace.api import publish as weave_publish
from weave.trace.op import op
from weave.trace.refs import ObjectRef
from weave.trace.rich import pydantic_util
from weave.trace.vals import WeaveObject


class Message(TypedDict):
Expand Down Expand Up @@ -76,6 +79,7 @@ def format(self, **kwargs: Any) -> Any:
raise NotImplementedError("Subclasses must implement format()")


@objectify.register
class StringPrompt(Prompt):
content: str = ""

Expand All @@ -87,13 +91,14 @@ def format(self, **kwargs: Any) -> str:
return self.content.format(**kwargs)

@classmethod
def from_obj(cls, obj: Any) -> "StringPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(content=obj.content)
prompt.name = obj.name
prompt.description = obj.description
return prompt


@objectify.register
class MessagesPrompt(Prompt):
messages: list[dict] = Field(default_factory=list)

Expand All @@ -114,13 +119,14 @@ def format(self, **kwargs: Any) -> list:
return [self.format_message(m, **kwargs) for m in self.messages]

@classmethod
def from_obj(cls, obj: Any) -> "MessagesPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
prompt = cls(messages=obj.messages)
prompt.name = obj.name
prompt.description = obj.description
return prompt


@objectify.register
class EasyPrompt(UserList, Prompt):
data: list = Field(default_factory=list)
config: dict = Field(default_factory=dict)
Expand Down Expand Up @@ -420,7 +426,7 @@ def as_dict(self) -> dict[str, Any]:
}

@classmethod
def from_obj(cls, obj: Any) -> "EasyPrompt":
def from_obj(cls, obj: WeaveObject) -> Self:
messages = obj.messages if hasattr(obj, "messages") else obj.data
messages = [dict(m) for m in messages]
config = dict(obj.config)
Expand All @@ -434,7 +440,7 @@ def from_obj(cls, obj: Any) -> "EasyPrompt":
)

@staticmethod
def load(fp: IO) -> "EasyPrompt":
def load(fp: IO) -> Self:
if isinstance(fp, str): # Common mistake
raise TypeError(
"Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?"
Expand All @@ -444,7 +450,7 @@ def load(fp: IO) -> "EasyPrompt":
return prompt

@staticmethod
def load_file(filepath: Union[str, Path]) -> "Prompt":
def load_file(filepath: Union[str, Path]) -> Self:
expanded_path = os.path.expanduser(str(filepath))
with open(expanded_path) as f:
return EasyPrompt.load(f)
Expand Down
46 changes: 46 additions & 0 deletions weave/trace/objectify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
from weave.trace.vals import WeaveObject

_registry: dict[str, type] = {}


T = TypeVar("T")


def register(cls: type[T]) -> type[T]:
"""Decorator to register a class with the objectify function.

Registered classes will be able to be deserialized directly into their base objects
instead of into a WeaveObject."""
_registry[cls.__name__] = cls

@classmethod
def from_uri(cls: type[T], uri: str) -> T:
import weave

obj = weave.ref(uri).get()
if isinstance(obj, cls):
return obj
return cls(**obj.model_dump())

cls.from_uri = from_uri

return cls


def objectify(obj: WeaveObject) -> Any:
if not (cls_name := getattr(obj, "_class_name", None)):
return obj

if cls_name not in _registry:
return obj

cls = _registry[cls_name]
if hasattr(cls, "from_uri"):
return cls.from_uri(obj.ref.uri())

return obj
33 changes: 4 additions & 29 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,32 +162,7 @@ def uri(self) -> str:
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u

def objectify(self, obj: Any) -> Any:
"""Convert back to higher level object."""
class_name = getattr(obj, "_class_name", None)
if "EasyPrompt" == class_name:
from weave.flow.prompt.prompt import EasyPrompt

prompt = EasyPrompt.from_obj(obj)
# We want to use the ref on the object (and not self) as it will have had
# version number or latest alias resolved to a specific digest.
prompt.__dict__["ref"] = obj.ref
return prompt
if "StringPrompt" == class_name:
from weave.flow.prompt.prompt import StringPrompt

prompt = StringPrompt.from_obj(obj)
prompt.__dict__["ref"] = obj.ref
return prompt
if "MessagesPrompt" == class_name:
from weave.flow.prompt.prompt import MessagesPrompt

prompt = MessagesPrompt.from_obj(obj)
prompt.__dict__["ref"] = obj.ref
return prompt
return obj

def get(self) -> Any:
def get(self, *, objectify: bool = True) -> Any:
# Move import here so that it only happens when the function is called.
# This import is invalid in the trace server and represents a dependency
# that should be removed.
Expand All @@ -196,7 +171,7 @@ def get(self) -> Any:

gc = get_weave_client()
if gc is not None:
return self.objectify(gc.get(self))
return gc.get(self, objectify=objectify)

# Special case: If the user is attempting to fetch an object but has not
# yet initialized the client, we can initialize a client to
Expand All @@ -206,10 +181,10 @@ def get(self) -> Any:
f"{self.entity}/{self.project}", ensure_project_exists=False
)
try:
res = init_client.client.get(self)
res = init_client.client.get(self, objectify=objectify)
finally:
init_client.reset()
return self.objectify(res)
return res

def is_descended_from(self, potential_ancestor: ObjectRef) -> bool:
if self.entity != potential_ancestor.entity:
Expand Down
9 changes: 6 additions & 3 deletions weave/trace/vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import v1 as pydantic_v1

from weave.trace import box
from weave.trace import objectify as objectify_module
from weave.trace.context.tests_context import get_raise_on_captured_errors
from weave.trace.context.weave_client_context import get_weave_client
from weave.trace.errors import InternalError
Expand Down Expand Up @@ -618,6 +619,7 @@ def make_trace_obj(
server: TraceServerInterface,
root: Optional[Traceable],
parent: Any = None,
objectify: bool = True,
) -> Any:
if isinstance(val, Traceable):
# If val is a WeaveTable, we want to refer to it via the outer object
Expand Down Expand Up @@ -707,9 +709,10 @@ def make_trace_obj(

if not isinstance(val, Traceable):
if isinstance(val, ObjectRecord):
return WeaveObject(
val, ref=new_ref, server=server, root=root, parent=parent
)
obj = WeaveObject(val, ref=new_ref, server=server, root=root, parent=parent)
if objectify:
return objectify_module.objectify(obj)
return obj
elif isinstance(val, list):
return WeaveList(val, ref=new_ref, server=server, root=root, parent=parent)
elif isinstance(val, dict):
Expand Down
4 changes: 2 additions & 2 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def save(self, val: Any, name: str, branch: str = "latest") -> Any:
return self.get(ref)

@trace_sentry.global_trace_sentry.watch()
def get(self, ref: ObjectRef) -> Any:
def get(self, ref: ObjectRef, *, objectify: bool = True) -> Any:
project_id = f"{ref.entity}/{ref.project}"
try:
read_res = self.server.obj_read(
Expand Down Expand Up @@ -707,7 +707,7 @@ def get(self, ref: ObjectRef) -> Any:

val = from_json(data, project_id, self.server)

return make_trace_obj(val, ref, self.server, None)
return make_trace_obj(val, ref, self.server, None, objectify=objectify)

################ Query API ################

Expand Down
Loading