Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Feb 17, 2025
1 parent bc7b955 commit 374244c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 33 deletions.
4 changes: 3 additions & 1 deletion qdrant_client/embed/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def embed(
is_query: bool = False,
batch_size: int = 32,
) -> NumericVector:
task_id = options.get("task_id") if options else None

if (texts is None) is (images is None):
raise ValueError("Either documents or images should be provided")
if model_name in SUPPORTED_EMBEDDING_MODELS:
Expand All @@ -199,7 +201,7 @@ def embed(
embeddings = [
embedding.tolist()
for embedding in embedding_model_inst.embed(
documents=texts, batch_size=batch_size
documents=texts, batch_size=batch_size, task_id=task_id
)
]
else:
Expand Down
81 changes: 49 additions & 32 deletions tests/embed_tests/test_local_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,22 +1050,31 @@ def test_update_vectors(prefer_grpc):

@pytest.mark.parametrize("prefer_grpc", [True, False])
def test_propagate_options(prefer_grpc):
params = {
"lazy_load": True,
"cache_dir": "models",
"threads": None,
"device_ids": None,
"cuda": False,
"providers": ["CPUExecutionProvider"],
"local_files_only": True,
}
local_client = QdrantClient(":memory:")
if not local_client._FASTEMBED_INSTALLED:
pytest.skip("FastEmbed is not installed, skipping")
remote_client = QdrantClient(prefer_grpc=prefer_grpc)
dense_doc_1 = models.Document(
text="hello world", model=DENSE_MODEL_NAME, options={"lazy_load": True}
text="hello world", model=DENSE_MODEL_NAME, options=params
)
sparse_doc_1 = models.Document(
text="hello world", model=SPARSE_MODEL_NAME, options={"lazy_load": True}
text="hello world", model=SPARSE_MODEL_NAME, options=params
)
multi_doc_1 = models.Document(
text="hello world", model=COLBERT_MODEL_NAME, options={"lazy_load": True}
text="hello world", model=COLBERT_MODEL_NAME, options=params
)

dense_image_1 = models.Image(
image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME, options={"lazy_load": True}
image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME, options=params
)

points = [
Expand Down Expand Up @@ -1111,18 +1120,22 @@ def test_propagate_options(prefer_grpc):
local_client.upsert(COLLECTION_NAME, points)
remote_client.upsert(COLLECTION_NAME, points)

assert local_client._model_embedder.embedder.embedding_models[DENSE_MODEL_NAME][
0
].model.model.lazy_load
assert local_client._model_embedder.embedder.sparse_embedding_models[SPARSE_MODEL_NAME][
0
].model.model.lazy_load
assert local_client._model_embedder.embedder.late_interaction_embedding_models[
COLBERT_MODEL_NAME
][0].model.model.lazy_load
assert local_client._model_embedder.embedder.image_embedding_models[DENSE_IMAGE_MODEL_NAME][
0
].model.model.lazy_load
embedder = local_client._model_embedder.embedder
for model_type, model_name in [
(embedder.embedding_models, DENSE_MODEL_NAME),
(embedder.sparse_embedding_models, SPARSE_MODEL_NAME),
(embedder.late_interaction_embedding_models, COLBERT_MODEL_NAME),
(embedder.image_embedding_models, DENSE_IMAGE_MODEL_NAME),
]:
model = model_type[model_name][0].model.model

for key, value in params.items():
if key == "cache_dir":
assert str(model.cache_dir) == str(value), f"cache_dir was not propagated correctly for {model_name}"
elif key == "local_files_only":
assert model._local_files_only == value, f"local_files_only was not propagated correctly for {model_name}"
else:
assert getattr(model, key) == value, f"{key} was not propagated correctly for {model_name}"

local_client._model_embedder.embedder.embedding_models.clear()
local_client._model_embedder.embedder.sparse_embedding_models.clear()
Expand All @@ -1132,25 +1145,25 @@ def test_propagate_options(prefer_grpc):
inference_object_dense_doc_1 = models.InferenceObject(
object="hello world",
model=DENSE_MODEL_NAME,
options={"lazy_load": True},
options=params,
)

inference_object_sparse_doc_1 = models.InferenceObject(
object="hello world",
model=SPARSE_MODEL_NAME,
options={"lazy_load": True},
options=params,
)

inference_object_multi_doc_1 = models.InferenceObject(
object="hello world",
model=COLBERT_MODEL_NAME,
options={"lazy_load": True},
options=params,
)

inference_object_dense_image_1 = models.InferenceObject(
object=TEST_IMAGE_PATH,
model=DENSE_IMAGE_MODEL_NAME,
options={"lazy_load": True},
options=params,
)

points = [
Expand All @@ -1168,18 +1181,22 @@ def test_propagate_options(prefer_grpc):
local_client.upsert(COLLECTION_NAME, points)
remote_client.upsert(COLLECTION_NAME, points)

assert local_client._model_embedder.embedder.embedding_models[DENSE_MODEL_NAME][
0
].model.model.lazy_load
assert local_client._model_embedder.embedder.sparse_embedding_models[SPARSE_MODEL_NAME][
0
].model.model.lazy_load
assert local_client._model_embedder.embedder.late_interaction_embedding_models[
COLBERT_MODEL_NAME
][0].model.model.lazy_load
assert local_client._model_embedder.embedder.image_embedding_models[DENSE_IMAGE_MODEL_NAME][
0
].model.model.lazy_load
embedder = local_client._model_embedder.embedder
for model_type, model_name in [
(embedder.embedding_models, DENSE_MODEL_NAME),
(embedder.sparse_embedding_models, SPARSE_MODEL_NAME),
(embedder.late_interaction_embedding_models, COLBERT_MODEL_NAME),
(embedder.image_embedding_models, DENSE_IMAGE_MODEL_NAME),
]:
model = model_type[model_name][0].model.model

for key, value in params.items():
if key == "cache_dir":
assert str(model.cache_dir) == str(value), f"cache_dir was not propagated correctly for {model_name}"
elif key == "local_files_only":
assert model._local_files_only == value, f"local_files_only was not propagated correctly for {model_name}"
else:
assert getattr(model, key) == value, f"{key} was not propagated correctly for {model_name}"


@pytest.mark.parametrize("prefer_grpc", [True, False])
Expand Down

0 comments on commit 374244c

Please sign in to comment.