Skip to content

Commit

Permalink
Added partitioned parquet source on azure blob
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed May 26, 2024
1 parent 17dd447 commit c21c9dc
Show file tree
Hide file tree
Showing 5 changed files with 1,439 additions and 542 deletions.
146 changes: 139 additions & 7 deletions aligned/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ParquetConfig,
StorageFileReference,
Directory,
PartitionedParquetFileSource,
data_file_freshness,
)
from aligned.storage import Storage
Expand Down Expand Up @@ -53,7 +52,7 @@ def azure_container_blob(path: str) -> AzurePath:
@dataclass
class AzureBlobConfig(Directory):
account_id_env: str
tenent_id_env: str
tenant_id_env: str
client_id_env: str
client_secret_env: str
account_name_env: str
Expand All @@ -71,7 +70,7 @@ def to_markdown(self) -> str:
2. Using Tenant Id, Client Id and Client Secret
- Tenant Id Env: `{self.tenent_id_env}`
- Tenant Id Env: `{self.tenant_id_env}`
- Client Id Env: `{self.client_id_env}`
- Client Secret Env: `{self.client_secret_env}`
"""
Expand All @@ -96,8 +95,15 @@ def partitioned_parquet_at(
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> PartitionedParquetFileSource:
raise NotImplementedError(type(self))
) -> AzureBlobPartitionedParquetDataSource:
return AzureBlobPartitionedParquetDataSource(
self,
directory,
partition_keys,
mapping_keys=mapping_keys or {},
parquet_config=config or ParquetConfig(),
date_formatter=date_formatter or DateFormatter.noop(),
)

def csv_at(
self,
Expand Down Expand Up @@ -164,7 +170,7 @@ def read_creds(self) -> dict[str, str]:
else:
return {
'account_name': account_name,
'tenant_id': os.environ[self.tenent_id_env],
'tenant_id': os.environ[self.tenant_id_env],
'client_id': os.environ[self.client_id_env],
'client_secret': os.environ[self.client_secret_env],
}
Expand All @@ -191,7 +197,26 @@ def parquet_at(
) -> AzureBlobParquetDataSource:
sub_path = self.sub_path / path
return self.config.parquet_at(
sub_path.as_posix(), date_formatter=date_formatter or DateFormatter.noop()
sub_path.as_posix(),
mapping_keys=mapping_keys,
date_formatter=date_formatter or DateFormatter.noop(),
)

def partitioned_parquet_at(
self,
directory: str,
partition_keys: list[str],
mapping_keys: dict[str, str] | None = None,
config: ParquetConfig | None = None,
date_formatter: DateFormatter | None = None,
) -> AzureBlobPartitionedParquetDataSource:
sub_path = self.sub_path / directory
return self.config.partitioned_parquet_at(
sub_path.as_posix(),
partition_keys,
mapping_keys=mapping_keys,
config=config,
date_formatter=date_formatter,
)

def csv_at(
Expand Down Expand Up @@ -382,6 +407,113 @@ def all_between_dates(
)


@dataclass
class AzureBlobPartitionedParquetDataSource(BatchDataSource, DataFileReference, ColumnFeatureMappable):
config: AzureBlobConfig
directory: str
partition_keys: list[str]
mapping_keys: dict[str, str] = field(default_factory=dict)
parquet_config: ParquetConfig = field(default_factory=ParquetConfig)
date_formatter: DateFormatter = field(default_factory=lambda: DateFormatter.noop())
type_name: str = 'azure_blob_partitiond_parquet'

@property
def to_markdown(self) -> str:
return f"""Type: *Azure Blob Partitioned Parquet File*
Directory: *{self.directory}*
Partition Keys: *{self.partition_keys}*
{self.config.to_markdown}"""

def job_group_key(self) -> str:
return f"{self.type_name}/{self.directory}"

def __hash__(self) -> int:
return hash(self.job_group_key())

@property
def storage(self) -> Storage:
return self.config.storage

async def schema(self) -> dict[str, FeatureType]:
try:
schema = (await self.to_lazy_polars()).schema
return {name: FeatureType.from_polars(pl_type) for name, pl_type in schema.items()}

except FileNotFoundError as error:
raise UnableToFindFileException() from error
except HTTPStatusError as error:
raise UnableToFindFileException() from error

async def read_pandas(self) -> pd.DataFrame:
return (await self.to_lazy_polars()).collect().to_pandas()

async def to_lazy_polars(self) -> pl.LazyFrame:
try:
url = f"az://{self.directory}/**/*.parquet"
creds = self.config.read_creds()
return pl.scan_parquet(url, storage_options=creds)
except FileNotFoundError as error:
raise UnableToFindFileException(self.directory) from error
except HTTPStatusError as error:
raise UnableToFindFileException(self.directory) from error
except pl.ComputeError as error:
raise UnableToFindFileException(self.directory) from error

async def write_pandas(self, df: pd.DataFrame) -> None:
await self.write_polars(pl.from_pandas(df).lazy())

async def write_polars(self, df: pl.LazyFrame) -> None:
url = f"az://{self.directory}"
creds = self.config.read_creds()
df.collect().to_pandas().to_parquet(url, partition_cols=self.partition_keys, storage_options=creds)

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[AzureBlobParquetDataSource, RetrivalRequest]]
) -> RetrivalJob:

source = requests[0][0]
if not isinstance(source, cls):
raise ValueError(f'Only {cls} is supported, recived: {source}')

# Group based on config
return FileFactualJob(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.date_formatter,
)

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return FileFactualJob(self, [request], facts, date_formatter=self.date_formatter)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
return FileFullJob(self, request, limit, date_formatter=self.date_formatter)

def all_between_dates(
self,
request: RetrivalRequest,
start_date: datetime,
end_date: datetime,
) -> RetrivalJob:
return FileDateJob(
source=self,
request=request,
start_date=start_date,
end_date=end_date,
date_formatter=self.date_formatter,
)

async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
if len(requests) != 1:
raise ValueError(f"Only support writing on request, got {len(requests)}.")
features = requests[0].all_returned_columns
df = await job.select(features).to_lazy_polars()
await self.write_polars(df)


@dataclass
class AzureBlobParquetDataSource(
BatchDataSource,
Expand Down
27 changes: 20 additions & 7 deletions aligned/sources/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ class ParquetConfig(Codable):


@dataclass
class PartitionedParquetFileSource(BatchDataSource, ColumnFeatureMappable, DataFileReference):
class PartitionedParquetFileSource(
BatchDataSource, ColumnFeatureMappable, DataFileReference, WritableFeatureSource
):
"""
A source pointing to a Parquet file
"""
Expand Down Expand Up @@ -460,23 +462,34 @@ def multi_source_features_for(
)

async def schema(self) -> dict[str, FeatureType]:
if self.path.startswith('http'):
parquet_schema = pl.scan_parquet(self.path).schema
else:
parquet_schema = pl.read_parquet_schema(self.path)

glob_path = f'{self.directory}/**/*.parquet'
parquet_schema = pl.scan_parquet(glob_path).schema
return {name: FeatureType.from_polars(pl_type) for name, pl_type in parquet_schema.items()}

async def feature_view_code(self, view_name: str) -> str:
from aligned.feature_view.feature_view import FeatureView

raw_schema = await self.schema()
schema = {name: feat.feature_factory for name, feat in raw_schema.items()}
data_source_code = f'FileSource.parquet_at("{self.path}")'
data_source_code = f'FileSource.partitioned_parquet_at("{self.directory}", {self.partition_keys})'
return FeatureView.feature_view_code_template(
schema, data_source_code, view_name, 'from aligned import FileSource'
)

async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
if len(requests) != 1:
raise ValueError('Partitioned Parquet files only support one write request as of now')
request = requests[0]
job = job.select(request.all_returned_columns)
df = await job.to_lazy_polars()
await self.write_polars(df)

async def overwrite(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
import shutil

shutil.rmtree(self.directory)
await self.insert(job, requests)


@dataclass
class ParquetFileSource(BatchDataSource, ColumnFeatureMappable, DataFileReference):
Expand Down
4 changes: 2 additions & 2 deletions aligned/sources/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ async def test_partition_parquet(point_in_time_data_test: DataTest) -> None:
continue

entities = compiled.entitiy_names
partition_keys = list(entities)

file_source = FileSource.partitioned_parquet_at(
f'test_data/temp/{view_name}',
partition_keys=list(entities),
f'test_data/temp/{view_name}', partition_keys=partition_keys
)
await file_source.write_polars(source.data.lazy())

Expand Down
Loading

0 comments on commit c21c9dc

Please sign in to comment.