Skip to content

Commit

Permalink
Add ColPali embedding and support for processing PDFs as images (#543)
Browse files Browse the repository at this point in the history
Add ColPali embedding and support for processing PDFs as images
  • Loading branch information
NikolaosPapailiou authored Oct 14, 2024
1 parent f41bf58 commit a187208
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 10 deletions.
279 changes: 279 additions & 0 deletions apis/python/examples/object_api/multi_modal_pdf_search.ipynb

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions apis/python/src/tiledb/vector_search/embeddings/colpali_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Dict, OrderedDict, Tuple

import numpy as np

from tiledb.vector_search.embeddings import ObjectEmbedding

EMBED_DIM = 128


class ColpaliEmbedding(ObjectEmbedding):
def __init__(
self,
model_name: str = "vidore/colpali-v1.2",
device: str = None,
batch_size: int = 4,
):
self.model_name = model_name
self.device = device
self.batch_size = batch_size
self.model = None
self.processor = None

def init_kwargs(self) -> Dict:
return {
"model_name": self.model_name,
"device": self.device,
"batch_size": self.batch_size,
}

def dimensions(self) -> int:
return EMBED_DIM

def vector_type(self) -> np.dtype:
return np.float32

def load(self) -> None:
import torch
from colpali_engine.models import ColPali
from colpali_engine.models import ColPaliProcessor

if self.device is None:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"

# Load model
self.model = ColPali.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
).eval()
self.processor = ColPaliProcessor.from_pretrained(self.model_name)

def embed(
self, objects: OrderedDict, metadata: OrderedDict
) -> Tuple[np.ndarray, np.array]:
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

if "image" in objects:
images = []
for i in range(len(objects["image"])):
images.append(
Image.fromarray(
np.reshape(objects["image"][i], objects["shape"][i])
)
)
dataloader = DataLoader(
images,
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda x: self.processor.process_images(x),
)
elif "text" in objects:
dataloader = DataLoader(
objects["text"],
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda x: self.processor.process_queries(x),
)

embeddings = None
external_ids = None
id = 0
for batch in tqdm(dataloader):
with torch.no_grad():
batch = {k: v.to(self.model.device) for k, v in batch.items()}
batch_embeddings = list(torch.unbind(self.model(**batch).to("cpu")))
for object_embeddings in batch_embeddings:
object_embeddings_np = object_embeddings.to(torch.float32).cpu().numpy()
ext_ids = metadata["external_id"][id] * np.ones(
object_embeddings_np.shape[0], dtype=np.uint64
)
if embeddings is None:
external_ids = ext_ids
embeddings = object_embeddings_np
else:
external_ids = np.concatenate((external_ids, ext_ids))
embeddings = np.vstack((embeddings, object_embeddings_np))
id += 1
return (embeddings, external_ids)
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ def create(
schema = tiledb.ArraySchema(
domain=external_ids_dom,
sparse=True,
capacity=metadata_tile_size,
attrs=object_reader.metadata_attributes(),
)
tiledb.Array.create(object_metadata_array_uri, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ def __init__(
handlers={
"application/pdf": PyMuPDFParser(),
"text/plain": TextParser(),
"text/markdown": TextParser(),
"text/html": BS4HTMLParser(),
"application/msword": MsWordParser(),
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": (
Expand Down Expand Up @@ -386,7 +385,7 @@ def lazy_load(
mime_type = mimetypes.guess_type(self.uri)[0]
f = vfs.open(self.uri)

if mime_type is None or mime_type.startswith("text"):
if mime_type is None:
mime_type = "text/plain"

if mime_type.startswith("image/"):
Expand Down Expand Up @@ -519,31 +518,71 @@ def init_kwargs(self) -> Dict:
}

def metadata_attributes(self) -> List[tiledb.Attr]:
image_attr = tiledb.Attr(
name="image",
dtype=np.uint8,
var=True,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
)
image_shape_attr = tiledb.Attr(name="shape", dtype=np.uint32, var=True)
file_path_attr = tiledb.Attr(name="file_path", dtype=str)
return [image_shape_attr, file_path_attr]
page_attr = tiledb.Attr(name="page", dtype=np.int32)
return [image_attr, image_shape_attr, file_path_attr, page_attr]

def read_objects(
self, partition: DirectoryPartition
) -> Tuple[OrderedDict, OrderedDict]:
import mimetypes

from PIL import Image

import tiledb

size = len(partition.paths)
images = np.empty(size, dtype="O")
shapes = np.empty(size, dtype="O")
file_paths = np.empty(size, dtype="O")
external_ids = np.zeros(size, dtype=np.uint64)
max_size = DirectoryTextReader.MAX_OBJECTS_PER_FILE * len(partition.paths)
images = np.empty(max_size, dtype="O")
shapes = np.empty(max_size, dtype="O")
file_paths = np.empty(max_size, dtype="O")
pages = np.zeros(max_size, dtype=np.int32)
external_ids = np.zeros(max_size, dtype=np.uint64)
write_id = 0
vfs = tiledb.VFS()
for path in partition.paths:
with vfs.open(path) as fp:
if tiledb.array_exists(path):
with tiledb.open(path, "r") as a:
mime_type = a.meta.get("mime_type", None)
fp = tiledb.filestore.Filestore(path)
else:
mime_type = mimetypes.guess_type(path)[0]
fp = vfs.open(path)
if path.endswith(".pdf") or mime_type == "application/pdf":
from io import BytesIO

import fitz

doc = fitz.open(stream=fp.read())
p = 0
for page in doc.pages():
zoom = 1
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)

image = np.array(
Image.open(BytesIO(pix.tobytes(output="png", jpg_quality=95)))
)[:, :, :3]
images[write_id] = image.flatten()
shapes[write_id] = np.array(image.shape, dtype=np.uint32)
external_ids[write_id] = abs(hash(f"{path}_{page}"))
file_paths[write_id] = str(path)
pages[write_id] = p
write_id += 1
p += 1
else:
image = np.array(Image.open(fp))[:, :, :3]
images[write_id] = image.flatten()
shapes[write_id] = np.array(image.shape, dtype=np.uint32)
external_ids[write_id] = partition.object_id_start + write_id
external_ids[write_id] = abs(hash(f"{path}_{0}"))
file_paths[write_id] = str(path)
pages[write_id] = 0
write_id += 1
return (
{
Expand All @@ -552,8 +591,10 @@ def read_objects(
"external_id": external_ids[0:write_id],
},
{
"image": images[0:write_id],
"shape": shapes[0:write_id],
"file_path": file_paths[0:write_id],
"page": pages[0:write_id],
"external_id": external_ids[0:write_id],
},
)
Expand Down

0 comments on commit a187208

Please sign in to comment.