diff --git a/aligned/schemas/feature.py b/aligned/schemas/feature.py index ff4f5f94..a9e29c8e 100644 --- a/aligned/schemas/feature.py +++ b/aligned/schemas/feature.py @@ -57,6 +57,13 @@ def is_datetime(self) -> bool: def is_array(self) -> bool: return self.name.startswith('array') + def array_subtype(self) -> FeatureType | None: + if not self.is_array or '-' not in self.name: + return None + + sub = str(self.name[len('array-') :]) + return FeatureType(sub) + @property def datetime_timezone(self) -> str | None: if not self.is_datetime: @@ -115,10 +122,17 @@ def pandas_type(self) -> str | type: @property def polars_type(self) -> type: - if self.name.startswith('datetime-'): - time_zone = self.name.split('-')[1] + if self.is_datetime: + time_zone = self.datetime_timezone return pl.Datetime(time_zone=time_zone) # type: ignore + if self.is_array: + sub_type = self.array_subtype() + if sub_type: + return pl.List(sub_type.polars_type) # type: ignore + else: + return pl.List(pl.Utf8) # type: ignore + for name, dtype in NAME_POLARS_MAPPING: if name == self.name: return dtype diff --git a/aligned/sources/azure_blob_storage.py b/aligned/sources/azure_blob_storage.py index 8f315611..e9582612 100644 --- a/aligned/sources/azure_blob_storage.py +++ b/aligned/sources/azure_blob_storage.py @@ -442,17 +442,14 @@ async def write_polars(self, df: pl.LazyFrame) -> None: mode='append', ) - async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None: + def df_to_deltalake_compatible( + self, df: pl.DataFrame, requests: list[RetrivalRequest] + ) -> tuple[pl.DataFrame, dict]: import pyarrow as pa from aligned.schemas.constraints import Optional from aligned.schemas.feature import Feature - df = await job.to_polars() - url = f"az://{self.path}" - - def pa_field(feature: Feature) -> pa.Field: - is_nullable = Optional() in (feature.constraints or set()) - + def pa_dtype(dtype: FeatureType) -> pa.DataType: pa_types = { 'int8': pa.int8(), 'int16': pa.int16(), @@ -462,16 +459,33 @@ def pa_field(feature: Feature) -> pa.Field: 'double': pa.float64(), 'string': pa.large_string(), 'date': pa.date64(), + 'embedding': pa.large_list(pa.float32()), 'datetime': pa.float64(), 'list': pa.large_list(pa.int32()), 'array': pa.large_list(pa.int32()), 'bool': pa.bool_(), } - if feature.dtype.name in pa_types: - return pa.field(feature.name, pa_types[feature.dtype.name], nullable=is_nullable) + if dtype.name in pa_types: + return pa_types[dtype.name] + + if dtype.is_datetime: + return pa.float64() + + if dtype.is_array: + array_sub_dtype = dtype.array_subtype() + if array_sub_dtype: + return pa.large_list(pa_dtype(array_sub_dtype)) + + return pa.large_list(pa.string()) + + raise ValueError(f"Unsupported dtype: {dtype}") + + def pa_field(feature: Feature) -> pa.Field: + is_nullable = Optional() in (feature.constraints or set()) - raise ValueError(f"Unsupported dtype: {feature.dtype}") + pa_type = pa_dtype(feature.dtype) + return pa.field(feature.name, pa_type, nullable=is_nullable) dtypes = dict(zip(df.columns, df.dtypes, strict=False)) schemas = {} @@ -479,18 +493,26 @@ def pa_field(feature: Feature) -> pa.Field: for request in requests: for feature in request.all_features.union(request.entities): schemas[feature.name] = pa_field(feature) + if dtypes[feature.name] == pl.Null: df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type)) - elif feature.dtype.name == 'array': - df = df.with_columns(pl.col(feature.name).cast(pl.List(pl.Int32()))) - elif feature.dtype.name == 'datetime': + elif feature.dtype.is_datetime: df = df.with_columns(pl.col(feature.name).dt.timestamp('ms').cast(pl.Float64())) else: df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type)) + return df, schemas + + async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None: + import pyarrow as pa + + df = await job.to_polars() + url = f"az://{self.path}" + + df, schemas = self.df_to_deltalake_compatible(df, requests) + orderd_schema = OrderedDict(sorted(schemas.items())) schema = list(orderd_schema.values()) - df.select(list(orderd_schema.keys())).write_delta( url, storage_options=self.config.read_creds(), @@ -500,8 +522,6 @@ def pa_field(feature: Feature) -> pa.Field: async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None: import pyarrow as pa - from aligned.schemas.constraints import Optional - from aligned.schemas.feature import Feature from deltalake.exceptions import TableNotFoundError df = await job.to_polars() @@ -509,49 +529,15 @@ async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> Non url = f"az://{self.path}" merge_on = set() - def pa_field(feature: Feature) -> pa.Field: - is_nullable = Optional() in (feature.constraints or set()) - - pa_types = { - 'int8': pa.int8(), - 'int16': pa.int16(), - 'int32': pa.int32(), - 'int64': pa.int64(), - 'float': pa.float64(), - 'double': pa.float64(), - 'string': pa.large_string(), - 'date': pa.date64(), - 'datetime': pa.float64(), - 'list': pa.large_list(pa.int32()), - 'array': pa.large_list(pa.int32()), - 'bool': pa.bool_(), - } - - if feature.dtype.name in pa_types: - return pa.field(feature.name, pa_types[feature.dtype.name], nullable=is_nullable) - - raise ValueError(f"Unsupported dtype: {feature.dtype}") - - dtypes = dict(zip(df.columns, df.dtypes, strict=False)) schemas = {} for request in requests: merge_on.update(request.entity_names) - for feature in request.all_features.union(request.entities): - schemas[feature.name] = pa_field(feature) - if dtypes[feature.name] == pl.Null: - df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type)) - elif feature.dtype.name == 'array': - df = df.with_columns(pl.col(feature.name).cast(pl.List(pl.Int32()))) - elif feature.dtype.name == 'datetime': - df = df.with_columns(pl.col(feature.name).dt.timestamp('ms').cast(pl.Float64())) - else: - df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type)) + df, schemas = self.df_to_deltalake_compatible(df, requests) orderd_schema = OrderedDict(sorted(schemas.items())) schema = list(orderd_schema.values()) - predicate = ' AND '.join([f"s.{key} = t.{key}" for key in merge_on]) try: diff --git a/pyproject.toml b/pyproject.toml index 98d05180..fb925467 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aligned" -version = "0.0.84" +version = "0.0.85" description = "A data managment and lineage tool for ML applications." authors = ["Mats E. Mollestad "] license = "Apache-2.0"