diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 2797adaf4050c..901e6811463e1 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -968,9 +968,6 @@ def _run_llm_or_chain( return result -## Public API - - def _prepare_eval_run( client: Client, dataset_name: str, @@ -978,10 +975,17 @@ def _prepare_eval_run( project_name: str, project_metadata: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, + dataset_version: Optional[Union[str, datetime]] = None, ) -> Tuple[MCF, TracerSession, Dataset, List[Example]]: wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) dataset = client.read_dataset(dataset_name=dataset_name) - examples = list(client.list_examples(dataset_id=dataset.id)) + as_of = dataset_version if isinstance(dataset_version, datetime) else None + if isinstance(dataset_version, str): + raise NotImplementedError( + "Selecting dataset_version by tag is not yet supported." + " Please use a datetime object." + ) + examples = list(client.list_examples(dataset_id=dataset.id, as_of=as_of)) if not examples: raise ValueError(f"Dataset {dataset_name} has no example rows.") modified_at = [ex.modified_at for ex in examples if ex.modified_at] @@ -1173,6 +1177,7 @@ def prepare( concurrency_level: int = 5, project_metadata: Optional[Dict[str, Any]] = None, revision_id: Optional[str] = None, + dataset_version: Optional[Union[datetime, str]] = None, ) -> _DatasetRunContainer: project_name = project_name or name_generation.random_name() if revision_id: @@ -1186,6 +1191,7 @@ def prepare( project_name, project_metadata=project_metadata, tags=tags, + dataset_version=dataset_version, ) tags = tags or [] for k, v in (project.metadata.get("git") or {}).items(): @@ -1269,6 +1275,8 @@ def _display_aggregate_results(aggregate_results: pd.DataFrame) -> None: "langchain.schema.runnable.base.RunnableLambda.html)" ) +## Public API + async def arun_on_dataset( client: Optional[Client], @@ -1276,11 +1284,11 @@ async def arun_on_dataset( llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, *, evaluation: Optional[smith_eval.RunEvalConfig] = None, + dataset_version: Optional[Union[datetime, str]] = None, concurrency_level: int = 5, project_name: Optional[str] = None, project_metadata: Optional[Dict[str, Any]] = None, verbose: bool = False, - tags: Optional[List[str]] = None, revision_id: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: @@ -1289,6 +1297,13 @@ async def arun_on_dataset( warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True) if revision_id is None: revision_id = get_langchain_env_var_metadata().get("revision_id") + tags = kwargs.pop("tags", None) + if tags: + warn_deprecated( + "0.1.9", + message="The tags argument is deprecated and will be" + " removed in a future release. Please specify project_metadata instead.", + ) if kwargs: warn_deprecated( @@ -1310,6 +1325,7 @@ async def arun_on_dataset( concurrency_level, project_metadata=project_metadata, revision_id=revision_id, + dataset_version=dataset_version, ) batch_results = await runnable_utils.gather_with_concurrency( container.configs[0].get("max_concurrency"), @@ -1332,17 +1348,24 @@ def run_on_dataset( llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, *, evaluation: Optional[smith_eval.RunEvalConfig] = None, + dataset_version: Optional[Union[datetime, str]] = None, concurrency_level: int = 5, project_name: Optional[str] = None, project_metadata: Optional[Dict[str, Any]] = None, verbose: bool = False, - tags: Optional[List[str]] = None, revision_id: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: input_mapper = kwargs.pop("input_mapper", None) if input_mapper: warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True) + tags = kwargs.pop("tags", None) + if tags: + warn_deprecated( + "0.1.9", + message="The tags argument is deprecated and will be" + " removed in a future release. Please specify project_metadata instead.", + ) if revision_id is None: revision_id = get_langchain_env_var_metadata().get("revision_id") @@ -1366,6 +1389,7 @@ def run_on_dataset( concurrency_level, project_metadata=project_metadata, revision_id=revision_id, + dataset_version=dataset_version, ) if concurrency_level == 0: batch_results = [ @@ -1458,8 +1482,8 @@ def construct_chain(): client = Client() run_on_dataset( client, - "<my_dataset_name>", - construct_chain, + dataset_name="<my_dataset_name>", + llm_or_chain_factory=construct_chain, evaluation=evaluation_config, ) @@ -1496,8 +1520,8 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) -> run_on_dataset( client, - "<my_dataset_name>", - construct_chain, + dataset_name="<my_dataset_name>", + llm_or_chain_factory=construct_chain, evaluation=evaluation_config, ) """ # noqa: E501