Skip to content

Commit

Permalink
Merge pull request #18 from mlx-chat/MLC-17
Browse files Browse the repository at this point in the history
[MLC-17] server: added runner script to package server files
  • Loading branch information
stockeh authored Mar 3, 2024
2 parents b0ace87 + 449135e commit db13f00
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
44 changes: 44 additions & 0 deletions runner.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash

collect_modules=(
"mlx"
"chromadb"
)

hidden_imports=(
"server.models"
"server.models.gemma"
"server.models.bert"
)

exclude_modules=(
"torch"
"torchaudio"
"torchvision"
"tensorflow"
"matplotlib"
"pandas"
"PIL"
"IPython"
)

misc_params=(
"--copy-metadata opentelemetry-sdk"
)

command="pyinstaller --onefile runner.py"

for module in "${collect_modules[@]}"; do
command+=" --collect-all $module"
done
for module in "${hidden_imports[@]}"; do
command+=" --hidden-import $module"
done
for module in "${exclude_modules[@]}"; do
command+=" --exclude-module $module"
done
for param in "${misc_params[@]}"; do
command+=" $param"
done

eval "$command"
19 changes: 7 additions & 12 deletions server/retriever/vectorstore.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .document import Document
import uuid
import functools
import mlx.core as mx

import chromadb
import chromadb.config

from chromadb.utils.batch_utils import create_batches
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
from typing import (
Any,
Expand All @@ -15,11 +18,7 @@
Tuple,
Type,
)

import chromadb
import chromadb.config
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument

from .document import Document
from .embeddings import Embeddings

Chroma = TypeVar('Chroma', bound='Chroma')
Expand Down Expand Up @@ -508,9 +507,7 @@ def update_documents(self, ids: List[str], documents: List[Document]) -> None:

if hasattr(
self._collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches

):
for batch in create_batches(
api=self._collection._client,
ids=ids,
Expand Down Expand Up @@ -578,9 +575,7 @@ def from_texts(
ids = [str(uuid.uuid1()) for _ in texts]
if hasattr(
chroma_collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches

):
for batch in create_batches(
api=chroma_collection._client,
ids=ids,
Expand Down

0 comments on commit db13f00

Please sign in to comment.