Skip to content

Commit

Permalink
Merge branch 'ucbepic:main' into test_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
staru09 authored Jan 25, 2025
2 parents db857cf + 2a259a0 commit 7983ab4
Show file tree
Hide file tree
Showing 38 changed files with 2,455 additions and 397 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ tests:

tests-basic:
poetry run pytest tests/basic
poetry run pytest tests/test_api.py
poetry run pytest -s tests/test_api.py
poetry run pytest tests/test_runner_caching.py

lint:
Expand Down
3 changes: 2 additions & 1 deletion docetl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from docetl.runner import DSLRunner
from docetl.optimizer import Optimizer
from docetl.apis.pd_accessors import SemanticAccessor

__all__ = ["DSLRunner", "Optimizer"]
__all__ = ["DSLRunner", "Optimizer", "SemanticAccessor"]
Empty file added docetl/apis/__init__.py
Empty file.
633 changes: 633 additions & 0 deletions docetl/apis/pd_accessors.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docetl/config_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
yaml_file_suffix: Optional[str] = None,
max_threads: int = None,
console: Optional[Console] = None,
**kwargs,
):
self.config = config
self.base_name = base_name
Expand Down
40 changes: 32 additions & 8 deletions docetl/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,28 @@ def optimize(self):
sample_size_needed = self.runner.optimizer.sample_size_map.get(
self.config["type"]
)

# if type is equijoin, sample_size_needed may be a dictionary
if self.config["type"] == "equijoin":
if isinstance(sample_size_needed, dict):
sample_size_needed = [
sample_size_needed["left"],
sample_size_needed["right"],
]
else:
sample_size_needed = [sample_size_needed, sample_size_needed]
else:
sample_size_needed = [sample_size_needed]

assert len(sample_size_needed) >= len(
self.children
), f"Sample size list must be a list of at least the same length as the number of children. Current sample size list: {sample_size_needed}. Current number of children: {len(self.children)}"

# run the children to get the input data for optimizing this operation
input_data = []
for child in self.children:
for idx, child in enumerate(self.children):
input_data.append(
child.next(is_build=True, sample_size_needed=sample_size_needed)[0]
child.next(is_build=True, sample_size_needed=sample_size_needed[idx])[0]
)

# Optimize this operation if it's eligible for optimization
Expand Down Expand Up @@ -367,6 +384,13 @@ def optimize(self):
sample_size_needed = self.runner.optimizer.sample_size_map.get(
new_head_pointer.config["type"]
)
# if it's an equijoin, sample_size_needed may be a dictionary
if new_head_pointer.config["type"] == "equijoin":
if isinstance(sample_size_needed, dict):
sample_size_needed = min(
sample_size_needed["left"], sample_size_needed["right"]
)

# walk down the new head pointer and set the selectivities
queue = [new_head_pointer] if new_head_pointer.parent else []
while queue:
Expand Down Expand Up @@ -500,12 +524,12 @@ def next(
# Track costs and log execution
this_op_cost = self.runner.total_cost - cost_before_execution
cost += this_op_cost
if this_op_cost > 0:
build_indicator = "[yellow](build)[/yellow] " if is_build else ""
curr_logs += f"[green]✓[/green] {build_indicator}{self.name} (Cost: [green]${this_op_cost:.2f}[/green])\n"
else:
build_indicator = "[yellow](build)[/yellow] " if is_build else ""
curr_logs += f"[green]✓[/green] {build_indicator}{self.name}\n"

build_indicator = "[yellow](build)[/yellow] " if is_build else ""
curr_logs += f"[green]✓[/green] {build_indicator}{self.name} (Cost: [green]${this_op_cost:.2f}[/green])\n"
self.runner.console.log(
f"[green]✓[/green] {build_indicator}{self.name} (Cost: [green]${this_op_cost:.2f}[/green])"
)

# Save selectivity estimate
output_size = len(output_data)
Expand Down
2 changes: 1 addition & 1 deletion docetl/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.manually_fix_errors = self.config.get("manually_fix_errors", False)
self.status = status
self.num_retries_on_validate_failure = self.config.get(
"num_retries_on_validate_failure", 0
"num_retries_on_validate_failure", 2
)
self.is_build = is_build
self.syntax_check()
Expand Down
4 changes: 4 additions & 0 deletions docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,8 @@ def execute(
if not is_build:
results = [result for result in results if result[filter_key]]

# Drop the filter_key from the results
for result in results:
result.pop(filter_key, None)

return results, total_cost
6 changes: 6 additions & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
if isinstance(response, ModelResponse)
else response
)

# Check that the output has all the keys in the schema
for key in self.config["output"]["schema"]:
if key not in output:
return output, False

for key, value in item.items():
if key not in self.config["output"]["schema"]:
output[key] = value
Expand Down
61 changes: 40 additions & 21 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ResolveOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "resolve"
comparison_prompt: str
resolution_prompt: str
resolution_prompt: Optional[str] = None
output: Optional[Dict[str, Any]] = None
embedding_model: Optional[str] = None
resolution_model: Optional[str] = None
Expand Down Expand Up @@ -119,14 +119,16 @@ def syntax_check(self) -> None:
f"Missing required key '{key}' in ResolveOperation configuration"
)

if "schema" not in self.config["output"]:
if "schema" not in self.config["output"] and not self.runner._from_df_accessors:
raise ValueError("Missing 'schema' in 'output' configuration")
elif not self.runner._from_df_accessors:
if not isinstance(self.config["output"]["schema"], dict):
raise TypeError(
"'schema' in 'output' configuration must be a dictionary"
)

if not isinstance(self.config["output"]["schema"], dict):
raise TypeError("'schema' in 'output' configuration must be a dictionary")

if not self.config["output"]["schema"]:
raise ValueError("'schema' in 'output' configuration cannot be empty")
if not self.config["output"]["schema"]:
raise ValueError("'schema' in 'output' configuration cannot be empty")

# Check if the comparison_prompt is a valid Jinja2 template
try:
Expand All @@ -140,7 +142,7 @@ def syntax_check(self) -> None:
or "input2" not in comparison_var_names
):
raise ValueError(
"'comparison_prompt' must contain both 'input1' and 'input2' variables"
f"'comparison_prompt' must contain both 'input1' and 'input2' variables. {self.config['comparison_prompt']}"
)

if "resolution_prompt" in self.config:
Expand Down Expand Up @@ -674,19 +676,36 @@ def process_cluster(cluster):
f"Number of distinct keys after resolution: {num_clusters_after}"
)

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = [
executor.submit(process_cluster, cluster) for cluster in final_clusters
]
for future in rich_as_completed(
futures,
total=len(futures),
desc="Determining resolved key for each group of equivalent keys",
console=self.console,
):
cluster_results, cluster_cost = future.result()
results.extend(cluster_results)
total_cost += cluster_cost
# If no resolution prompt is provided, we can skip the resolution phase
# And simply select the most common value for each key
if not self.config.get("resolution_prompt", None):
for cluster in final_clusters:
if len(cluster) > 1:
for key in self.config["output"]["keys"]:
most_common_value = max(
set(input_data[i][key] for i in cluster),
key=lambda x: sum(
1 for i in cluster if input_data[i][key] == x
),
)
for i in cluster:
input_data[i][key] = most_common_value
results = input_data
else:
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = [
executor.submit(process_cluster, cluster)
for cluster in final_clusters
]
for future in rich_as_completed(
futures,
total=len(futures),
desc="Determining resolved key for each group of equivalent keys",
console=self.console,
):
cluster_results, cluster_cost = future.result()
results.extend(cluster_results)
total_cost += cluster_cost

total_pairs = len(input_data) * (len(input_data) - 1) // 2
true_match_count = sum(
Expand Down
40 changes: 36 additions & 4 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from litellm import ModelResponse, RateLimitError, completion, embedding
from rich import print as rprint
from rich.console import Console
from rich.console import Console, Group
from rich.panel import Panel
from rich.text import Text

from docetl.utils import completion_cost

Expand Down Expand Up @@ -403,7 +405,7 @@ def call_llm(
rate_limited_attempt = 0
while attempt <= max_retries:
try:
return timeout(timeout_seconds)(self._cached_call_llm)(
output = timeout(timeout_seconds)(self._cached_call_llm)(
key,
model,
op_type,
Expand All @@ -418,6 +420,31 @@ def call_llm(
initial_result=initial_result,
litellm_completion_kwargs=litellm_completion_kwargs,
)
# Log input and output if verbose
if verbose:
# Truncate messages to 500 chars
messages_str = str(messages)
truncated_messages = (
messages_str[:500] + "..."
if len(messages_str) > 500
else messages_str
)

# Log with nice formatting
self.runner.console.print(
Panel(
Group(
Text("Input:", style="bold cyan"),
Text(truncated_messages),
Text("\nOutput:", style="bold cyan"),
Text(str(output)),
),
title="[bold green]LLM Call Details[/bold green]",
border_style="green",
)
)

return output
except RateLimitError:
# TODO: this is a really hacky way to handle rate limits
# we should implement a more robust retry mechanism
Expand Down Expand Up @@ -479,7 +506,7 @@ def _call_llm_with_cache(
len(props) == 1
and list(props.values())[0].get("type") == "string"
and scratchpad is None
and ("ollama" in model or "sagemaker" in model)
and ("sagemaker" in model)
):
use_tools = False

Expand Down Expand Up @@ -740,7 +767,12 @@ def _parse_llm_response_helper(

try:
output_dict = json.loads(tool_call.function.arguments)
if "ollama" in response.model:
# Augment output_dict with empty values for any keys in the schema that are not in output_dict
for key in schema:
if key not in output_dict:
output_dict[key] = "Not found"

if "ollama" in response.model or "sagemaker" in response.model:
for key, value in output_dict.items():
if not isinstance(value, str):
continue
Expand Down
5 changes: 3 additions & 2 deletions docetl/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def __init__(
if self.config.get("optimizer_config", {}).get("sample_sizes", {}):
self.sample_size_map.update(self.config["optimizer_config"]["sample_sizes"])

self.print_optimizer_config()
if not self.runner._from_df_accessors:
self.print_optimizer_config()

def print_optimizer_config(self):
"""
Expand Down Expand Up @@ -372,7 +373,7 @@ def should_optimize(
elif node_of_interest.config.get("type") == "reduce":
reduce_optimizer = ReduceOptimizer(
self.runner,
self._run_operation,
self.runner._run_operation,
)
should_optimize_output, input_data, output_data = (
reduce_optimizer.should_optimize(node_of_interest.config, input_data[0])
Expand Down
Loading

0 comments on commit 7983ab4

Please sign in to comment.