Skip to content

Commit

Permalink
Added create index api
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalkc9565 committed Mar 2, 2024
1 parent 1c6bd00 commit 8b0bf49
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 19 deletions.
119 changes: 107 additions & 12 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dotenv import load_dotenv
load_dotenv()

import requests
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
import os
Expand All @@ -14,13 +14,7 @@
client = MongoClient(os.getenv("MONGODB_URI"), server_api=ServerApi('1'))

# connect to Atlas as a vector store
store = MongoDBAtlasVectorSearch(
client,
db_name=os.getenv('MONGODB_DATABASE'), # this is the database where you stored your embeddings
collection_name=os.getenv('MONGODB_VECTORS'), # this is where your embeddings were stored in 2_load_and_index.py
index_name=os.getenv('MONGODB_VECTOR_INDEX') # this is the name of the index you created after loading your data
)
index = VectorStoreIndex.from_vector_store(store)


app = Flask(__name__)
cors = CORS(app)
Expand All @@ -34,20 +28,121 @@ def hello_world():
"message": "hello world"
})

@app.route('/create_index', methods=['POST'])
@cross_origin()
def create_index():
user_id = request.json['user_id']
database_name = "test"
collection_name = user_id.split('@')[0] + '_invoice'
vector_collection_name = user_id.split('@')[0] + '_invoice_vector'
vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_')
uri = os.getenv("MONGODB_API_URI")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/vnd.atlas.2023-02-01+json',
}
payload = {
"collectionName": vector_collection_name,
"database": database_name,
"name": vector_index_name,
"type": "search",
"mappings": {
"dynamic": True,
"fields": {
"embedding": {
"dimensions": 1536,
"similarity": "cosine",
"type": "knnVector"
}
}
}
}
try:
response = requests.post(
f"{uri}/groups/{os.getenv('MONGODB_ATLAS_GROUP_ID')}/clusters/{os.getenv('MONGODB_ATLAS_CLUSTER_NAME')}/fts/indexes",
headers=headers,
auth=requests.auth.HTTPDigestAuth(os.getenv('MONGODB_ATLAS_USERNAME'), os.getenv('MONGODB_ATLAS_PASSWORD')),
json=payload
)
response.raise_for_status()
return jsonify({"status": "success", "data": response.json()})
except requests.exceptions.HTTPError as e:
return jsonify({"status": "fail", "error": e.response.json()}), e.response.status_code
except requests.exceptions.RequestException as e:
return jsonify({"status": "fail", "error": str(e)}), 500
except Exception as e:
return jsonify({"status": "fail", "error": str(e)}), 500

@app.route('/list_indexes/<database_name>/<collection_name>', methods=['GET'])
@cross_origin()
def list_indexes(database_name, collection_name):
uri = os.getenv("MONGODB_API_URI")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/vnd.atlas.2023-02-01+json',
}
try:
response = requests.get(
f"{uri}/groups/{os.getenv('MONGODB_ATLAS_GROUP_ID')}/clusters/{os.getenv('MONGODB_ATLAS_CLUSTER_NAME')}/fts/indexes/{database_name}/{collection_name}",
headers=headers,
auth=requests.auth.HTTPDigestAuth(os.getenv('MONGODB_ATLAS_USERNAME'), os.getenv('MONGODB_ATLAS_PASSWORD')),
)
response.raise_for_status()
return jsonify({"status": "success", "data": response.json()})
except requests.exceptions.HTTPError as e:
return jsonify({"status": "fail", "error": e.response.json()}), e.response.status_code
except requests.exceptions.RequestException as e:
return jsonify({"status": "fail", "error": str(e)}), 500
except Exception as e:
return jsonify({"status": "fail", "error": str(e)}), 500




@app.route('/process', methods=['POST'])
@cross_origin()
def process():
is_processed = process_entries(client = client)
if not is_processed:
return jsonify({"status": "failed", "error": "process failed"}), 400
return jsonify({"status": "success", "message": "process successful"})
try:
user_id = request.json['user_id']
database_name = "test"
collection_name = user_id.split('@')[0] + '_invoice'
vector_collection_name = user_id.split('@')[0] + '_invoice_vector'
vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_')
is_processed = process_entries(
client=client,
database_name=database_name,
collection_name=collection_name,
vector_collection_name=vector_collection_name,
vector_index_name=vector_index_name
)

if not is_processed:
return jsonify({"status": "fail", "error": "process failed"}), 400

return jsonify({"status": "success", "message": "process successful"}), 200
except KeyError:
return jsonify({"status": "fail", "error": "invalid request body"}), 400
except Exception as e:
return jsonify({"status": "fail", "error": f"{e}"}), 500


@app.route('/query', methods=['POST'])
@cross_origin()
def process_form():
# get the query
query = request.json["query"]
user_id = request.json['user_id']
database_name = "test"
vector_collection_name = user_id.split('@')[0] + '_invoice_vector'
vector_index_name = (user_id.split('@')[0] + '_invoice_vector_index').replace('.', '_')
store = MongoDBAtlasVectorSearch(
client,
db_name=database_name, # this is the database where you stored your embeddings
collection_name=vector_collection_name, # this is where your embeddings were stored in 2_load_and_index.py
index_name=vector_index_name # this is the name of the index you created after loading your data
)
index = VectorStoreIndex.from_vector_store(store)


if query is not None:
# query your data!
Expand Down
16 changes: 9 additions & 7 deletions process.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,26 @@
# llamaindex has a special class that does this for you
# it pulls every object in a given collection

def process_entries(client):
collection = client[os.getenv("MONGODB_DATABASE")][os.getenv("MONGODB_COLLECTION")]
def process_entries(client, database_name, collection_name, vector_collection_name, vector_index_name):
# connect to Atlas as a vector store

query_dict = {"processed": False}
collection = client[database_name][collection_name]
unprocessed_entries = collection.find(query_dict)
reader = SimpleMongoReader(uri=os.getenv("MONGODB_URI"))
documents = reader.load_data(
os.getenv("MONGODB_DATABASE"),
os.getenv("MONGODB_COLLECTION"), # this is the collection where the objects you loaded in 1_import got stored
database_name,
collection_name, # this is the collection where the objects you loaded in 1_import got stored
# field_names=["saleDate", "items", "storeLocation", "customer", "couponUsed", "purchaseMethod"], # these is a list of the top-level fields in your objects that will be indexed
field_names=["text"], # make sure your objects have a field called "full_text" or that you change this value
query_dict=query_dict # this is a mongo query dict that will filter your data if you don't want to index everything
)

store = MongoDBAtlasVectorSearch(
client,
db_name=os.getenv('MONGODB_DATABASE'),
collection_name=os.getenv('MONGODB_VECTORS'), # this is where your embeddings will be stored
index_name=os.getenv('MONGODB_VECTOR_INDEX') # this is the name of the index you will need to create
db_name=database_name, # this is the database where you stored your embeddings
collection_name=vector_collection_name, # this is where your embeddings will be stored
index_name=vector_index_name # this is the name of the index you will need to create
)
# # create Atlas as a vector store
# now create an index from all the Documents and store them in Atlas
Expand Down

0 comments on commit 8b0bf49

Please sign in to comment.