Skip to content

Commit

Permalink
Nlp cache (#1689)
Browse files Browse the repository at this point in the history
* Add cache to build_noun_graph

* Semver
  • Loading branch information
natoverse authored Feb 10, 2025
1 parent c02ab09 commit a6a78d5
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 9 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250210180318886210.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add caching to NLP extractor."
}
4 changes: 4 additions & 0 deletions graphrag/config/models/extract_graph_nlp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
text_analyzer: TextAnalyzerConfig = Field(
description="The text analyzer configuration.", default=TextAnalyzerConfig()
)
parallelization_num_threads: int = Field(
description="The number of threads to use for the extraction process.",
default=defs.PARALLELIZATION_NUM_THREADS,
)
8 changes: 6 additions & 2 deletions graphrag/index/flows/extract_graph_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
Expand All @@ -20,9 +21,10 @@
from graphrag.index.operations.prune_graph import prune_graph


def extract_graph_nlp(
async def extract_graph_nlp(
text_units: pd.DataFrame,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
extraction_config: ExtractGraphNLPConfig,
pruning_config: PruneGraphConfig,
embed_config: EmbedGraphConfig | None = None,
Expand All @@ -31,10 +33,12 @@ def extract_graph_nlp(
"""All the steps to create the base entity graph."""
text_analyzer_config = extraction_config.text_analyzer
text_analyzer = create_noun_phrase_extractor(text_analyzer_config)
extracted_nodes, extracted_edges = build_noun_graph(
extracted_nodes, extracted_edges = await build_noun_graph(
text_units,
text_analyzer=text_analyzer,
normalize_edge_weights=extraction_config.normalize_edge_weights,
num_threads=extraction_config.parallelization_num_threads,
cache=cache,
)

# create a temporary graph to prune, then turn it back into dataframes
Expand Down
38 changes: 33 additions & 5 deletions graphrag/index/operations/build_noun_graph/build_noun_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,65 @@

import pandas as pd

from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import AsyncType
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
BaseNounPhraseExtractor,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.index.utils.hashing import gen_sha512_hash


def build_noun_graph(
async def build_noun_graph(
text_unit_df: pd.DataFrame,
text_analyzer: BaseNounPhraseExtractor,
normalize_edge_weights: bool,
num_threads: int = 4,
cache: PipelineCache | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build a noun graph from text units."""
text_units = text_unit_df.loc[:, ["id", "text"]]
nodes_df = _extract_nodes(text_units, text_analyzer)
nodes_df = await _extract_nodes(
text_units, text_analyzer, num_threads=num_threads, cache=cache
)
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)

return (nodes_df, edges_df)


def _extract_nodes(
async def _extract_nodes(
text_unit_df: pd.DataFrame,
text_analyzer: BaseNounPhraseExtractor,
num_threads: int = 4,
cache: PipelineCache | None = None,
) -> pd.DataFrame:
"""
Extract initial nodes and edges from text units.
Input: text unit df with schema [id, text, document_id]
Returns a dataframe with schema [id, title, freq, text_unit_ids].
"""
text_unit_df["noun_phrases"] = text_unit_df["text"].apply(
lambda text: text_analyzer.extract(text)
cache = cache or NoopPipelineCache()
cache = cache.child("extract_noun_phrases")

async def extract(row):
text = row["text"]
attrs = {"text": text, "analyzer": str(text_analyzer)}
key = gen_sha512_hash(attrs, attrs.keys())
result = await cache.get(key)
if not result:
result = text_analyzer.extract(text)
await cache.set(key, result)
return result

text_unit_df["noun_phrases"] = await derive_from_rows(
text_unit_df,
extract,
num_threads=num_threads,
async_type=AsyncType.Threaded,
)

noun_node_df = text_unit_df.explode("noun_phrases")
noun_node_df = noun_node_df.rename(
columns={"noun_phrases": "title", "id": "text_unit_id"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def extract(self, text: str) -> list[str]:
Returns: List of noun phrases.
"""

@abstractmethod
def __str__(self) -> str:
"""Return string representation of the extractor, used for cache key generation."""
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,7 @@ def _tag_noun_phrases(
cleaned_tokens, self.max_word_length
),
}

def __str__(self) -> str:
"""Return string representation of the extractor, used for cache key generation."""
return f"cfg_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}_{self.noun_phrase_grammars}_{self.noun_phrase_tags}"
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ def _tag_noun_phrases(
"has_compound_words": has_compound_words,
"has_valid_tokens": has_valid_tokens,
}

def __str__(self) -> str:
"""Return string representation of the extractor, used for cache key generation."""
return f"regex_en_{self.exclude_nouns}_{self.max_word_length}_{self.word_delimiter}"
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,7 @@ def _tag_noun_phrases(
cleaned_token_texts, self.max_word_length
),
}

def __str__(self) -> str:
"""Return string representation of the extractor, used for cache key generation."""
return f"syntactic_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}"
4 changes: 3 additions & 1 deletion graphrag/index/run/derive_from_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import pandas as pd

from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.logger.progress import progress_ticker
Expand All @@ -33,11 +34,12 @@ def __init__(self, num_errors: int, example: str | None = None):
async def derive_from_rows(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
callbacks: WorkflowCallbacks | None = None,
num_threads: int = 4,
async_type: AsyncType = AsyncType.AsyncIO,
) -> list[ItemType | None]:
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
callbacks = callbacks or NoopWorkflowCallbacks()
match async_type:
case AsyncType.AsyncIO:
return await derive_from_rows_asyncio(
Expand Down
3 changes: 2 additions & 1 deletion graphrag/index/workflows/extract_graph_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ async def run_workflow(
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage("text_units", context.storage)

entities, relationships = extract_graph_nlp(
entities, relationships = await extract_graph_nlp(
text_units,
callbacks,
context.cache,
extraction_config=config.extract_graph_nlp,
pruning_config=config.prune_graph,
embed_config=config.embed_graph,
Expand Down

0 comments on commit a6a78d5

Please sign in to comment.