diff --git a/.semversioner/next-release/patch-20250210180318886210.json b/.semversioner/next-release/patch-20250210180318886210.json new file mode 100644 index 0000000000..3b19d5c8d6 --- /dev/null +++ b/.semversioner/next-release/patch-20250210180318886210.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add caching to NLP extractor." +} diff --git a/graphrag/config/models/extract_graph_nlp_config.py b/graphrag/config/models/extract_graph_nlp_config.py index d8793a7557..90f7b5fcc7 100644 --- a/graphrag/config/models/extract_graph_nlp_config.py +++ b/graphrag/config/models/extract_graph_nlp_config.py @@ -64,3 +64,7 @@ class ExtractGraphNLPConfig(BaseModel): text_analyzer: TextAnalyzerConfig = Field( description="The text analyzer configuration.", default=TextAnalyzerConfig() ) + parallelization_num_threads: int = Field( + description="The number of threads to use for the extraction process.", + default=defs.PARALLELIZATION_NUM_THREADS, + ) diff --git a/graphrag/index/flows/extract_graph_nlp.py b/graphrag/index/flows/extract_graph_nlp.py index f1b04626f7..ed0161a41e 100644 --- a/graphrag/index/flows/extract_graph_nlp.py +++ b/graphrag/index/flows/extract_graph_nlp.py @@ -5,6 +5,7 @@ import pandas as pd +from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig @@ -20,9 +21,10 @@ from graphrag.index.operations.prune_graph import prune_graph -def extract_graph_nlp( +async def extract_graph_nlp( text_units: pd.DataFrame, callbacks: WorkflowCallbacks, + cache: PipelineCache, extraction_config: ExtractGraphNLPConfig, pruning_config: PruneGraphConfig, embed_config: EmbedGraphConfig | None = None, @@ -31,10 +33,12 @@ def extract_graph_nlp( """All the steps to create the base entity graph.""" text_analyzer_config = extraction_config.text_analyzer text_analyzer = create_noun_phrase_extractor(text_analyzer_config) - extracted_nodes, extracted_edges = build_noun_graph( + extracted_nodes, extracted_edges = await build_noun_graph( text_units, text_analyzer=text_analyzer, normalize_edge_weights=extraction_config.normalize_edge_weights, + num_threads=extraction_config.parallelization_num_threads, + cache=cache, ) # create a temporary graph to prune, then turn it back into dataframes diff --git a/graphrag/index/operations/build_noun_graph/build_noun_graph.py b/graphrag/index/operations/build_noun_graph/build_noun_graph.py index af8a2fd9e9..ade9729d60 100644 --- a/graphrag/index/operations/build_noun_graph/build_noun_graph.py +++ b/graphrag/index/operations/build_noun_graph/build_noun_graph.py @@ -7,27 +7,38 @@ import pandas as pd +from graphrag.cache.noop_pipeline_cache import NoopPipelineCache +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.config.enums import AsyncType from graphrag.index.operations.build_noun_graph.np_extractors.base import ( BaseNounPhraseExtractor, ) +from graphrag.index.run.derive_from_rows import derive_from_rows +from graphrag.index.utils.hashing import gen_sha512_hash -def build_noun_graph( +async def build_noun_graph( text_unit_df: pd.DataFrame, text_analyzer: BaseNounPhraseExtractor, normalize_edge_weights: bool, + num_threads: int = 4, + cache: PipelineCache | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Build a noun graph from text units.""" text_units = text_unit_df.loc[:, ["id", "text"]] - nodes_df = _extract_nodes(text_units, text_analyzer) + nodes_df = await _extract_nodes( + text_units, text_analyzer, num_threads=num_threads, cache=cache + ) edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights) return (nodes_df, edges_df) -def _extract_nodes( +async def _extract_nodes( text_unit_df: pd.DataFrame, text_analyzer: BaseNounPhraseExtractor, + num_threads: int = 4, + cache: PipelineCache | None = None, ) -> pd.DataFrame: """ Extract initial nodes and edges from text units. @@ -35,9 +46,26 @@ def _extract_nodes( Input: text unit df with schema [id, text, document_id] Returns a dataframe with schema [id, title, freq, text_unit_ids]. """ - text_unit_df["noun_phrases"] = text_unit_df["text"].apply( - lambda text: text_analyzer.extract(text) + cache = cache or NoopPipelineCache() + cache = cache.child("extract_noun_phrases") + + async def extract(row): + text = row["text"] + attrs = {"text": text, "analyzer": str(text_analyzer)} + key = gen_sha512_hash(attrs, attrs.keys()) + result = await cache.get(key) + if not result: + result = text_analyzer.extract(text) + await cache.set(key, result) + return result + + text_unit_df["noun_phrases"] = await derive_from_rows( + text_unit_df, + extract, + num_threads=num_threads, + async_type=AsyncType.Threaded, ) + noun_node_df = text_unit_df.explode("noun_phrases") noun_node_df = noun_node_df.rename( columns={"noun_phrases": "title", "id": "text_unit_id"} diff --git a/graphrag/index/operations/build_noun_graph/np_extractors/base.py b/graphrag/index/operations/build_noun_graph/np_extractors/base.py index 20131ad367..44390e5d14 100644 --- a/graphrag/index/operations/build_noun_graph/np_extractors/base.py +++ b/graphrag/index/operations/build_noun_graph/np_extractors/base.py @@ -33,3 +33,7 @@ def extract(self, text: str) -> list[str]: Returns: List of noun phrases. """ + + @abstractmethod + def __str__(self) -> str: + """Return string representation of the extractor, used for cache key generation.""" diff --git a/graphrag/index/operations/build_noun_graph/np_extractors/cfg_extractor.py b/graphrag/index/operations/build_noun_graph/np_extractors/cfg_extractor.py index b9ba88a9d0..ca9f4dcd69 100644 --- a/graphrag/index/operations/build_noun_graph/np_extractors/cfg_extractor.py +++ b/graphrag/index/operations/build_noun_graph/np_extractors/cfg_extractor.py @@ -172,3 +172,7 @@ def _tag_noun_phrases( cleaned_tokens, self.max_word_length ), } + + def __str__(self) -> str: + """Return string representation of the extractor, used for cache key generation.""" + return f"cfg_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}_{self.noun_phrase_grammars}_{self.noun_phrase_tags}" diff --git a/graphrag/index/operations/build_noun_graph/np_extractors/regex_extractor.py b/graphrag/index/operations/build_noun_graph/np_extractors/regex_extractor.py index 948d770625..2f14b68e0d 100644 --- a/graphrag/index/operations/build_noun_graph/np_extractors/regex_extractor.py +++ b/graphrag/index/operations/build_noun_graph/np_extractors/regex_extractor.py @@ -117,3 +117,7 @@ def _tag_noun_phrases( "has_compound_words": has_compound_words, "has_valid_tokens": has_valid_tokens, } + + def __str__(self) -> str: + """Return string representation of the extractor, used for cache key generation.""" + return f"regex_en_{self.exclude_nouns}_{self.max_word_length}_{self.word_delimiter}" diff --git a/graphrag/index/operations/build_noun_graph/np_extractors/syntactic_parsing_extractor.py b/graphrag/index/operations/build_noun_graph/np_extractors/syntactic_parsing_extractor.py index c4bff1d6bf..2b7edf2645 100644 --- a/graphrag/index/operations/build_noun_graph/np_extractors/syntactic_parsing_extractor.py +++ b/graphrag/index/operations/build_noun_graph/np_extractors/syntactic_parsing_extractor.py @@ -157,3 +157,7 @@ def _tag_noun_phrases( cleaned_token_texts, self.max_word_length ), } + + def __str__(self) -> str: + """Return string representation of the extractor, used for cache key generation.""" + return f"syntactic_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}" diff --git a/graphrag/index/run/derive_from_rows.py b/graphrag/index/run/derive_from_rows.py index e9648943d9..6a29d8848f 100644 --- a/graphrag/index/run/derive_from_rows.py +++ b/graphrag/index/run/derive_from_rows.py @@ -12,6 +12,7 @@ import pandas as pd +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.logger.progress import progress_ticker @@ -33,11 +34,12 @@ def __init__(self, num_errors: int, example: str | None = None): async def derive_from_rows( input: pd.DataFrame, transform: Callable[[pd.Series], Awaitable[ItemType]], - callbacks: WorkflowCallbacks, + callbacks: WorkflowCallbacks | None = None, num_threads: int = 4, async_type: AsyncType = AsyncType.AsyncIO, ) -> list[ItemType | None]: """Apply a generic transform function to each row. Any errors will be reported and thrown.""" + callbacks = callbacks or NoopWorkflowCallbacks() match async_type: case AsyncType.AsyncIO: return await derive_from_rows_asyncio( diff --git a/graphrag/index/workflows/extract_graph_nlp.py b/graphrag/index/workflows/extract_graph_nlp.py index 8efbdc1551..4be186392b 100644 --- a/graphrag/index/workflows/extract_graph_nlp.py +++ b/graphrag/index/workflows/extract_graph_nlp.py @@ -26,9 +26,10 @@ async def run_workflow( """All the steps to create the base entity graph.""" text_units = await load_table_from_storage("text_units", context.storage) - entities, relationships = extract_graph_nlp( + entities, relationships = await extract_graph_nlp( text_units, callbacks, + context.cache, extraction_config=config.extract_graph_nlp, pruning_config=config.prune_graph, embed_config=config.embed_graph,