Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add metadata into text_chunks #1671

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250129165958552473.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "add metadata to chunk texts when indexing"
}
6 changes: 1 addition & 5 deletions docs/config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,8 @@ These settings control the data input used by the pipeline. Any settings with a
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ---------- |
| `GRAPHRAG_INPUT_TYPE` | The input storage type to use when reading files. (`file` or `blob`) | `str` | optional | `file` |
| `GRAPHRAG_INPUT_FILE_PATTERN` | The file pattern regexp to use when reading input files from the input directory. | `str` | optional | `.*\.txt$` |
| `GRAPHRAG_INPUT_SOURCE_COLUMN` | The 'source' column to use when reading CSV input files. | `str` | optional | `source` |
| `GRAPHRAG_INPUT_TIMESTAMP_COLUMN` | The 'timestamp' column to use when reading CSV input files. | `str` | optional | `None` |
| `GRAPHRAG_INPUT_TIMESTAMP_FORMAT` | The timestamp format to use when parsing timestamps in the timestamp column. | `str` | optional | `None` |
| `GRAPHRAG_INPUT_TEXT_COLUMN` | The 'text' column to use when reading CSV input files. | `str` | optional | `text` |
| `GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS` | A list of CSV columns, comma-separated, to incorporate as document fields. | `str` | optional | `id` |
| `GRAPHRAG_INPUT_TITLE_COLUMN` | The 'title' column to use when reading CSV input files. | `str` | optional | `title` |
| `GRAPHRAG_INPUT_DOCUMENT_METADATA` | A list of CSV columns, comma-separated, to incorporate as metadata with each chunk. | `str` | optional | `None` |
| `GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | `None` |
| `GRAPHRAG_INPUT_CONNECTION_STRING` | The connection string to use when reading CSV input files from Azure Blob Storage. | `str` | optional | `None` |
| `GRAPHRAG_INPUT_CONTAINER_NAME` | The container name to use when reading CSV input files from Azure Blob Storage. | `str` | optional | `None` |
Expand Down
6 changes: 1 addition & 5 deletions docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,8 @@ This is the base LLM configuration section. Other steps may override this config
- `file_encoding` **str** - The encoding of the input file. Default is `utf-8`
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$` if in csv mode and `.*\.txt$` if in text mode.
- `file_filter` **dict** - Key/value pairs to filter. Default is None.
- `source_column` **str** - (CSV Mode Only) The source column name.
- `timestamp_column` **str** - (CSV Mode Only) The timestamp column name.
- `timestamp_format` **str** - (CSV Mode Only) The source format.
- `text_column` **str** - (CSV Mode Only) The text column name.
- `title_column` **str** - (CSV Mode Only) The title column name.
- `document_attribute_columns` **list[str]** - (CSV Mode Only) The additional document attributes to include.
- `metadata` **list[str]** - The additional document attributes to include when chunking.

### chunks

Expand Down
16 changes: 2 additions & 14 deletions graphrag/config/models/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,9 @@ class InputConfig(BaseModel):
file_filter: dict[str, str] | None = Field(
description="The optional file filter for the input files.", default=None
)
source_column: str | None = Field(
description="The input source column to use.", default=None
)
timestamp_column: str | None = Field(
description="The input timestamp column to use.", default=None
)
timestamp_format: str | None = Field(
description="The input timestamp format to use.", default=None
)
text_column: str = Field(
description="The input text column to use.", default=defs.INPUT_TEXT_COLUMN
)
title_column: str | None = Field(
description="The input title column to use.", default=None
)
document_attribute_columns: list[str] = Field(
description="The document attribute columns to use.", default=[]
metadata: list[str] | None = Field(
description="The document metadata to use with each chunk.", default=None
)
15 changes: 13 additions & 2 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def create_base_text_units(
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
metadata: list[str] | None,
line_delimiter: str = ".\n",
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand All @@ -32,15 +34,22 @@ def create_base_text_units(

callbacks.progress(Progress(percent=0))

agg_dict = {"text_with_ids": list}
if metadata:
for meta in metadata:
agg_dict[meta] = "first" # type: ignore

aggregated = (
(
sort.groupby(group_by_columns, sort=False)
if len(group_by_columns) > 0
else sort.groupby(lambda _x: True)
)
.agg(texts=("text_with_ids", list))
.agg(agg_dict)
.reset_index()
)
) # in here I need to add the metadata columns to it

aggregated.rename(columns={"text_with_ids": "texts"}, inplace=True)

aggregated["chunks"] = chunk_text(
aggregated,
Expand All @@ -50,6 +59,8 @@ def create_base_text_units(
encoding_model=encoding_model,
strategy=strategy,
callbacks=callbacks,
metadata=metadata,
line_delimiter=line_delimiter,
)

aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
Expand Down
24 changes: 5 additions & 19 deletions graphrag/index/flows/create_final_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def create_final_documents(
documents: pd.DataFrame,
text_units: pd.DataFrame,
document_attribute_columns: list[str] | None = None,
metadata: list[str] | None,
) -> pd.DataFrame:
"""All the steps to transform final documents."""
exploded = (
Expand Down Expand Up @@ -47,30 +47,16 @@ def create_final_documents(
rejoined["human_readable_id"] = rejoined.index + 1

# Convert attribute columns to strings and collapse them into a JSON object
if document_attribute_columns:
# Convert all specified columns to string at once
rejoined[document_attribute_columns] = rejoined[
document_attribute_columns
].astype(str)

# Collapse the document_attribute_columns into a single JSON object column
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
orient="records"
)

# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)

# set the final column order, but adjust for attributes
core_columns = [
"id",
"human_readable_id",
"title",
"text",
"text_unit_ids",
]
if metadata:
core_columns.extend(metadata)

final_columns = [column for column in core_columns if column in rejoined.columns]
if document_attribute_columns:
final_columns.append("attributes")

return rejoined.loc[:, final_columns]
return rejoined.loc[:, list(set(final_columns))]
61 changes: 12 additions & 49 deletions graphrag/index/input/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
)
if "id" not in data.columns:
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
if config.source_column is not None and "source" not in data.columns:
if config.source_column not in data.columns:
log.warning(
"source_column %s not found in csv file %s",
config.source_column,
path,
)
else:
data["source"] = data.apply(lambda x: x[config.source_column], axis=1)
if config.text_column is not None and "text" not in data.columns:
if config.text_column not in data.columns:
log.warning(
Expand All @@ -59,47 +50,19 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
)
else:
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
if config.title_column is not None and "title" not in data.columns:
if config.title_column not in data.columns:
log.warning(
"title_column %s not found in csv file %s",
config.title_column,
path,
)
else:
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)

if config.timestamp_column is not None:
fmt = config.timestamp_format
if fmt is None:
msg = "Must specify timestamp_format if timestamp_column is specified"
raise ValueError(msg)

if config.timestamp_column not in data.columns:
log.warning(
"timestamp_column %s not found in csv file %s",
config.timestamp_column,
path,
)
else:
data["timestamp"] = pd.to_datetime(
data[config.timestamp_column], format=fmt
)

# TODO: Theres probably a less gross way to do this
if "year" not in data.columns:
data["year"] = data.apply(lambda x: x["timestamp"].year, axis=1)
if "month" not in data.columns:
data["month"] = data.apply(lambda x: x["timestamp"].month, axis=1)
if "day" not in data.columns:
data["day"] = data.apply(lambda x: x["timestamp"].day, axis=1)
if "hour" not in data.columns:
data["hour"] = data.apply(lambda x: x["timestamp"].hour, axis=1)
if "minute" not in data.columns:
data["minute"] = data.apply(lambda x: x["timestamp"].minute, axis=1)
if "second" not in data.columns:
data["second"] = data.apply(lambda x: x["timestamp"].second, axis=1)

if config.metadata is not None:
for metadata in config.metadata:
if metadata not in data.columns:
log.warning(
"metadata column %s not found in csv file %s",
metadata,
path,
)
else:
data[metadata] = data.apply(
lambda x, metadata=metadata: x[metadata], axis=1
)
return data

file_pattern = (
Expand Down
9 changes: 7 additions & 2 deletions graphrag/index/input/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ async def load(
"""Load text inputs from a directory."""

async def load_file(
path: str, group: dict | None = None, _encoding: str = "utf-8"
path: str,
group: dict | None = None,
_encoding: str = "utf-8",
metadata: list[str] | None = None,
) -> dict[str, Any]:
if group is None:
group = {}
text = await storage.get(path, encoding="utf-8")
new_item = {**group, "text": text}
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
new_item["title"] = str(Path(path).name)
if metadata and "creation_date" in metadata:
new_item["creation_date"] = storage.get_creation_date(path)
return new_item

files = list(
Expand All @@ -57,7 +62,7 @@ async def load_file(

for file, group in files:
try:
files_loaded.append(await load_file(file, group))
files_loaded.append(await load_file(file, group, metadata=config.metadata))
except Exception: # noqa: BLE001 (catching Exception is fine here)
log.warning("Warning! Error loading file %s. Skipping...", file)

Expand Down
20 changes: 17 additions & 3 deletions graphrag/index/operations/chunk_text/chunk_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def chunk_text(
encoding_model: str,
strategy: ChunkStrategyType,
callbacks: WorkflowCallbacks,
metadata: list[str] | None = None,
line_delimiter: str = ".\n",
) -> pd.Series:
"""
Chunk a piece of text into smaller pieces.
Expand Down Expand Up @@ -65,7 +67,14 @@ def chunk_text(
input.apply(
cast(
"Any",
lambda x: run_strategy(strategy_exec, x[column], config, tick),
lambda x: run_strategy(
strategy_exec,
x[column],
config,
tick,
{v: x[v] for v in metadata or []},
line_delimiter,
),
),
axis=1,
),
Expand All @@ -77,10 +86,15 @@ def run_strategy(
input: ChunkInput,
config: ChunkingConfig,
tick: ProgressTicker,
metadata: dict[str, Any] | None = None,
line_delimiter: str = ".\n",
) -> list[str | tuple[list[str] | None, str, int]]:
"""Run strategy method definition."""
if isinstance(input, str):
return [item.text_chunk for item in strategy_exec([input], config, tick)]
return [
item.text_chunk
for item in strategy_exec([input], config, tick, metadata, line_delimiter)
]

# We can work with both just a list of text content
# or a list of tuples of (document_id, text content)
Expand All @@ -92,7 +106,7 @@ def run_strategy(
else:
texts.append(item[1])

strategy_results = strategy_exec(texts, config, tick)
strategy_results = strategy_exec(texts, config, tick, metadata, line_delimiter)

results = []
for strategy_result in strategy_results:
Expand Down
25 changes: 23 additions & 2 deletions graphrag/index/operations/chunk_text/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""A module containing chunk strategies."""

from collections.abc import Iterable
from typing import Any

import nltk
import tiktoken
Expand All @@ -18,7 +19,11 @@


def run_tokens(
input: list[str], config: ChunkingConfig, tick: ProgressTicker
input: list[str],
config: ChunkingConfig,
tick: ProgressTicker,
metadata: dict[str, Any] | None = None,
line_delimiter: str = ".\n",
) -> Iterable[TextChunk]:
"""Chunks text into chunks based on encoding tokens."""
tokens_per_chunk = config.size
Expand All @@ -43,16 +48,32 @@ def decode(tokens: list[int]) -> str:
decode=decode,
),
tick,
line_delimiter,
metadata,
)


def run_sentences(
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
input: list[str],
_config: ChunkingConfig,
tick: ProgressTicker,
metadata: dict[str, Any] | None = None,
line_delimiter: str = ".\n",
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts by sentence."""
for doc_idx, text in enumerate(input):
sentences = nltk.sent_tokenize(text)
for sentence in sentences:
metadata_str = ""
if metadata is not None and len(metadata) > 0:
metadata_str = line_delimiter.join([
f"{k}: {v}" for k, v in metadata.items()
])
sentence = (
f"{metadata_str}{line_delimiter}{sentence}"
if metadata_str
else sentence
)
yield TextChunk(
text_chunk=sentence,
source_doc_indices=[doc_idx],
Expand Down
4 changes: 3 additions & 1 deletion graphrag/index/operations/chunk_text/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any

from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.logger.progress import ProgressTicker
Expand All @@ -23,5 +24,6 @@ class TextChunk:
"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text)."""

ChunkStrategy = Callable[
[list[str], ChunkingConfig, ProgressTicker], Iterable[TextChunk]
[list[str], ChunkingConfig, ProgressTicker, dict[str, Any] | None, str],
Iterable[TextChunk],
]
Loading