Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
liquidcarbon committed Sep 28, 2024
1 parent 411e9da commit b6e3cf6
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 155 deletions.
188 changes: 102 additions & 86 deletions affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Module for creating well-documented datasets, with types and annotations.
"""

import numpy as np
import pandas as pd
from importlib import import_module
from time import time
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
import pandas as pd


def try_import(module) -> Optional[object]:
try:
Expand All @@ -15,11 +17,12 @@ def try_import(module) -> Optional[object]:
print(f"{module} not found in the current environment")
return


if TYPE_CHECKING:
import duckdb # type: ignore
import polars as pl # type: ignore
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
import polars as pl # type: ignore
else:
duckdb = try_import("duckdb")
pl = try_import("polars")
Expand All @@ -28,14 +31,15 @@ def try_import(module) -> Optional[object]:


class Descriptor:
"""Base class for scalars and vectors."""

def __get__(self, instance, owner):
return self if not instance else instance.__dict__[self.name]

def __set__(self, instance, values):
try:
_values = self.array_class(
values if values is not None else [],
dtype=self.dtype
values if values is not None else [], dtype=self.dtype
)
except OverflowError as e:
raise e
Expand All @@ -56,7 +60,7 @@ def info(self):

@classmethod
def factory(cls, dtype, array_class=pd.Series, cls_name=None):
"""Factory method to create classes.
"""Factory method for creating typed classes.
Reverted to explicit class declarations.
Unable to convince IDEs that factory-made classes are not of "DescriptorType".
Expand All @@ -65,6 +69,7 @@ def factory(cls, dtype, array_class=pd.Series, cls_name=None):
class DescriptorType(cls):
def __init__(self, comment=None, *, values=None, array_class=array_class):
super().__init__(dtype, values, comment, array_class)

if cls_name:
DescriptorType.__name__ = cls_name
return DescriptorType
Expand All @@ -87,9 +92,11 @@ def __repr__(self):


class Vector(Descriptor):
"""Vectors are typed arrays of values."""

@classmethod
def from_scalar(cls, scalar: Scalar, length=1):
_value = [] if (not length or scalar.value is None) else [scalar.value]*length
_value = [] if (not length or scalar.value is None) else [scalar.value] * length
instance = cls(scalar.dtype, _value, scalar.comment, scalar.array_class)
instance.scalar = scalar.value
return instance
Expand Down Expand Up @@ -126,39 +133,41 @@ class DatasetMeta(type):
def __repr__(cls) -> str:
_lines = [cls.__name__]
for k, v in cls.__dict__.items():
if isinstance (v, Descriptor):
if isinstance(v, Descriptor):
_lines.append(f"{k}: {v.info}")
return "\n".join(_lines)


class Dataset(metaclass=DatasetMeta):
"""Base class for typed, annotated datasets."""

save_to = {"partition": tuple(), "prefix": "", "file": ""}

@classmethod
def get_scalars(cls):
return {k: None for k,v in cls.__dict__.items() if isinstance(v, Scalar)}
return {k: None for k, v in cls.__dict__.items() if isinstance(v, Scalar)}

@classmethod
def get_vectors(cls):
return {k: None for k,v in cls.__dict__.items() if isinstance(v, Vector)}
return {k: None for k, v in cls.__dict__.items() if isinstance(v, Vector)}

@classmethod
def get_dict(cls):
return dict(cls())

def __init__(self, **fields: Union[Scalar|Vector]):
def __init__(self, **fields: Union[Scalar, Vector]):
"""Create dataset, dynamically setting field values.
Vectors are initialized first, ensuring all are of equal length.
Scalars are filled in afterwards.
"""

self.origin = {"created_ts": int(time() * 1000)}
_sizes = {}
self._vectors = self.__class__.get_vectors()
self._scalars = self.__class__.get_scalars()
if len(self._vectors) == 0 and len(self._scalars) == 0:
raise ValueError("no attributes defined in your dataset")
self.origin = {"created_ts": int(time() * 1000)}
_sizes = {}
for vector_name in self._vectors:
field_data = fields.get(vector_name)
setattr(self, vector_name, field_data)
Expand Down Expand Up @@ -196,7 +205,9 @@ def build(cls, query=None, dataframe=None, **kwargs):
return cls.from_dataframe(dataframe, **kwargs)

@classmethod
def from_dataframe(cls, dataframe: pd.DataFrame | Optional['pl.DataFrame'], **kwargs):
def from_dataframe(
cls, dataframe: pd.DataFrame | Optional["pl.DataFrame"], **kwargs
):
instance = cls()
for i, k in enumerate(dict(instance)):
if kwargs.get("rename") in (None, False):
Expand All @@ -213,7 +224,7 @@ def from_sql(cls, query: str, **kwargs):
if kwargs.get("method") in ("polars",):
query_results = duckdb.sql(query).pl()
instance = cls.from_dataframe(query_results, **kwargs)
instance.origin["source"] += f'\nquery:\n{query}'
instance.origin["source"] += f"\nquery:\n{query}"
return instance

def __eq__(self, other):
Expand All @@ -225,9 +236,7 @@ def __len__(self) -> int:
def __iter__(self):
"""Yields attr names and values, in same order as defined in class."""
yield from (
(k, self.__dict__[k])
for k in self.__class__.__dict__
if k in self.__dict__
(k, self.__dict__[k]) for k in self.__class__.__dict__ if k in self.__dict__
)

def __repr__(self):
Expand All @@ -238,54 +247,6 @@ def __repr__(self):
lines.append(f"{k} = {v}".replace(", '...',", " ..."))
return "\n".join(lines)

def is_dataset(self, key):
attr = getattr(self, key, None)
if attr is None or len(attr) == 0 or isinstance(attr, Scalar):
return False
else:
return all(isinstance(v, Dataset) for v in attr)

def sql(self, query, **replacements):
"""Query the dataset with DuckDB.
DuckDB uses replacement scans to query python objects.
Class instance attributes like `FROM self.df` must be registered as views.
This is what **replacements kwargs are for.
By default, df=self.df (pandas dataframe) is used.
The registered views persist across queries. RAM impact TBD.
"""
if replacements.get("df") is None:
duckdb.register("df", self.df)
for k, v in replacements.items():
duckdb.register(k, v)
return duckdb.sql(query)

def flatten(self):
"""List of dicts? Dict of lists? TBD"""
raise NotImplementedError

def model_dump(self) -> dict:
"""Similar to Pydantic's model_dump; alias for dict."""
return self.dict

def to_parquet(self, path, engine="duckdb", **kwargs):
if engine == "arrow":
pq.write_table(self.arrow, path)
if engine == "duckdb":
kv_metadata = []
for k, v in self.metadata.items():
if isinstance(v, str) and "'" in v:
_v = {v.replace("'", "''")} # must escape single quotes
kv_metadata.append(f"{k}: '{_v}'")
else:
kv_metadata.append(f"{k}: '{v}'")
self.sql(f"""
COPY (SELECT * FROM df) TO {path} (
FORMAT PARQUET,
KV_METADATA {{ {", ".join(kv_metadata)} }}
);""", **kwargs)
return path

@property
def shape(self):
return len(self), len(self._vectors) + len(self._scalars)
Expand All @@ -306,7 +267,7 @@ def metadata(self) -> dict:
return {
"table_comment": self.__class__.__doc__,
**self.data_dict,
**self.origin
**self.origin,
}

@property
Expand All @@ -317,7 +278,6 @@ def df(self) -> pd.DataFrame:
}
return pd.DataFrame(_dict)


@property
def df4(self) -> pd.DataFrame:
if len(self) > 4:
Expand All @@ -340,89 +300,145 @@ def arrow(self) -> "pa.Table":
def pl(self) -> "pl.DataFrame":
return pl.DataFrame(dict(self))

def is_dataset(self, key):
attr = getattr(self, key, None)
if attr is None or len(attr) == 0 or isinstance(attr, Scalar):
return False
else:
return all(isinstance(v, Dataset) for v in attr)

def sql(self, query, **replacements):
"""Query the dataset with DuckDB.
DuckDB uses replacement scans to query python objects.
Class instance attributes like `FROM self.df` must be registered as views.
This is what **replacements kwargs are for.
By default, df=self.df (pandas dataframe) is used.
The registered views persist across queries. RAM impact TBD.
"""
if replacements.get("df") is None:
duckdb.register("df", self.df)
for k, v in replacements.items():
duckdb.register(k, v)
return duckdb.sql(query)

def flatten(self):
"""List of dicts? Dict of lists? TBD"""
raise NotImplementedError

def model_dump(self) -> dict:
"""Similar to Pydantic's model_dump; alias for dict."""
return self.dict

def to_parquet(self, path, engine="duckdb", **kwargs):
if engine == "arrow":
pq.write_table(self.arrow, path)
if engine == "duckdb":
kv_metadata = []
for k, v in self.metadata.items():
if isinstance(v, str) and "'" in v:
_v = {v.replace("'", "''")} # must escape single quotes
kv_metadata.append(f"{k}: '{_v}'")
else:
kv_metadata.append(f"{k}: '{v}'")
self.sql(
f"""
COPY (SELECT * FROM df) TO {path} (
FORMAT PARQUET,
KV_METADATA {{ {", ".join(kv_metadata)} }}
);""",
**kwargs,
)
return path

def save(self):
"""Path and format constructed from `save_to` attribute."""
raise NotImplementedError


### Typed scalars and vectors


class ScalarObject(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=object, comment=comment, value=value, array_class=array_class)
super().__init__(object, value, comment, array_class)


class ScalarBool(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype="boolean", comment=comment, value=value, array_class=array_class)
super().__init__("boolean", value, comment, array_class)


class ScalarI8(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=pd.Int8Dtype(), comment=comment, value=value, array_class=array_class)
super().__init__(pd.Int8Dtype(), value, comment, array_class)


class ScalarI16(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=pd.Int16Dtype(), comment=comment, value=value, array_class=array_class)
super().__init__(pd.Int16Dtype(), value, comment, array_class)


class ScalarI32(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=pd.Int32Dtype(), comment=comment, value=value, array_class=array_class)
super().__init__(pd.Int32Dtype(), value, comment, array_class)


class ScalarI64(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=pd.Int64Dtype(), comment=comment, value=value, array_class=array_class)
super().__init__(pd.Int64Dtype(), value, comment, array_class)


class ScalarF32(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=np.float32, comment=comment, value=value, array_class=array_class)
super().__init__(np.float32, value, comment, array_class)


class ScalarF64(Scalar):
def __init__(self, comment: str, *, value=None, array_class=pd.Series):
super().__init__(dtype=np.float64, comment=comment, value=value, array_class=array_class)
super().__init__(np.float64, value, comment, array_class)


class VectorObject(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=object, comment=comment, values=values, array_class=array_class)
super().__init__(object, values, comment, array_class)


class VectorBool(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype="boolean", comment=comment, values=values, array_class=array_class)
super().__init__("boolean", values, comment, array_class)


class VectorI8(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=pd.Int8Dtype(), comment=comment, values=values, array_class=array_class)
super().__init__(pd.Int8Dtype(), values, comment, array_class)


class VectorI16(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=pd.Int16Dtype(), comment=comment, values=values, array_class=array_class)
super().__init__(pd.Int16Dtype(), values, comment, array_class)


class VectorI32(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=pd.Int32Dtype(), comment=comment, values=values, array_class=array_class)
super().__init__(pd.Int32Dtype(), values, comment, array_class)


class VectorI64(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=pd.Int64Dtype(), comment=comment, values=values, array_class=array_class)
super().__init__(pd.Int64Dtype(), values, comment, array_class)


class VectorF16(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=np.float16, comment=comment, values=values, array_class=array_class)
super().__init__(np.float16, values, comment, array_class)


class VectorF32(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=np.float32, comment=comment, values=values, array_class=array_class)
super().__init__(np.float32, values, comment, array_class)


class VectorF64(Vector):
def __init__(self, comment: str, *, values=None, array_class=pd.Series):
super().__init__(dtype=np.float64, comment=comment, values=values, array_class=array_class)
super().__init__(np.float64, values, comment, array_class)
Loading

0 comments on commit b6e3cf6

Please sign in to comment.