Skip to content

Commit

Permalink
[Evals API][2/n] datasets / datasetio meta-reference implementation (#…
Browse files Browse the repository at this point in the history
…288)

* skeleton dataset / datasetio

* dataset datasetio

* config

* address comments

* delete dataset_utils

* address comments

* naming fix
  • Loading branch information
yanxi0830 authored Oct 22, 2024
1 parent 8a01b9e commit 8218106
Show file tree
Hide file tree
Showing 16 changed files with 452 additions and 8 deletions.
6 changes: 0 additions & 6 deletions llama_stack/apis/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,5 @@ async def get_dataset(
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]: ...

@webmethod(route="/datasets/delete")
async def delete_dataset(
self,
dataset_identifier: str,
) -> None: ...

@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[DatasetDefWithProvider]: ...
6 changes: 5 additions & 1 deletion llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety


LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"

Expand All @@ -30,18 +31,21 @@
ModelDef,
ShieldDef,
MemoryBankDef,
DatasetDef,
]

RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
]

RoutedProtocol = Union[
Inference,
Safety,
Memory,
DatasetIO,
]


Expand Down
4 changes: 4 additions & 0 deletions llama_stack/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
]


Expand Down
4 changes: 4 additions & 0 deletions llama_stack/distribution/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from llama_stack.distribution.datatypes import * # noqa: F403

from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
Expand All @@ -38,6 +40,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
}


Expand Down
5 changes: 4 additions & 1 deletion llama_stack/distribution/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
Expand All @@ -23,6 +24,7 @@ async def get_routing_table_impl(
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
Expand All @@ -33,12 +35,13 @@ async def get_routing_table_impl(


async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter

api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
Expand Down
32 changes: 32 additions & 0 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from typing import Any, AsyncGenerator, Dict, List

from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.distribution.datatypes import RoutingTable

from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403


class MemoryRouter(Memory):
Expand Down Expand Up @@ -160,3 +162,33 @@ async def run_shield(
messages=messages,
params=params,
)


class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table

async def initialize(self) -> None:
pass

async def shutdown(self) -> None:
pass

async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)
32 changes: 32 additions & 0 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403

from llama_stack.distribution.datatypes import * # noqa: F403

Expand All @@ -27,6 +28,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")


Registry = Dict[str, List[RoutableObjectWithProvider]]
Expand Down Expand Up @@ -80,6 +85,16 @@ def add_objects(objs: List[RoutableObjectWithProvider]) -> None:

add_objects(memory_banks)

elif api == Api.datasetio:
p.dataset_store = self
datasets = await p.list_datasets()

# do in-memory updates due to pesky Annotated unions
for d in datasets:
d.provider_id = pid

add_objects(datasets)

async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
Expand Down Expand Up @@ -137,6 +152,7 @@ async def register_object(self, obj: RoutableObjectWithProvider):
raise ValueError(f"Provider `{obj.provider_id}` not found")

p = self.impls_by_provider_id[obj.provider_id]

await register_object_with_provider(obj, p)

if obj.identifier not in self.registry:
Expand Down Expand Up @@ -190,3 +206,19 @@ async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)


class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[DatasetDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects

async def get_dataset(
self, dataset_identifier: str
) -> Optional[DatasetDefWithProvider]:
return self.get_object_by_identifier(identifier)

async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
10 changes: 10 additions & 0 deletions llama_stack/providers/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field

from llama_stack.apis.datasets import DatasetDef

from llama_stack.apis.memory_banks import MemoryBankDef

from llama_stack.apis.models import ModelDef
Expand All @@ -22,12 +24,14 @@ class Api(Enum):
safety = "safety"
agents = "agents"
memory = "memory"
datasetio = "datasetio"

telemetry = "telemetry"

models = "models"
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"

# built-in API
inspect = "inspect"
Expand All @@ -51,6 +55,12 @@ async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...


class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...

async def register_datasets(self, dataset_def: DatasetDef) -> None: ...


@json_schema_type
class ProviderSpec(BaseModel):
api: Api
Expand Down
18 changes: 18 additions & 0 deletions llama_stack/providers/impls/meta_reference/datasetio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .config import MetaReferenceDatasetIOConfig


async def get_provider_impl(
config: MetaReferenceDatasetIOConfig,
_deps,
):
from .datasetio import MetaReferenceDatasetIOImpl

impl = MetaReferenceDatasetIOImpl(config)
await impl.initialize()
return impl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403


class MetaReferenceDatasetIOConfig(BaseModel): ...
Loading

0 comments on commit 8218106

Please sign in to comment.