Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest chat mappings #548

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,11 @@ def __get_unlinked_cres(self) -> List[CRE]:
.all()
)
return cres
def get_all_nodes_and_cres(self):
return self.__get_all_nodes_and_cres()

def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]:
result = []
nodes = []
cres = []
node_ids = self.session.query(Node.id).all()
for nid in node_ids:
result.extend(self.get_nodes(db_id=nid[0]))
Expand Down
20 changes: 20 additions & 0 deletions application/prompt_client/openai_prompt_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import openai
import logging

Expand Down Expand Up @@ -58,3 +59,22 @@ def query_llm(self, raw_question: str) -> str:
messages=messages,
)
return response.choices[0].message["content"].strip()

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
messages = [
{
"role": "system",
"content": f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate.",
},
{
"role": "user",
"content": f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
},
]
openai.api_key = self.api_key
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.0,
)
return response.choices[0].message.content.strip()
80 changes: 60 additions & 20 deletions application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from application.database import db
from application.defs import cre_defs
from application.prompt_client import openai_prompt_client, vertex_prompt_client
from application.defs import cre_defs as defs
from application.prompt_client import (
openai_prompt_client,
vertex_prompt_client,
spacy_prompt_client,
)
from datetime import datetime
from multiprocessing import Pool
from nltk.corpus import stopwords
Expand All @@ -25,6 +29,8 @@


def is_valid_url(url):
if not url:
return False
return url.startswith("http://") or url.startswith("https://")


Expand Down Expand Up @@ -103,9 +109,9 @@ def find_missing_embeddings(self, database: db.Node_collection) -> List[str]:
"""
logger.info(f"syncing nodes with embeddings")
missing_embeddings = []
for doc_type in cre_defs.Credoctypes:
for doc_type in defs.Credoctypes:
db_ids = []
if doc_type.value == cre_defs.Credoctypes.CRE:
if doc_type.value == defs.Credoctypes.CRE:
db_ids = [a[0] for a in database.list_cre_ids()]
else:
db_ids = [a[0] for a in database.list_node_ids_by_ntype(doc_type.value)]
Expand All @@ -128,10 +134,10 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str):
For example if "ASVS" is passed the method will generate all embeddings for ASVS
Args:
database (db.Node_collection): the Node_collection instance to use
item_name (str): the item for which to generate embeddings, this can be either `cre_defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool.
item_name (str): the item for which to generate embeddings, this can be either `defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool.
"""
db_ids = []
if item_name == cre_defs.Credoctypes.CRE.value:
if item_name == defs.Credoctypes.CRE.value:
db_ids = [a[0] for a in database.list_cre_ids()]
else:
db_ids = [a[0] for a in database.list_node_ids_by_name(item_name)]
Expand All @@ -144,11 +150,13 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str):
def generate_embeddings(
self, database: db.Node_collection, missing_embeddings: List[str]
):
"""method generate embeddings accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects"""
"""
accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects
"""
logger.info(f"generating {len(missing_embeddings)} embeddings")
for id in missing_embeddings:
cre = database.get_cre_by_db_id(id)
node = database.get_nodes(db_id=id)
node = database.get_nodes(db_id=id)[0]
content = ""
if node:
if is_valid_url(node.hyperlink):
Expand All @@ -174,9 +182,16 @@ def generate_embeddings(
if not dbcre:
logger.fatal(node, "cannot be converted to database Node")
dbcre.id = id
database.add_embedding(
dbcre, cre_defs.Credoctypes.CRE, embedding, content
)
database.add_embedding(dbcre, defs.Credoctypes.CRE, embedding, content)

def generate_embeddings_for_document(self, node: defs.Node):
content = ""
if is_valid_url(node.hyperlink):
content = self.clean_content(self.get_content(node.hyperlink))
else:
content = node.__repr__()
logger.info(f"making embedding for {node.id}")
return self.ai_client.get_text_embeddings(content)


class PromptHandler:
Expand All @@ -197,8 +212,9 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N
os.getenv("OPENAI_API_KEY")
)
else:
logger.error(
"cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set "
self.ai_client = spacy_prompt_client.SpacyPromptClient()
logger.info(
"cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set, using spacy "
)
self.database = database
self.embeddings_instance = in_memory_embeddings.instance().with_ai_client(
Expand All @@ -219,6 +235,12 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N
f"there are {len(missing_embeddings)} embeddings missing from the dataset, db inclompete"
)

def generate_embeddings_for_document(self, node: defs.Node):
self.embeddings_instance.setup_playwright()
embeddings = self.embeddings_instance.generate_embeddings_for_document(node)
self.embeddings_instance.teardown_playwright()
return embeddings

def generate_embeddings_for(self, item_name: str):
self.embeddings_instance.setup_playwright()
self.embeddings_instance.generate_embeddings_for(self.database, item_name)
Expand Down Expand Up @@ -277,7 +299,7 @@ def get_id_of_most_similar_cre(self, item_embedding: List[float]) -> Optional[st
self.existing_cre_embeddings,
self.existing_cre_ids,
) = self.__load_cre_embeddings(
self.database.get_embeddings_by_doc_type(cre_defs.Credoctypes.CRE.value)
self.database.get_embeddings_by_doc_type(defs.Credoctypes.CRE.value)
)
if not self.existing_cre_embeddings.getnnz() or not len(self.existing_cre_ids):
raise ValueError(
Expand Down Expand Up @@ -316,7 +338,7 @@ def get_id_of_most_similar_node(self, standard_text_embedding: List[float]) -> s
self.existing_node_ids,
) = self.__load_node_embeddings(
self.database.get_embeddings_by_doc_type(
cre_defs.Credoctypes.Standard.value
defs.Credoctypes.Standard.value
)
)
if not self.existing_node_embeddings.getnnz() or not len(
Expand Down Expand Up @@ -354,13 +376,12 @@ def get_id_of_most_similar_cre_paginated(
embedding_array = sparse.csr_matrix(
np.array(item_embedding).reshape(1, -1)
) # convert embedding into a 1-dimentional numpy array

(
embeddings,
total_pages,
starting_page,
) = self.database.get_embeddings_by_doc_type_paginated(
cre_defs.Credoctypes.CRE.value
defs.Credoctypes.CRE.value
)
max_similarity = -1
most_similar_index = 0
Expand All @@ -378,7 +399,7 @@ def get_id_of_most_similar_cre_paginated(
total_pages,
_,
) = self.database.get_embeddings_by_doc_type_paginated(
cre_defs.Credoctypes.CRE.value, page=page
defs.Credoctypes.CRE.value, page=page
)

if max_similarity < similarity_threshold:
Expand Down Expand Up @@ -411,7 +432,7 @@ def get_id_of_most_similar_node_paginated(
total_pages,
starting_page,
) = self.database.get_embeddings_by_doc_type_paginated(
doc_type=cre_defs.Credoctypes.Standard.value,
doc_type=defs.Credoctypes.Standard.value,
page=1,
)

Expand All @@ -429,7 +450,7 @@ def get_id_of_most_similar_node_paginated(
most_similar_id = existing_standard_ids[most_similar_index]

embeddings, _, _ = self.database.get_embeddings_by_doc_type_paginated(
doc_type=cre_defs.Credoctypes.Standard.value, page=page
doc_type=defs.Credoctypes.Standard.value, page=page
)
if max_similarity < similarity_threshold:
logger.info(
Expand Down Expand Up @@ -496,3 +517,22 @@ def generate_text(self, prompt: str) -> Dict[str, str]:
table = [closest_object]
result = f"Answer: {answer}"
return {"response": result, "table": table, "accurate": accurate}

def get_id_of_most_similar_cre_using_chat(
self, item: defs.Document
) -> Optional[str]:
# load all cres
content = ""
if item.hyperlink:
content = self.embeddings_instance.get_content(item.hyperlink)
else:
content = item.__repr__()
database = self.database
res = database.get_all_nodes_and_cres()
cres = [r for r in res if r.doctype == defs.Credoctypes.CRE.value]
cres_in_export_format = [f"{c.id}|{c.name}" for c in cres]
return self.ai_client.create_mapping_completion(
prompt="",
cre_id_and_name_in_export_format=cres_in_export_format,
standard_id_or_content=content,
)
35 changes: 35 additions & 0 deletions application/prompt_client/spacy_prompt_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import spacy
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class SpacyPromptClient:

def __init__(self) -> None:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.info(
"Downloading language model for the spaCy POS tagger\n"
"(don't worry, this will only happen once)"
)
from spacy.cli import download

download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")

def get_text_embeddings(self, text: str):
return self.nlp(text).vector

def create_chat_completion(self, prompt, closest_object_str) -> str:
raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)

def query_llm(self, raw_question: str) -> str:
raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)
7 changes: 7 additions & 0 deletions application/prompt_client/vertex_prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,10 @@ def query_llm(self, raw_question: str) -> str:
msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant."
response = self.chat.send_message(msg, **parameters)
return response.text

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS}

msg= f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate."\
f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
return self.chat.send_message(msg, **parameters).text
44 changes: 44 additions & 0 deletions application/tests/spreadsheet_parsers_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from pprint import pprint
import unittest
from application.database import db
from application.tests.utils import data_gen
from application.defs import cre_defs as defs
from application import create_app, sqla # type: ignore
from application.utils.spreadsheet_parsers import (
parse_export_format,
parse_hierarchical_export_format,
suggest_from_export_format,
)


Expand Down Expand Up @@ -37,6 +40,47 @@ def test_parse_hierarchical_export_format(self) -> None:
for element in v:
self.assertIn(element, output[k])

def test_suggest_from_export_format(self) -> None:
self.app = create_app(mode="test")
self.app_context = self.app.app_context()
self.app_context.push()
sqla.create_all()
collection = db.Node_collection()

input_data, expected_output = data_gen.export_format_data()
for cre in expected_output[defs.Credoctypes.CRE.value]:
collection.add_cre(cre=cre)

# clean every other cre
index = 0
input_data_no_cres = []
for line in input_data:
no_cre_line = line.copy()
if index % 2 == 0:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)
output = suggest_from_export_format(
lfile=input_data_no_cres, database=collection
)
self.maxDiff = None

empty_lines = 0
for line in output:
cres_in_line = [
line[c] for c in line.keys() if c.startswith("CRE") and line[c]
]
if len(cres_in_line) == 0:
empty_lines += 1

self.assertGreater(
len(input_data) / 2, empty_lines
) # assert that there was at least some suggestions

sqla.session.remove()
sqla.drop_all()
self.app_context.pop()


if __name__ == "__main__":
unittest.main()
Loading
Loading