Skip to content

Commit

Permalink
Updated with openai embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Oct 18, 2024
1 parent e52e8ab commit 8cef7c8
Show file tree
Hide file tree
Showing 27 changed files with 898 additions and 520 deletions.
3 changes: 0 additions & 3 deletions aligned/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from aligned.compiler.feature_factory import (
UUID,
Bool,
Entity,
EventTimestamp,
ValidFrom,
Float,
Expand Down Expand Up @@ -58,10 +57,8 @@
'KafkaConfig',
# Types
'ExposedModel',
'Entity',
'String',
'Bool',
'Entity',
'UUID',
'UInt8',
'UInt16',
Expand Down
42 changes: 11 additions & 31 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,8 @@ def as_model_version(self) -> ModelVersion:


class CouldBeEntityFeature:
def as_entity(self) -> Entity:
if isinstance(self, FeatureFactory):
return Entity(self).with_tag(StaticFeatureTags.is_entity)

raise ValueError(f'{self} is not a feature factory, and can therefore not be an entity')
def as_entity(self: T) -> T:
return self.with_tag(StaticFeatureTags.is_entity)


class EquatableFeature(FeatureFactory):
Expand Down Expand Up @@ -1345,39 +1342,15 @@ def aggregate(self) -> CategoricalAggregation:
return CategoricalAggregation(self)


class Entity(FeatureFactory):

_dtype: FeatureFactory

@property
def dtype(self) -> FeatureType:
return self._dtype.dtype

def __init__(self, dtype: FeatureFactory):
self._dtype = dtype

def aggregate(self) -> CategoricalAggregation:
return CategoricalAggregation(self)

def feature(self) -> Feature:
return Feature(
name=self.name,
dtype=self.dtype,
description=self._description,
tags=None,
constraints=self._dtype.constraints,
)


class Timestamp(DateFeature, ArithmeticFeature):

time_zone: str | None

def __init__(self, time_zone: str | None = 'UTC') -> None:
self.time_zone = time_zone

def defines_freshness(self) -> Timestamp:
return self.with_tag('freshness_timestamp')
def as_freshness(self) -> Timestamp:
return self.with_tag(StaticFeatureTags.is_freshness)

@property
def dtype(self) -> FeatureType:
Expand Down Expand Up @@ -1430,6 +1403,13 @@ def copy_type(self) -> Embedding:
def dtype(self) -> FeatureType:
return FeatureType.embedding(self.embedding_size or 0)

def dot_product(self, embedding: Embedding) -> Float:
from aligned.compiler.transformation_factory import ListDotProduct

feat = Float()
feat.transformation = ListDotProduct(self, embedding)
return feat

def indexed(
self,
storage: VectorStorage,
Expand Down
15 changes: 10 additions & 5 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from aligned.compiler.feature_factory import (
CanBeClassificationLabel,
Entity,
Bool,
EventTimestamp,
FeatureFactory,
Expand All @@ -31,7 +30,7 @@
from aligned.request.retrival_request import RetrivalRequest
from aligned.retrival_job import ConvertableToRetrivalJob, PredictionJob, RetrivalJob
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import Feature, FeatureLocation, FeatureReference, FeatureType
from aligned.schemas.feature import Feature, FeatureLocation, FeatureReference, FeatureType, StaticFeatureTags
from aligned.schemas.feature_view import CompiledFeatureView
from aligned.schemas.literal_value import LiteralValue
from aligned.schemas.model import Model as ModelSchema
Expand Down Expand Up @@ -500,9 +499,12 @@ class MyModel(ModelContract):
)
elif isinstance(feature, RecommendationTarget):
inference_view.recommendation_targets.add(feature.compile())
elif isinstance(feature, Entity):
inference_view.entities.add(feature.feature())
elif isinstance(feature, FeatureFactory):

if feature.tags and StaticFeatureTags.is_entity in feature.tags:
inference_view.entities.add(feature.feature())
continue

if feature.transformation:
# Adding features that is not stored in the view
# e.g:
Expand Down Expand Up @@ -588,7 +590,7 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int:
if inference_view.source:
inference_view.source = inference_view.source.with_view(view)

return ModelSchema(
schema = ModelSchema(
name=metadata.name,
features=features,
predictions_view=inference_view,
Expand All @@ -599,3 +601,6 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int:
exposed_at_url=metadata.exposed_at_url,
exposed_model=metadata.exposed_model,
)
if schema.exposed_model:
schema.exposed_model = schema.exposed_model.with_contract(schema)
return schema
16 changes: 16 additions & 0 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,19 @@ def compile(self) -> Transformation:
from aligned.schemas.transformation import FormatStringTransformation

return FormatStringTransformation(self.format, [feature.name for feature in self.features])


@dataclass
class ListDotProduct(TransformationFactory):

left: FeatureFactory
right: FeatureFactory

@property
def using_features(self) -> list[FeatureFactory]:
return [self.left, self.right]

def compile(self) -> Transformation:
from aligned.schemas.transformation import ListDotProduct

return ListDotProduct(self.left.name, self.right.name)
24 changes: 12 additions & 12 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation, FeatureType
from aligned.schemas.feature import Feature, FeatureLocation, FeatureType
from aligned.request.retrival_request import RequestResult, RetrivalRequest
from aligned.compiler.feature_factory import FeatureFactory
from polars.type_aliases import TimeUnit
Expand Down Expand Up @@ -243,7 +243,7 @@ class MyView(FeatureView):
feature_types = {name: feature_type.feature_factory for name, feature_type in schema.items()}
return FeatureView.feature_view_code_template(feature_types, f'{self}', view_name)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
async def freshness(self, feature: Feature) -> datetime | None:
"""
my_table_freshenss = await (PostgreSQLConfig("DB_URL")
.table("my_table")
Expand All @@ -254,7 +254,7 @@ async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
from aligned.sources.local import data_file_freshness

if isinstance(self, DataFileReference):
return await data_file_freshness(self, event_timestamp.name)
return await data_file_freshness(self, feature.name)

raise NotImplementedError(f'Freshness is not implemented for {type(self)}.')

Expand Down Expand Up @@ -504,8 +504,8 @@ def multi_source_features_for( # type: ignore

return source.source.features_for(facts, request).filter(condition)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
return await self.source.freshness(event_timestamp)
async def freshness(self, feature: Feature) -> datetime | None:
return await self.source.freshness(feature)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
Expand Down Expand Up @@ -780,9 +780,9 @@ def all_between_dates(
.derive_features([request])
)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
left_freshness = await self.source.freshness(event_timestamp)
right_frehsness = await self.right_source.freshness(event_timestamp)
async def freshness(self, feature: Feature) -> datetime | None:
left_freshness = await self.source.freshness(feature)
right_frehsness = await self.right_source.freshness(feature)

if left_freshness is None:
return None
Expand Down Expand Up @@ -990,7 +990,7 @@ def all_between_dates(
def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on()

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
async def freshness(self, feature: Feature) -> datetime | None:
return None

@classmethod
Expand Down Expand Up @@ -1077,9 +1077,9 @@ def all_between_dates(
.derive_features([request])
)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
left_freshness = await self.source.freshness(event_timestamp)
right_frehsness = await self.right_source.freshness(event_timestamp)
async def freshness(self, feature: Feature) -> datetime | None:
left_freshness = await self.source.freshness(feature)
right_frehsness = await self.right_source.freshness(feature)

if left_freshness is None:
return None
Expand Down
10 changes: 10 additions & 0 deletions aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if TYPE_CHECKING:
from aligned.feature_store import ModelFeatureStore
from aligned.schemas.model import Model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +89,9 @@ async def depends_on(self) -> list[FeatureLocation]:
"""
return []

def with_contract(self, model: Model) -> ExposedModel:
return self

async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]:
raise NotImplementedError(type(self))

Expand Down Expand Up @@ -428,3 +432,9 @@ async def function_wrapper(values: RetrivalJob, store: ModelFeatureStore) -> pl.
return result

return DillFunction(function=dill.dumps(function_wrapper))


def openai_embedding(model: str, prompt_template: str | None = None) -> ExposedModel:
from aligned.exposed_model.openai import OpenAiEmbeddingPredictor

return OpenAiEmbeddingPredictor(model=model, prompt_template=prompt_template or '')
Loading

0 comments on commit 8cef7c8

Please sign in to comment.