-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathquery.py
66 lines (51 loc) · 1.93 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from src.query_utils import build_model
class Query():
def __init__(self, cfg) -> None:
# Grab values from config
self.cfg = cfg
self.logger = cfg.logger
self.responses = getattr(cfg, 'responses', 5)
# Construct the model (if not already built)
if not hasattr(cfg, "get_embeddings"):
cfg.model = build_model(cfg)
self.get_embeddings = cfg.get_embeddings
# Compute the FIASS Index on the dataset
self.logger.info(f'Computing FAISS Index on dataset embeddings')
self.dataset = cfg.dataset
self.dataset.add_faiss_index(column="embeddings")
def __call__(self, query : str):
assert type(query) is str
self.logger.debug(f'you asked me: {query}')
emb = self.get_embeddings([query]).detach().cpu().numpy()
scores, samples = self.dataset.get_nearest_examples(
"embeddings", emb, k=self.responses
)
return scores, samples
if __name__ == "__main__":
import logging
import pandas as pd
from src.parser import parse_args
from src.dataset_generation import get_dataset
cfg, _ = parse_args()
cfg.logger = logging.getLogger('HFGitHubIssuesLogger')
cfg.dataset = get_dataset(cfg)
# print("Dataset: \n", cfg.dataset)
# Optional, can also be constructed in Query.__init__
cfg.model = build_model(cfg)
# Calculate FIASS on dataset and perform query
q = Query(cfg)
scores, samples = q( cfg.query )
# Display results
print()
print(f"> {cfg.query}")
print()
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
for _, row in samples_df.iterrows():
print(f"COMMENT: {row.comments}")
print(f"SCORE: {row.scores}")
print(f"TITLE: {row.title}")
print(f"URL: {row.html_url}")
print("=" * 50)
print()