diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index 0ba6a140853f..ff4056fdbd73 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -1,3 +1,5 @@ +import pytest + import weave @@ -21,3 +23,13 @@ def test_dataset_iteration(client): # Test that we can iterate multiple times rows2 = list(dataset) assert rows2 == rows + + +def test_pythonic_access(client): + rows = [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}] + ds = weave.Dataset(rows=rows) + assert len(ds) == 5 + assert ds[0] == {"a": 1} + + with pytest.raises(IndexError): + ds[-1] diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 0bcd9c60b81d..76e3b03a9581 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -69,3 +69,12 @@ def convert_to_table(cls, rows: Any) -> weave.Table: def __iter__(self) -> Iterator[dict]: return iter(self.rows) + + def __len__(self) -> int: + # TODO: This can be slow for large datasets... + return len(list(self.rows)) + + def __getitem__(self, key: int) -> dict: + if key < 0: + raise IndexError("Negative indexing is not supported") + return self.rows[key]