Skip to content

Commit

Permalink
Support header hierarchy in splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Aug 27, 2024
1 parent 9e36175 commit e5729aa
Show file tree
Hide file tree
Showing 11 changed files with 639 additions and 167 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ gather_operation:

Notes:

- The gather operation adds a new field to each item: {content_key}\_formatted, which contains the formatted chunk with added context.
- The gather operation adds a new field to each item: {content_key}\_rendered, which contains the formatted chunk with added context.
- The formatted content includes labels for previous context, main chunk, and next context.
- Skipped chunks are indicated with a "[... X characters skipped ...]" message.

Expand Down
91 changes: 78 additions & 13 deletions motion/operations/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class GatherOperation(BaseOperation):
1. Group chunks by their document ID.
2. Order chunks within each group.
3. Add peripheral context to each chunk based on the configuration.
4. Return results containing the formatted chunks with added context, including information about skipped characters.
4. Include headers for each chunk and its upward hierarchy.
5. Return results containing the rendered chunks with added context, including information about skipped characters and headers.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -83,6 +84,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
"main_chunk_start", "--- Begin Main Chunk ---"
)
main_chunk_end = self.config.get("main_chunk_end", "--- End Main Chunk ---")
doc_header_keys = self.config.get("doc_header_keys", [])
results = []
cost = 0.0

Expand All @@ -99,25 +101,26 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
# Sort chunks by their order within the document
chunks.sort(key=lambda x: x[order_key])

# Process each chunk with its peripheral context
# Process each chunk with its peripheral context and headers
for i, chunk in enumerate(chunks):
formatted_chunk = self.format_chunk_with_context(
rendered_chunk = self.render_chunk_with_context(
chunks,
i,
peripheral_config,
content_key,
order_key,
main_chunk_start,
main_chunk_end,
doc_header_keys,
)

result = chunk.copy()
result[f"{content_key}_formatted"] = formatted_chunk
result[f"{content_key}_rendered"] = rendered_chunk
results.append(result)

return results, cost

def format_chunk_with_context(
def render_chunk_with_context(
self,
chunks: List[Dict],
current_index: int,
Expand All @@ -126,9 +129,10 @@ def format_chunk_with_context(
order_key: str,
main_chunk_start: str,
main_chunk_end: str,
doc_header_keys: List[Dict[str, Any]],
) -> str:
"""
Format a chunk with its peripheral context.
Render a chunk with its peripheral context and headers.
Args:
chunks (List[Dict]): List of all chunks in the document.
Expand All @@ -138,9 +142,10 @@ def format_chunk_with_context(
order_key (str): Key for the order of each chunk.
main_chunk_start (str): String to mark the start of the main chunk.
main_chunk_end (str): String to mark the end of the main chunk.
doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys.
Returns:
str: Formatted chunk with context.
str: Renderted chunk with context and headers.
"""
combined_parts = []

Expand All @@ -152,16 +157,20 @@ def format_chunk_with_context(
peripheral_config.get("previous", {}),
content_key,
order_key,
reverse=True,
)
)
combined_parts.append("--- End Previous Context ---\n")

# Process main chunk
main_chunk = chunks[current_index]
combined_parts.append(
f"{main_chunk_start}\n{main_chunk[content_key]}\n{main_chunk_end}"
headers = self.render_hierarchy_headers(
main_chunk, chunks[: current_index + 1], doc_header_keys
)
if headers:
combined_parts.append(headers)
combined_parts.append(f"{main_chunk_start}")
combined_parts.append(f"{main_chunk[content_key]}")
combined_parts.append(f"{main_chunk_end}")

# Process next chunks
combined_parts.append("\n--- Next Context ---")
Expand Down Expand Up @@ -222,9 +231,9 @@ def process_peripheral_chunks(
section = "middle"
else:
# Show number of characters skipped
skipped_chars = sum(len(c[content_key]) for c in chunks)
skipped_chars = len(chunk[content_key])
if not in_skip:
skip_char_count += skipped_chars
skip_char_count = skipped_chars
in_skip = True
else:
skip_char_count += skipped_chars
Expand All @@ -245,7 +254,8 @@ def process_peripheral_chunks(
summary_suffix = " (Summary)" if is_summary else ""

chunk_prefix = f"[Chunk {chunk[order_key]}{summary_suffix}]"
processed_parts.append(f"{chunk_prefix} {chunk[section_content_key]}")
processed_parts.append(chunk_prefix)
processed_parts.append(f"{chunk[section_content_key]}")
included_chunks.append(chunk)

if in_skip:
Expand All @@ -255,3 +265,58 @@ def process_peripheral_chunks(
processed_parts = list(reversed(processed_parts))

return processed_parts

def render_hierarchy_headers(
self,
current_chunk: Dict,
chunks: List[Dict],
doc_header_keys: List[Dict[str, Any]],
) -> str:
"""
Render headers for the current chunk's hierarchy.
Args:
current_chunk (Dict): The current chunk being processed.
chunks (List[Dict]): List of chunks up to and including the current chunk.
doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys.
Returns:
str: Renderted headers in the current chunk's hierarchy.
"""
rendered_headers = []
current_hierarchy = {}

# Find the largest/highest level in the current chunk
current_chunk_headers = current_chunk.get(doc_header_keys, [])
highest_level = float("inf") # Initialize with positive infinity
for header_info in current_chunk_headers:
level = header_info.get("level")
if level is not None and level < highest_level:
highest_level = level

# If no headers found in the current chunk, set highest_level to None
if highest_level == float("inf"):
highest_level = None

for chunk in chunks:
for header_info in chunk.get(doc_header_keys, []):
header = header_info["header"]
level = header_info["level"]
if header and level:
current_hierarchy[level] = header
# Clear lower levels when a higher level header is found
for lower_level in range(level + 1, len(current_hierarchy) + 1):
if lower_level in current_hierarchy:
current_hierarchy[lower_level] = None

# Render the headers in the current hierarchy, everything above the highest level in the current chunk (if the highest level in the current chunk is None, render everything)
for level, header in sorted(current_hierarchy.items()):
if header is not None and (highest_level is None or level < highest_level):
rendered_headers.append(f"{'#' * level} {header}")

rendered_headers = " > ".join(rendered_headers)
return (
f"_Current Header Hierarchy:_ {rendered_headers}"
if rendered_headers
else ""
)
44 changes: 42 additions & 2 deletions motion/optimizers/map_optimizer/operation_creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def create_split_map_gather_operations(
content_key: str,
summary_prompt: Optional[str] = None,
summary_model: Optional[str] = None,
header_extraction_prompt: Optional[str] = "",
header_output_schema: Optional[Dict[str, Any]] = {},
) -> List[Dict[str, Any]]:
pipeline = []
chunk_size = int(chunk_info["chunk_size"] * 1.5)
Expand All @@ -63,8 +65,45 @@ def create_split_map_gather_operations(
}
pipeline.append(split_config)

# If there's a summary prompt, create a map config
if summary_prompt:
if header_extraction_prompt and summary_prompt:
# Create parallel map for summary and header extraction
pmap_output_schema = {
"schema": {
f"{split_key}_summary": "string",
**header_output_schema,
}
}
parallel_map_config = {
"type": "parallel_map",
"name": f"parallel_map_{split_key}_{op_config['name']}",
"prompts": [
{
"name": f"header_extraction_{split_key}_{op_config['name']}",
"prompt": header_extraction_prompt,
"model": self.config["default_model"],
"output_keys": list(header_output_schema.keys()),
},
{
"name": f"summary_{split_key}_{op_config['name']}",
"prompt": summary_prompt,
"model": summary_model,
"output_keys": [f"{split_key}_summary"],
},
],
"output": pmap_output_schema,
}
pipeline.append(parallel_map_config)
elif header_extraction_prompt:
pipeline.append(
{
"type": "map",
"name": f"header_extraction_{split_key}_{op_config['name']}",
"prompt": header_extraction_prompt,
"model": self.config["default_model"],
"output": {"schema": header_output_schema},
}
)
elif summary_prompt:
pipeline.append(
{
"type": "map",
Expand All @@ -81,6 +120,7 @@ def create_split_map_gather_operations(
"content_key": content_key,
"doc_id_key": f"{split_name}_id",
"order_key": f"{split_name}_chunk_num",
"doc_header_keys": ("headers" if header_output_schema else []),
"peripheral_chunks": {},
}

Expand Down
19 changes: 19 additions & 0 deletions motion/optimizers/map_optimizer/plan_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ def determine_metadata_with_retry():
)
self.console.log(f"Reason: {metadata_info.get('reason', 'N/A')}")

# Create header extraction prompt
header_extraction_prompt, header_output_schema = (
self.prompt_generator._get_header_extraction_prompt(
op_config, input_data, split_key
)
)
if header_extraction_prompt:
self.console.log(
f"Inferring headers from the documents. Will apply this prompt to find headers in chunks: {header_extraction_prompt}"
)
else:
self.console.log(
"Not inferring headers from the documents. Will not apply any header extraction prompt."
)

# Create base operations
# TODO: try with and without metadata
base_operations = []
Expand Down Expand Up @@ -181,6 +196,8 @@ def determine_metadata_with_retry():
content_key,
info_extraction_prompt if peripheral_configs[-1][1] else None,
"gpt-4o-mini",
header_extraction_prompt,
header_output_schema,
)
map_op = self.operation_creator.create_map_operation(
op_config, split_result["subprompt"] + " Only process the main chunk."
Expand Down Expand Up @@ -236,6 +253,8 @@ def task():
content_key,
info_extraction_prompt if peripheral_config[1] else None,
"gpt-4o-mini",
header_extraction_prompt,
header_output_schema,
)
map_op = self.operation_creator.create_map_operation(
op_config,
Expand Down
Loading

0 comments on commit e5729aa

Please sign in to comment.