Skip to content

Commit

Permalink
refactor: address redhog feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 13, 2024
1 parent aad2c48 commit 658b59e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 18 deletions.
82 changes: 66 additions & 16 deletions docetl/operations/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@


class SampleOperation(BaseOperation):
"""
Params:
- method: "uniform", "stratify", "outliers", "custom"
- samples: int, float, or list
- method_kwargs: dict, optional
- embedding_model: str, optional
- embedding_keys: list, optional
- center: dict, optional
"""

def __init__(
self,
*args,
Expand All @@ -20,20 +30,39 @@ def syntax_check(self) -> None:
ValueError: If required keys are missing or invalid in the configuration.
TypeError: If configuration values have incorrect types.
"""
if "samples" not in self.config and "outliers" not in self.config:
raise ValueError(
"Must specify either 'samples' or 'outliers' in SampleOperation configuration"
)
if "method" not in self.config:
raise ValueError("Must specify 'method' in SampleOperation configuration")

valid_methods = ["uniform", "stratify", "outliers", "custom"]
if self.config["method"] not in valid_methods:
raise ValueError(f"'method' must be one of {valid_methods}")

if self.config["method"] == "custom":
# Samples must be a list
if not isinstance(self.config["samples"], list):
raise TypeError("'samples' must be a list for custom sampling")

if "samples" in self.config:
if self.config["method"] in ["random", "stratify"]:
if "samples" not in self.config:
raise ValueError(
"Must specify 'samples' for random or stratify sampling"
)
if not isinstance(self.config["samples"], (int, float, list)) or (
isinstance(self.config["samples"], (int, float))
and self.config["samples"] <= 0
):
raise TypeError("'samples' must be a positive integer, float, or list")

if "outliers" in self.config:
outliers_config = self.config["outliers"]
if self.config["method"] == "stratify":
if "stratify_key" not in self.config.get("method_kwargs", {}):
raise ValueError("Must specify 'stratify_key' for stratify sampling")
if not isinstance(
self.config.get("method_kwargs", {})["stratify_key"], str
):
raise TypeError("'stratify_key' must be a string")

if self.config["method"] == "outliers":
outliers_config = self.config.get("method_kwargs", {})
if "std" not in outliers_config and "samples" not in outliers_config:
raise ValueError(
"Must specify either 'std' or 'samples' in outliers configuration"
Expand Down Expand Up @@ -67,6 +96,15 @@ def syntax_check(self) -> None:
"'embedding_keys' in outliers must be a list of strings"
)

if "center" in self.config.get("method_kwargs", {}):
if not isinstance(self.config.get("method_kwargs", {})["center"], dict):
raise TypeError("'center' must be a dictionary")
for key, value in self.config.get("method_kwargs", {})["center"].items():
if not isinstance(value, (int, float)):
raise TypeError(
f"Values in 'center' must be numbers, got {type(value)} for key '{key}'"
)

def execute(
self, input_data: List[Dict], is_build: bool = False
) -> Tuple[List[Dict], float]:
Expand All @@ -86,26 +124,35 @@ def execute(
if not input_data:
return [], cost

if "outliers" in self.config:
if self.config["method"] == "outliers":
# Outlier functionality
outliers_config = self.config["outliers"]
outliers_config = self.config.get("method_kwargs", {})
embeddings, embedding_cost = get_embeddings_for_clustering(
input_data, outliers_config, self.runner.api
)
cost += embedding_cost
embeddings = np.array(embeddings)

center = embeddings.mean(axis=0)
if "center" in self.config:
center = np.array(
[
outliers_config["center"][key]
for key in outliers_config["embedding_keys"]
]
)
else:
center = embeddings.mean(axis=0)

distances = np.sqrt(((embeddings - center) ** 2).sum(axis=1))

if "std" in outliers_config:
cutoff = (
np.sqrt((embeddings.std(axis=0) ** 2).sum())
* outliers_config["std"]
)
else: # "samples" in outliers_config
else: # "samples" in config
distance_distribution = np.sort(distances)
samples = outliers_config["samples"]
samples = self.config["samples"]
if isinstance(samples, float):
samples = int(samples * (len(distance_distribution) - 1))
cutoff = distance_distribution[samples]
Expand All @@ -116,7 +163,7 @@ def execute(
output_data = [item for idx, item in enumerate(input_data) if include[idx]]
else:
samples = self.config["samples"]
if isinstance(samples, list):
if self.config["method"] == "custom":
keys = list(samples[0].keys())
key_to_doc = {
tuple([doc[key] for key in keys]): doc for doc in input_data
Expand All @@ -128,12 +175,15 @@ def execute(
]
else:
stratify = None
if "stratify" in self.config:
stratify = [data[self.config["stratify"]] for data in input_data]
if self.config["method"] == "stratify":
stratify = [
data[self.config.get("method_kwargs", {})["stratify_key"]]
for data in input_data
]

import sklearn.model_selection

output_data, dummy = sklearn.model_selection.train_test_split(
output_data, _ = sklearn.model_selection.train_test_split(
input_data,
train_size=samples,
random_state=self.config.get("random_state", None),
Expand Down
10 changes: 8 additions & 2 deletions tests/basic/test_cluster_and_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_sample_operation_with_count(
sample_config, sample_data, api_wrapper, default_model, max_threads
):
sample_config["samples"] = 5
sample_config["method"] = "uniform"
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute(sample_data)

Expand All @@ -159,6 +160,7 @@ def test_sample_operation_with_fraction(
sample_config, sample_data, api_wrapper, default_model, max_threads
):
sample_config["samples"] = 0.5
sample_config["method"] = "uniform"
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute(sample_data)

Expand All @@ -172,6 +174,7 @@ def test_sample_operation_with_list(
):
sample_list = [{"id": 1}, {"id": 3}, {"id": 5}]
sample_config["samples"] = sample_list
sample_config["method"] = "custom"
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute(sample_data)

Expand All @@ -184,7 +187,8 @@ def test_sample_operation_with_stratify(
sample_config, sample_data, api_wrapper, default_model, max_threads
):
sample_config["samples"] = 5
sample_config["stratify"] = "group"
sample_config["method"] = "stratify"
sample_config["method_kwargs"] = {"stratify_key": "group"}
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute(sample_data)

Expand All @@ -197,7 +201,8 @@ def test_sample_operation_with_stratify(
def test_sample_operation_with_outliers(
sample_config, sample_data, api_wrapper, default_model, max_threads
):
sample_config["outliers"] = {
sample_config["method"] = "outliers"
sample_config["method_kwargs"] = {
"std": 2,
"embedding_keys": ["concept", "description"],
"keep": True,
Expand All @@ -214,6 +219,7 @@ def test_sample_operation_empty_input(
sample_config, api_wrapper, default_model, max_threads
):
sample_config["samples"] = 3
sample_config["method"] = "uniform"
operation = SampleOperation(api_wrapper, sample_config, default_model, max_threads)
results, cost = operation.execute([])

Expand Down

0 comments on commit 658b59e

Please sign in to comment.