Skip to content

Commit

Permalink
Merge pull request #117 from matthewwardrop/add_support_for_patsy_Q
Browse files Browse the repository at this point in the history
Add support for the patsy `Q` transform.
  • Loading branch information
matthewwardrop authored Apr 22, 2023
2 parents a7ecebd + 3449e72 commit 05292aa
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 18 deletions.
1 change: 1 addition & 0 deletions docsite/docs/guides/grammar.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ that have *not* been implemented by `formulaic` are explicitly noted also.
| Transform | Description | Formulaic | Patsy | R |
|----------:|:------------|:---------:|:-----:|:-:|
| `I(...)` | Identity transform, allowing arbitrary Python/R operations, e.g. `I(x+y)`. Note that in `formulaic`, it is more idiomatic to use `{x+y}`. ||||
| `Q('<column_name>')` | Look up feature by potentially exotic name, e.g. `Q('wacky name!')`. Note that in `formulaic`, it is more idiomatic to use ``` `wacky name!` ```. ||||
| `C(...)` | Categorically encode a column, e.g. `C(x)` ||||
| `center(...)` | Shift column data so mean is zero. ||||
| `scale(...)` | Shift column so mean is zero and variance is 1. ||[^6] ||
Expand Down
4 changes: 3 additions & 1 deletion formulaic/materializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def __init__(self, data, context=None, **params):
self._init()

self.layered_context = LayeredMapping(
self.data_context, self.context, TRANSFORMS
LayeredMapping(self.data_context, name="data"),
LayeredMapping(self.context, name="context"),
LayeredMapping(TRANSFORMS, name="transforms"),
)

self.factor_cache = {}
Expand Down
6 changes: 6 additions & 0 deletions formulaic/transforms/patsy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ def Treatment(reference=TreatmentContrasts.MISSING):
return TreatmentContrasts(base=reference)


@stateful_transform
def Q(variable, _context=None):
return _context.data[variable]


PATSY_COMPAT_TRANSFORMS = {
"standardize": standardize,
"Q": Q,
"Treatment": Treatment,
"Poly": PolyContrasts,
"Sum": SumContrasts,
Expand Down
59 changes: 47 additions & 12 deletions formulaic/utils/layered_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from collections.abc import MutableMapping
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

# Cached property was introduced in Python 3.8 (we currently support 3.7)
try:
from functools import cached_property
except ImportError: # pragma: no cover
from cached_property import cached_property


class LayeredMapping(MutableMapping):
"""
Expand All @@ -10,15 +16,19 @@ class LayeredMapping(MutableMapping):
bottom until the key is found or the stack is exhausted. Mutations are
stored in an additional layer local only to the `LayeredMapping` instance,
and the layers passed in are never mutated.
Nest named layers can be extracted via attribute lookups, or via
`.named_layers`.
"""

def __init__(self, *layers: Tuple[Optional[Mapping]]):
def __init__(self, *layers: Tuple[Optional[Mapping]], name: Optional[str] = None):
"""
Crepare a `LayeredMapping` instance, populating it with the nominated
layers.
"""
self.mutations: Dict = {}
self.layers: List[Mapping] = self.__filter_layers(layers)
self.name = name
self._mutations: Dict = {}
self._layers: List[Mapping] = self.__filter_layers(layers)

@staticmethod
def __filter_layers(layers: Iterable[Mapping]) -> List[Mapping]:
Expand All @@ -28,36 +38,37 @@ def __filter_layers(layers: Iterable[Mapping]) -> List[Mapping]:
return [layer for layer in layers if layer is not None]

def __getitem__(self, key: Any) -> Any:
for layer in [self.mutations, *self.layers]:
for layer in [self._mutations, *self._layers]:
if key in layer:
return layer[key]
raise KeyError(key)

def __setitem__(self, key: Any, value: Any):
self.mutations[key] = value
self._mutations[key] = value

def __delitem__(self, key: Any):
if key in self.mutations:
del self.mutations[key]
if key in self._mutations:
del self._mutations[key]
else:
raise KeyError(f"Key '{key}' not found in mutable layer.")

def __iter__(self):
keys = set()
for layer in [self.mutations, *self.layers]:
for layer in [self._mutations, *self._layers]:
for key in layer:
if key not in keys:
keys.add(key)
yield key

def __len__(self):
return len(set(itertools.chain(self.mutations, *self.layers)))
return len(set(itertools.chain(self._mutations, *self._layers)))

def with_layers(
self,
*layers: Tuple[Optional[Mapping]],
prepend: bool = True,
inplace: bool = False,
name: Optional[str] = None,
) -> "LayeredMapping":
"""
Return a copy of this `LayeredMapping` instance with additional layers
Expand All @@ -78,10 +89,34 @@ def with_layers(
return self

if inplace:
self.layers = (
[*layers, *self.layers] if prepend else [*self.layers, *layers]
self._layers = (
[*layers, *self._layers] if prepend else [*self._layers, *layers]
)
self.name = name
if "named_layers" in self.__dict__:
del self.named_layers
return self

new_layers = [*layers, self] if prepend else [self, *layers]
return LayeredMapping(*new_layers)
return LayeredMapping(*new_layers, name=name)

# Named layer lookups and caching

@cached_property
def named_layers(self):
named_layers = {}
local = {}
for layer in reversed(self._layers):
if isinstance(layer, LayeredMapping):
if layer.name:
local[layer.name] = layer
named_layers.update(layer.named_layers)
named_layers.update(local)
if self.name:
named_layers[self.name] = self
return named_layers

def __getattr__(self, attr):
if attr not in self.named_layers:
raise AttributeError(f"{repr(attr)} does not correspond to a named layer.")
return self.named_layers[attr]
27 changes: 24 additions & 3 deletions formulaic/utils/stateful_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def stateful_transform(func: Callable) -> Callable:
- _state: The existing state or an empty dictionary.
- _metadata: Any extra metadata passed about the factor being evaluated.
- _spec: The `ModelSpec` instance being evaluated (or an empty `ModelSpec`).
- _context: A mapping of the name to value for all the variables available
in the formula evaluation context (including data column names).
If the callable has any of these in its signature, these will be passed onto
it; otherwise, they will be swallowed by the stateful transform wrapper.
Expand All @@ -37,10 +39,12 @@ def stateful_transform(func: Callable) -> Callable:
The stateful transform callable.
"""
func = functools.singledispatch(func)
params = inspect.signature(func).parameters.keys()
params = set(inspect.signature(func).parameters.keys())

@functools.wraps(func)
def wrapper(data, *args, _metadata=None, _state=None, _spec=None, **kwargs):
def wrapper(
data, *args, _metadata=None, _state=None, _spec=None, _context=None, **kwargs
):
from formulaic.model_spec import ModelSpec

_state = {} if _state is None else _state
Expand All @@ -49,6 +53,8 @@ def wrapper(data, *args, _metadata=None, _state=None, _spec=None, **kwargs):
extra_params["_metadata"] = _metadata
if "_spec" in params:
extra_params["_spec"] = _spec or ModelSpec(formula=[])
if "_context" in params:
extra_params["_context"] = _context

if isinstance(data, dict):
results = {}
Expand All @@ -63,7 +69,14 @@ def wrapper(data, *args, _metadata=None, _state=None, _spec=None, **kwargs):
if statum:
_state[key] = statum
return results
return func(data, *args, _state=_state, **extra_params, **kwargs)

return func(
data,
*args,
**({"_state": _state} if "_state" in params else {}),
**extra_params,
**kwargs,
)

wrapper.__is_stateful_transform__ = True
return wrapper
Expand Down Expand Up @@ -127,6 +140,12 @@ def stateful_eval(
name = name.replace('"', r'\\\\"')
if name not in state:
state[name] = {}
node.keywords.append(
ast.keyword(
"_context",
ast.parse("__FORMULAIC_CONTEXT__", mode="eval").body,
)
)
node.keywords.append(
ast.keyword(
"_metadata",
Expand All @@ -145,6 +164,7 @@ def stateful_eval(
# Compile mutated AST
code = compile(ast.fix_missing_locations(code), "", "eval")

assert "__FORMULAIC_CONTEXT__" not in env
assert "__FORMULAIC_METADATA__" not in env
assert "__FORMULAIC_STATE__" not in env
assert "__FORMULAIC_SPEC__" not in env
Expand All @@ -155,6 +175,7 @@ def stateful_eval(
{},
LayeredMapping(
{
"__FORMULAIC_CONTEXT__": env,
"__FORMULAIC_METADATA__": metadata,
"__FORMULAIC_SPEC__": spec,
"__FORMULAIC_STATE__": state,
Expand Down
6 changes: 6 additions & 0 deletions tests/transforms/test_patsy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def test_Treatment():
).values
== numpy.array([[1, 0, 0], [1, 1, 0], [1, 0, 1]])
)


def test_Q():
assert model_matrix(
"Q('x')", pandas.DataFrame({"x": [1, 2, 3]})
).model_spec.column_names == ("Intercept", "Q('x')")
30 changes: 28 additions & 2 deletions tests/utils/test_layered_mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest

from formulaic.utils.layered_mapping import LayeredMapping
Expand Down Expand Up @@ -31,10 +33,34 @@ def test_layered_context():

# Test mutations
layered["f"] = 10
assert layered.mutations == {"f": 10}
assert layered._mutations == {"f": 10}

del layered["f"]
assert layered.mutations == {}
assert layered._mutations == {}

with pytest.raises(KeyError):
del layered["a"]


def test_named_layered_mappings():

data_layer = LayeredMapping({"data": 1}, name="data")
context_layer = LayeredMapping({"context": "context"}, name="context")
layers = LayeredMapping({"data": None, "context": None}, data_layer, context_layer)

assert sorted(layers.named_layers) == ["context", "data"]
assert layers["data"] is None
assert layers["context"] is None
assert layers.data["data"] == 1
assert layers.context["context"] == "context"

assert layers.with_layers({"data": 2}, inplace=True)["data"] == 2
assert sorted(
layers.with_layers({"data": 2}, inplace=True, name="toplevel").named_layers
) == ["context", "data", "toplevel"]

with pytest.raises(
AttributeError,
match=re.escape("'missing' does not correspond to a named layer."),
):
layers.missing

0 comments on commit 05292aa

Please sign in to comment.