Skip to content

Commit

Permalink
Updated openai and minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Oct 22, 2024
1 parent 8cef7c8 commit e1f30f4
Show file tree
Hide file tree
Showing 12 changed files with 386 additions and 62 deletions.
8 changes: 7 additions & 1 deletion aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,9 +1403,15 @@ def copy_type(self) -> Embedding:
def dtype(self) -> FeatureType:
return FeatureType.embedding(self.embedding_size or 0)

def dot_product(self, embedding: Embedding) -> Float:
def dot_product(self, embedding: Embedding, check_embedding_size: bool = True) -> Float:
from aligned.compiler.transformation_factory import ListDotProduct

if check_embedding_size:
assert self.embedding_size == embedding.embedding_size, (
'Expected similar embedding size, but got two different ones. '
f"Left: {self.embedding_size}, right: {embedding.embedding_size}"
)

feat = Float()
feat.transformation = ListDotProduct(self, embedding)
return feat
Expand Down
14 changes: 0 additions & 14 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
FeatureViewWrapper,
)
from aligned.exposed_model.interface import ExposedModel
from aligned.request.retrival_request import RetrivalRequest
from aligned.retrival_job import ConvertableToRetrivalJob, PredictionJob, RetrivalJob
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import Feature, FeatureLocation, FeatureReference, FeatureType, StaticFeatureTags
Expand Down Expand Up @@ -180,24 +179,11 @@ def predict_over(
values: ConvertableToRetrivalJob | RetrivalJob,
needed_views: list[FeatureViewWrapper | ModelContractWrapper] | None = None,
) -> PredictionJob:
from aligned.retrival_job import RetrivalJob

model = self.compile()

if not model.exposed_model:
raise ValueError(f"Model {model.name} does not have an `exposed_model` to use for predictions.")

if not isinstance(values, RetrivalJob):
features = {feat.as_feature() for feat in model.features.default_features}
request = RetrivalRequest(
name='default',
location=FeatureLocation.model(model.name),
entities=set(),
features=features,
derived_features=set(),
)
values = RetrivalJob.from_convertable(values, request)

return self.query(needed_views).predict_over(values)

def as_view(self) -> CompiledFeatureView | None:
Expand Down
34 changes: 33 additions & 1 deletion aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,39 @@ async def function_wrapper(values: RetrivalJob, store: ModelFeatureStore) -> pl.
return DillFunction(function=dill.dumps(function_wrapper))


def openai_embedding(model: str, prompt_template: str | None = None) -> ExposedModel:
def openai_embedding(
model: str, batch_on_n_chunks: int | None, prompt_template: str | None = None
) -> ExposedModel:
"""
Returns an OpenAI embedding model.
```python
@model_contract(
input_features=[MyFeature().name],
exposed_model=openai_embedding("text-embedding-3-small"),
)
class MyEmbedding:
my_entity = Int32().as_entity()
name = String()
embedding = Embedding(1536)
predicted_at = EventTimestamp()
embeddings = await store.model(MyEmbedding).predict_over({
"my_entity": [1, 2, 3],
"name": ["Hello", "World", "foo"]
}).to_polars()
```
Args:
model (str): the model to use. Look at the OpenAi docs to find the correct one.
batch_on_n_chunks (int): When to change to the batch API. Given that the batch size is too big.
prompt_template (str): A custom prompt template if wanted. The default will be based on the input features.
Returns:
ExposedModel: a model that sends embedding requests to OpenAI
"""
from aligned.exposed_model.openai import OpenAiEmbeddingPredictor

return OpenAiEmbeddingPredictor(model=model, prompt_template=prompt_template or '')
232 changes: 214 additions & 18 deletions aligned/exposed_model/openai.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,187 @@
from __future__ import annotations

import asyncio
import json
import logging
from dataclasses import dataclass
import polars as pl
from datetime import datetime, timezone
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4

from aligned.exposed_model.interface import ExposedModel
import polars as pl
from aligned.exposed_model.interface import (
ExposedModel,
Feature,
FeatureReference,
RetrivalJob,
VersionedModel,
)
from aligned.feature_store import ModelFeatureStore
from aligned.retrival_job import RetrivalJob
from aligned.schemas.feature import Feature, FeatureReference
from aligned.schemas.model import Model

if TYPE_CHECKING:
from openai import AsyncClient

logger = logging.getLogger(__name__)


def write_batch_request(texts: list[str], path: Path, model: str, url: str) -> None:
"""
Creates a .jsonl file for batch processing, with each line being a request to the embeddings API.
"""
with path.open('w') as f:
for i, text in enumerate(texts):
request = {
'custom_id': f"request-{i+1}",
'method': 'POST',
'url': url,
'body': {'model': model, 'input': text},
}
f.write(json.dumps(request) + '\n')


async def chunk_batch_embedding_request(texts: list[str], model: str, client: AsyncClient) -> pl.DataFrame:

max_batch = 50_000
number_of_batches = ceil(len(texts) / max_batch)

batch_result: pl.DataFrame | None = None

for i in range(number_of_batches):
start = i * max_batch
end_batch = min((i + 1) * max_batch, len(texts))

if start == end_batch:
batch_prompts = [texts[start]]
else:
batch_prompts = texts[start:end_batch]

result = await make_batch_embedding_request(batch_prompts, model, client)

if batch_result is None:
batch_result = result
else:
batch_result = batch_result.hstack(result)

assert batch_result is not None
return batch_result


async def make_batch_embedding_request(texts: list[str], model: str, client: AsyncClient) -> pl.DataFrame:

id_path = str(uuid4())
batch_file = Path(id_path)
output_file = Path(id_path + '-output.jsonl')

write_batch_request(texts, batch_file, model, '/v1/embeddings')
request_file = await client.files.create(file=batch_file, purpose='batch')
response = await client.batches.create(
input_file_id=request_file.id,
endpoint='/v1/embeddings',
completion_window='24h',
metadata={'description': 'Embedding batch job'},
)
status_response = await client.batches.retrieve(response.id)

last_process = None
expected_duration_left = 60

while status_response.status not in ['completed', 'failed']:
await asyncio.sleep(expected_duration_left * 0.8) # Poll every minute
status_response = await client.batches.retrieve(response.id)
logger.info(f"Status of batch request {status_response.status}")

processed_records = 0
leftover_records = 0

if status_response.request_counts:
processed_records = status_response.request_counts.completed
leftover_records = status_response.request_counts.total - processed_records

if status_response.in_progress_at:
last_process = datetime.fromtimestamp(status_response.in_progress_at, tz=timezone.utc)
now = datetime.now(tz=timezone.utc)

items_per_process = (now - last_process).total_seconds() / max(processed_records, 1)
expected_duration_left = max(items_per_process * leftover_records, 60)

batch_info = await client.batches.retrieve(response.id)
output_file_id = batch_info.output_file_id

if not output_file_id:
raise ValueError(f"No output file for request: {response.id}")

output_content = await client.files.retrieve_content(output_file_id)
output_file.write_text(output_content)
embeddings = pl.read_ndjson(output_file.as_posix())
expanded_emb = (
embeddings.unnest('response')
.unnest('body')
.explode('data')
.select(['custom_id', 'data'])
.unnest('data')
.select(['custom_id', 'embedding'])
.with_columns(pl.col('custom_id').str.split('-').list.get(1).alias('index'))
)
return expanded_emb


async def embed_texts(
texts: list[str], model: str, skip_if_n_chunks: int | None, client: AsyncClient
) -> list[list[float]] | str:
import tiktoken

max_token_size = 8192
number_of_texts = len(texts)

chunks: list[int] = []
chunk_size = 0
encoder = tiktoken.encoding_for_model(model)

for index, text in enumerate(texts):
token_size = len(encoder.encode(text))

if chunk_size + token_size > max_token_size:
chunks.append(index)
chunk_size = 0

if skip_if_n_chunks and len(chunks) + 1 >= skip_if_n_chunks:
return f"At text nr: {index} did it go above {skip_if_n_chunks} with {len(chunks)}"

chunk_size += token_size

if number_of_texts - 1 > chunks[-1]:
chunks.append(number_of_texts - 1)

embeddings: list[list[float]] = []

last_chunk_index = 0

for chunk_index in chunks:
if last_chunk_index == 0 and chunk_index >= number_of_texts - 1:
chunk_texts = texts
elif last_chunk_index == 0:
chunk_texts = texts[:chunk_index]
elif chunk_index >= number_of_texts - 1:
chunk_texts = texts[last_chunk_index:]
else:
chunk_texts = texts[last_chunk_index:chunk_index]

res = await client.embeddings.create(input=chunk_texts, model=model)
embeddings.extend([emb.embedding for emb in res.data])
last_chunk_index = chunk_index

return embeddings


@dataclass
class OpenAiEmbeddingPredictor(ExposedModel):
class OpenAiEmbeddingPredictor(ExposedModel, VersionedModel):

model: str

batch_on_n_chunks: int | None = 100
feature_refs: list[FeatureReference] = None # type: ignore
output_name: str = ''
prompt_template: str = ''
Expand All @@ -27,6 +196,24 @@ def prompt_template_hash(self) -> str:

return sha256(self.prompt_template.encode(), usedforsecurity=False).hexdigest()

@property
def as_markdown(self) -> str:
return f"""Sending a `embedding` request to OpenAI's API.
This will use the model: `{self.model}` to generate the embeddings.
Will switch to the batch API if more then {self.batch_on_n_chunks} chunks are needed to fulfill the request.
And use the prompt template:
```
{self.prompt_template}
```"""

async def model_version(self) -> str:
if len(self.feature_refs) == 1:
return self.model
else:
return f"{self.model}-{self.prompt_template_hash()}"

async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]:
return self.feature_refs

Expand Down Expand Up @@ -58,24 +245,33 @@ async def run_polars(self, values: RetrivalJob, store: ModelFeatureStore) -> pl.
missing_cols = expected_cols - set(values.loaded_columns)

if missing_cols:
logging.info(f"Missing cols: {missing_cols}")
df = await store.store.features_for(values, features=self.feature_refs).to_polars()
else:
df = await values.to_polars()

if len(expected_cols) == 1:
prompts = df[self.feature_refs[0].name].to_list()
texts = df[self.feature_refs[0].name].to_list()
else:
prompts: list[str] = []
texts: list[str] = []
for row in df.to_dicts():
prompts.append(self.prompt_template.format(**row))

embeddings = await client.embeddings.create(input=prompts, model=self.model)
return df.hstack(
[
pl.Series(
name=self.output_name,
values=[emb.embedding for emb in embeddings.data],
dtype=pl.List(pl.Float32),
)
]
texts.append(self.prompt_template.format(**row))

realtime_emb = await embed_texts(
texts, model=self.model, skip_if_n_chunks=self.batch_on_n_chunks, client=client
)

if isinstance(realtime_emb, list):
return df.hstack(
[
pl.Series(
name=self.output_name,
values=realtime_emb,
dtype=pl.List(pl.Float32),
)
]
)

batch_result = await chunk_batch_embedding_request(texts, self.model, client)

return df.hstack([batch_result['embedding'].alias(self.output_name)])
12 changes: 11 additions & 1 deletion aligned/exposed_model/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class MyModelContract2:

entities = {'entity_id': ['a', 'b'], 'x': [1, 2]}
pred_job = MyModelContract2.predict_over(entities, needed_views=[InputFeatureView, MyModelContract])
assert set(pred_job.request_result.feature_columns) == {'x', 'prediction', 'other_pred'}
assert set(pred_job.request_result.all_returned_columns) == {'x', 'entity_id', 'prediction', 'other_pred'}

preds = await pred_job.to_polars()
assert preds['other_pred'].to_list() == [6, 12]
Expand Down Expand Up @@ -226,3 +226,13 @@ class MyModelContract2:
)
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])

preds = (
await without_cache.model(MyModelContract2)
.predict_over(without_cache.feature_view(InputFeatureView).all())
.to_polars()
)
input_features = InputFeatureView.query().request.all_returned_columns
assert set(input_features) - set(preds.columns) == set(), 'Missing some columns'
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])
Loading

0 comments on commit e1f30f4

Please sign in to comment.