diff --git a/app.py b/app.py index 779c5a8..251ba37 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,6 @@ from dotenv import load_dotenv load_dotenv() - +import requests from flask import Flask, request, jsonify from flask_cors import CORS, cross_origin import os @@ -14,13 +14,7 @@ client = MongoClient(os.getenv("MONGODB_URI"), server_api=ServerApi('1')) # connect to Atlas as a vector store -store = MongoDBAtlasVectorSearch( - client, - db_name=os.getenv('MONGODB_DATABASE'), # this is the database where you stored your embeddings - collection_name=os.getenv('MONGODB_VECTORS'), # this is where your embeddings were stored in 2_load_and_index.py - index_name=os.getenv('MONGODB_VECTOR_INDEX') # this is the name of the index you created after loading your data -) -index = VectorStoreIndex.from_vector_store(store) + app = Flask(__name__) cors = CORS(app) @@ -34,13 +28,102 @@ def hello_world(): "message": "hello world" }) +@app.route('/create_index', methods=['POST']) +@cross_origin() +def create_index(): + user_id = request.json['user_id'] + database_name = "test" + collection_name = user_id.split('@')[0] + '_invoice' + vector_collection_name = user_id.split('@')[0] + '_invoice_vector' + vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_') + uri = os.getenv("MONGODB_API_URI") + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/vnd.atlas.2023-02-01+json', + } + payload = { + "collectionName": vector_collection_name, + "database": database_name, + "name": vector_index_name, + "type": "search", + "mappings": { + "dynamic": True, + "fields": { + "embedding": { + "dimensions": 1536, + "similarity": "cosine", + "type": "knnVector" + } + } + } + } + try: + response = requests.post( + f"{uri}/groups/{os.getenv('MONGODB_ATLAS_GROUP_ID')}/clusters/{os.getenv('MONGODB_ATLAS_CLUSTER_NAME')}/fts/indexes", + headers=headers, + auth=requests.auth.HTTPDigestAuth(os.getenv('MONGODB_ATLAS_USERNAME'), os.getenv('MONGODB_ATLAS_PASSWORD')), + json=payload + ) + response.raise_for_status() + return jsonify({"status": "success", "data": response.json()}) + except requests.exceptions.HTTPError as e: + return jsonify({"status": "fail", "error": e.response.json()}), e.response.status_code + except requests.exceptions.RequestException as e: + return jsonify({"status": "fail", "error": str(e)}), 500 + except Exception as e: + return jsonify({"status": "fail", "error": str(e)}), 500 + +@app.route('/list_indexes//', methods=['GET']) +@cross_origin() +def list_indexes(database_name, collection_name): + uri = os.getenv("MONGODB_API_URI") + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/vnd.atlas.2023-02-01+json', + } + try: + response = requests.get( + f"{uri}/groups/{os.getenv('MONGODB_ATLAS_GROUP_ID')}/clusters/{os.getenv('MONGODB_ATLAS_CLUSTER_NAME')}/fts/indexes/{database_name}/{collection_name}", + headers=headers, + auth=requests.auth.HTTPDigestAuth(os.getenv('MONGODB_ATLAS_USERNAME'), os.getenv('MONGODB_ATLAS_PASSWORD')), + ) + response.raise_for_status() + return jsonify({"status": "success", "data": response.json()}) + except requests.exceptions.HTTPError as e: + return jsonify({"status": "fail", "error": e.response.json()}), e.response.status_code + except requests.exceptions.RequestException as e: + return jsonify({"status": "fail", "error": str(e)}), 500 + except Exception as e: + return jsonify({"status": "fail", "error": str(e)}), 500 + + + + @app.route('/process', methods=['POST']) @cross_origin() def process(): - is_processed = process_entries(client = client) - if not is_processed: - return jsonify({"status": "failed", "error": "process failed"}), 400 - return jsonify({"status": "success", "message": "process successful"}) + try: + user_id = request.json['user_id'] + database_name = "test" + collection_name = user_id.split('@')[0] + '_invoice' + vector_collection_name = user_id.split('@')[0] + '_invoice_vector' + vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_') + is_processed = process_entries( + client=client, + database_name=database_name, + collection_name=collection_name, + vector_collection_name=vector_collection_name, + vector_index_name=vector_index_name + ) + + if not is_processed: + return jsonify({"status": "fail", "error": "process failed"}), 400 + + return jsonify({"status": "success", "message": "process successful"}), 200 + except KeyError: + return jsonify({"status": "fail", "error": "invalid request body"}), 400 + except Exception as e: + return jsonify({"status": "fail", "error": f"{e}"}), 500 @app.route('/query', methods=['POST']) @@ -48,6 +131,18 @@ def process(): def process_form(): # get the query query = request.json["query"] + user_id = request.json['user_id'] + database_name = "test" + vector_collection_name = user_id.split('@')[0] + '_invoice_vector' + vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_') + store = MongoDBAtlasVectorSearch( + client, + db_name=database_name, # this is the database where you stored your embeddings + collection_name=vector_collection_name, # this is where your embeddings were stored in 2_load_and_index.py + index_name=vector_index_name # this is the name of the index you created after loading your data + ) + index = VectorStoreIndex.from_vector_store(store) + if query is not None: # query your data! diff --git a/process.py b/process.py index cf00c98..ed4fb7c 100644 --- a/process.py +++ b/process.py @@ -28,14 +28,16 @@ # llamaindex has a special class that does this for you # it pulls every object in a given collection -def process_entries(client): - collection = client[os.getenv("MONGODB_DATABASE")][os.getenv("MONGODB_COLLECTION")] +def process_entries(client, database_name, collection_name, vector_collection_name, vector_index_name): + # connect to Atlas as a vector store + query_dict = {"processed": False} + collection = client[database_name][collection_name] unprocessed_entries = collection.find(query_dict) reader = SimpleMongoReader(uri=os.getenv("MONGODB_URI")) documents = reader.load_data( - os.getenv("MONGODB_DATABASE"), - os.getenv("MONGODB_COLLECTION"), # this is the collection where the objects you loaded in 1_import got stored + database_name, + collection_name, # this is the collection where the objects you loaded in 1_import got stored # field_names=["saleDate", "items", "storeLocation", "customer", "couponUsed", "purchaseMethod"], # these is a list of the top-level fields in your objects that will be indexed field_names=["text"], # make sure your objects have a field called "full_text" or that you change this value query_dict=query_dict # this is a mongo query dict that will filter your data if you don't want to index everything @@ -43,9 +45,9 @@ def process_entries(client): store = MongoDBAtlasVectorSearch( client, - db_name=os.getenv('MONGODB_DATABASE'), - collection_name=os.getenv('MONGODB_VECTORS'), # this is where your embeddings will be stored - index_name=os.getenv('MONGODB_VECTOR_INDEX') # this is the name of the index you will need to create + db_name=database_name, # this is the database where you stored your embeddings + collection_name=vector_collection_name, # this is where your embeddings will be stored + index_name=vector_index_name # this is the name of the index you will need to create ) # # create Atlas as a vector store # now create an index from all the Documents and store them in Atlas