forked from SciPhi-AI/R2R
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding_pipeline.py
106 lines (89 loc) · 3.17 KB
/
embedding_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""A simple example to demonstrate the usage of `BasicEmbeddingPipeline`."""
import logging
import uuid
import dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from r2r.core import DatasetConfig, LoggingDatabaseConnection
from r2r.datasets import HuggingFaceDataProvider
from r2r.embeddings import OpenAIEmbeddingProvider
from r2r.main import load_config
from r2r.pipelines import BasicDocument, BasicEmbeddingPipeline
from r2r.vector_dbs import PGVectorDB, QdrantDB
if __name__ == "__main__":
dotenv.load_dotenv()
(
api_config,
logging_config,
embedding_config,
database_config,
llm_config,
text_splitter_config,
) = load_config()
logger = logging.getLogger(logging_config["name"])
logging.basicConfig(level=logging_config["level"])
logger.info("Starting the embedding pipeline")
# Specify the embedding provider
embeddings_provider = OpenAIEmbeddingProvider()
embedding_model = embedding_config["model"]
embedding_dimension = embedding_config["dimension"]
embedding_batch_size = embedding_config["batch_size"]
# Specify the vector database provider
db = (
QdrantDB()
if database_config["vector_db_provider"] == "qdrant"
else PGVectorDB()
)
collection_name = database_config["collection_name"]
db.initialize_collection(collection_name, embedding_dimension)
# Specify the dataset providers
dataset_provider = HuggingFaceDataProvider()
dataset_provider.load_datasets(
[
DatasetConfig("camel-ai/physics", None, 10, "message_2"),
DatasetConfig("camel-ai/chemistry", None, 10, "message_2"),
],
)
# Specify the chunking strategy
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=text_splitter_config["chunk_size"],
chunk_overlap=text_splitter_config["chunk_overlap"],
length_function=len,
is_separator_regex=False,
)
logging_database = LoggingDatabaseConnection(logging_config["database"])
pipeline = BasicEmbeddingPipeline(
embedding_model,
embeddings_provider,
db,
text_splitter=text_splitter,
embedding_batch_size=embedding_batch_size,
logging_database=logging_database,
)
entry_id = 0
document_batch = []
for entry in dataset_provider.stream_text():
if entry is None:
break
text, config = entry
document_id = str(uuid.uuid5(uuid.NAMESPACE_URL, config.name))
if text is None:
break
document_batch.append(
BasicDocument(
id=str(
uuid.uuid5(uuid.NAMESPACE_URL, f"{config.name}_{text}")
),
text=text,
metadata={"document_id": document_id},
)
)
entry_id += 1
if len(document_batch) == 1:
logging.info(
f"Processing batch of {len(document_batch)} documents."
)
pipeline.run(document_batch)
document_batch = []
logging.info(f"Processing final {len(document_batch)} documents.")
pipeline.run(document_batch)
pipeline.close()