Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance prompt template formatting: support nested placeholders and unpaired braces handling #1653

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 4 additions & 58 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,58 +1,4 @@
# Python Artifacts
python/*/lib/
dist/

# Test Output
.coverage
coverage/
licenses.txt
examples_notebooks/*/data
tests/fixtures/cache
tests/fixtures/*/cache
tests/fixtures/*/output
output/lancedb


# Random
.DS_Store
*.log*
.venv
venv/
.conda
.tmp

.env
build.zip

.turbo

__pycache__

.pipeline

# Azurite
temp_azurite/
__azurite*.json
__blobstorage*.json
__blobstorage__/

# Getting started example
ragtest/
.ragtest/
.pipelines
.pipeline


# mkdocs
site/

# Docs migration
docsite/
.yarn/
.pnp*

# PyCharm
.idea/

# Jupyter notebook
.ipynb_checkpoints/
**/__pycache__/
workdir/
sac/
prompts/
90 changes: 60 additions & 30 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,67 @@
"version": "0.2.0",
"configurations": [
{
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
"name": "Debug Graphrag init",
"type": "python",
"request": "launch",
"module": "graphrag",
"args": [
"init",
"--root", "${workspaceFolder}/workdir"
],
},
"console": "integratedTerminal"
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"--query", "What are the top themes in this story",
]
},
{
"name": "Prompt Tuning",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "prompt-tune",
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
"name": "Debug Graphrag index",
"type": "python",
"request": "launch",
"module": "graphrag",
"args": [
"index",
"--root", "${workspaceFolder}/workdir"
],
"justMyCode": false, // 设置为 false 以调试第三方库
"console": "integratedTerminal"
},
{
"name": "Run Graphrag index",
"type": "python",
"request": "launch",
"module": "graphrag",
"args": [
"index",
"--root", "${workspaceFolder}/workdir"
],
"justMyCode": false, // 设置为 false 以调试第三方库
"console": "integratedTerminal",
"noDebug": true
},
{
"name": "Debug Graphrag prompt-tune",
"type": "python",
"request": "launch",
"module": "graphrag",
"args": [
"prompt-tune",
"--root", "${workspaceFolder}/workdir",
"--config", "${workspaceFolder}/workdir/settings.yaml",
"--discover-entity-types"
],
"console": "integratedTerminal"
},
{
"name": "Run Graphrag prompt-tune",
"type": "python",
"request": "launch",
"module": "graphrag",
"args": [
"prompt-tune",
"--root", "${workspaceFolder}/workdir",
"--config", "${workspaceFolder}/workdir/settings.yaml",
"--discover-entity-types"
],
"console": "integratedTerminal",
"noDebug": true
}
]
}
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"node_modules{,/**}",
".vscode{,/**}"
],
"python.defaultInterpreterPath": "python/services/.venv/bin/python",
"python.defaultInterpreterPath": ".venv/bin/python",
"python.languageServer": "Pylance",
"cSpell.customDictionaries": {
"project-words": {
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Parameterization settings for the default configuration."""

from devtools import pformat
from pprint import pformat
from pydantic import Field

import graphrag.config.defaults as defs
Expand Down
65 changes: 60 additions & 5 deletions graphrag/index/operations/extract_entities/graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
from collections.abc import Mapping
from dataclasses import dataclass
from string import Formatter
from typing import Any

import networkx as nx
Expand All @@ -30,7 +31,58 @@

log = logging.getLogger(__name__)


class SafeFormatter(Formatter):
def __init__(self):
# 匹配嵌套结构占位符的正则表达式
self.nested_pattern = re.compile(r"{([^{}]+)}")
# 匹配未成对的单独花括号
self.unpaired_pattern = re.compile(r"(?:{[^{}]*$|^[^{}]*}|{[^{}]*}|})")


def format(self, format_string, *args, **kwargs):
# 替换未成对的花括号
format_string = self._replace_unpaired_braces(format_string)
# 替换嵌套结构占位符
format_string = self._replace_nested(format_string)
# 使用父类的 format 方法处理非嵌套占位符
return super().format(format_string, *args, **kwargs)

def get_value(self, key, args, kwargs) -> Any:
# 仅处理字符串键,如果 key 不存在于 kwargs,则保留原样占位符
if isinstance(key, str) and key in kwargs:
return kwargs[key]
return f"{{{key}}}"

def _replace_unpaired_braces(self, format_string):
"""
替换未成对的花括号为普通字符 '{' 或 '}'。
"""
# 替换未成对的 `{` 或 `}` 为普通字符
def replace_unpaired(match):
unmatched = match.group(0)
# 如果是未闭合的 `{` 或单独的 `}`, 替换为普通字符
if unmatched.startswith("{") and unmatched.endswith("}"):
return unmatched # 保留合法的占位符
elif unmatched.startswith("{"):
return "{{" # 替换未闭合的 `{` 为普通字符
elif unmatched.endswith("}"):
return "}}" # 替换单独的 `}` 为普通字符
return unmatched
return self.unpaired_pattern.sub(replace_unpaired, format_string)

def _replace_nested(self, format_string):
"""
替换嵌套结构为普通文本{{}}
"""
def replace_nested(match):
key = match.group(1)
# # 如果是嵌套结构(检测到 "[" 或 "."),转义为 {{...}}
if "[" in key or "." in key:
return f"{{{{{key}}}}}" # 双括号转义,避免解析
return match.group(0) # 保留原始内容
return self.nested_pattern.sub(replace_nested, format_string)


@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
Expand Down Expand Up @@ -152,11 +204,14 @@ async def __call__(
async def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
formatter = SafeFormatter()
kwargs = {
**prompt_variables,
self._input_text_key: text
}
formated_prompt = formatter.format(self._extraction_prompt, **kwargs)
response = await self._llm(
self._extraction_prompt.format(**{
**prompt_variables,
self._input_text_key: text,
}),
formated_prompt
)
results = response.output.content or ""

Expand Down
Loading