Skip to content

Commit

Permalink
Support blob input
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardobl committed Apr 19, 2024
1 parent b5766c2 commit 6498bc8
Showing 1 changed file with 63 additions and 22 deletions.
85 changes: 63 additions & 22 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,69 @@ def log(rpcClient, message, jobId=None):
rpcClient.logForJob(rpc_pb2.RpcJobLog(jobId=jobId, log=message))


def deserializeFromHyperBlob(rpcClient, url, marker, index_vectors ,index_content, searches_vectors):
blobDiskId = rpcClient.openDisk(rpc_pb2.RpcOpenDiskRequest(url=url)).diskId
files = rpcClient.diskListFiles(rpc_pb2.RpcDiskListFilesRequest(diskId=blobDiskId, path="/"))
embeddings_files = [f for f in files.files if f.endswith(".embeddings")]
sentences = []
vectors = []
dtype = None
shape = None
for f in embeddings_files:
sentence_bytes=bytearray()
for r in rpcClient.diskReadFile(rpc_pb2.RpcDiskReadFileRequest(diskId=blobDiskId, path=f)):
sentence_bytes.extend(r.data)
vectors_bytes=bytearray()
for r in rpcClient.diskReadFile(rpc_pb2.RpcDiskReadFileRequest(diskId=blobDiskId, path=f+".vectors")):
vectors_bytes.extend(r.data)
shape_bytes=rpcClient.diskReadSmallFile(rpc_pb2.RpcDiskReadFileRequest(diskId=blobDiskId, path=f+".shape")).data
dtype_bytes=rpcClient.diskReadSmallFile(rpc_pb2.RpcDiskReadFileRequest(diskId=blobDiskId, path=f+".dtype")).data

# sentence_bytes to string
sentence = sentence_bytes.decode("utf-8")
dtype = dtype_bytes.decode("utf-8")
shape = json.loads(shape_bytes.decode("utf-8"))
embeddings = np.frombuffer(vectors_bytes, dtype=dtype).reshape(shape)

if marker == "query":
searches_vectors.append(embeddings)
else:
index_vectors.append(embeddings)
index_content.append(sentence)
return [dtype,shape]


def deserializeFromJSON(rpcClient, data, marker, index_vectors ,index_content, searches_vectors):
data=json.loads(data)
for part in data:
[text,embeddings_b64,_dtype,_shape] = part
if dtype is None: dtype = _dtype
elif dtype != _dtype: raise Exception("Data type mismatch")
if shape is None: shape = _shape
elif shape != _shape: raise Exception("Shape mismatch")
embeddings_bytes = base64.b64decode(embeddings_b64)
embeddings = np.frombuffer(embeddings_bytes, dtype=dtype).reshape(shape)
if marker == "query":
searches_vectors.append(embeddings)
else:
index_vectors.append(embeddings)
index_content.append(text)
return [dtype,shape]


def deserialize(rpcClient, jin,index_vectors ,index_content,searches_vectors):
dtype = None
shape = None
data = jin.data
dataType = jin.type
marker = jin.marker

if dataType == "application/hyperblob":
[dtype,shape] = deserializeFromHyperBlob(rpcClient, data, marker, index_vectors ,index_content, searches_vectors)
else:
#if dataType == "text" or dataType == "application/json":
[dtype,shape] = deserializeFromJSON(rpcClient, data, marker, index_vectors ,index_content, searches_vectors)
return [dtype,shape]


def completePendingJob(rpcClient ):
Expand Down Expand Up @@ -47,42 +109,21 @@ def getParamValue(key,default=None):
index_vectors = []
index_content = []
searches_vectors = []

dtype = None
shape = None

for jin in job.input:
data = json.loads(jin.data)
dataType = jin.type
marker = jin.marker

# every data might contain multiple vectors
for part in data:
[text,embeddings_b64,_dtype,_shape] = part
if dtype is None: dtype = _dtype
elif dtype != _dtype: raise Exception("Data type mismatch")
if shape is None: shape = _shape
elif shape != _shape: raise Exception("Shape mismatch")
embeddings_bytes = base64.b64decode(embeddings_b64)
embeddings = np.frombuffer(embeddings_bytes, dtype=dtype).reshape(shape)
if marker == "query":
searches_vectors.append(embeddings)
else:
index_vectors.append(embeddings)
index_content.append(text)
[dtype,shape] = deserialize(rpcClient,jin,index_vectors ,index_content,searches_vectors)

searches_vectors = np.array(searches_vectors)
index_vectors = np.array(index_vectors)
if normalize and dtype == "float32":
faiss.normalize_L2(searches_vectors)
faiss.normalize_L2(index_vectors)


print("Shape "+str(shape))
index = faiss.IndexFlatL2(shape[0])
index.add(index_vectors)


log(rpcClient, "Index prepared for job "+job.id+" in "+str(time.time()-t)+" seconds")

t=time.time()
Expand Down

0 comments on commit 6498bc8

Please sign in to comment.