Skip to content

Commit

Permalink
Updating ImplicitArray, Adding Array8Bit
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jul 17, 2024
1 parent 067c50b commit 6ff6a70
Show file tree
Hide file tree
Showing 8 changed files with 699 additions and 150 deletions.
366 changes: 244 additions & 122 deletions .vscode/PythonImportHelper-v2-Completion.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/fjformer/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fjformer.core.implicit_array import (
ImplicitArray as ImplicitArray,
use_implicit_args as use_implicit_args,
implicit_compact as implicit_compact,
aux_field as aux_field,
UninitializedAval as UninitializedAval,
default_handler as default_handler,
Expand Down
41 changes: 23 additions & 18 deletions src/fjformer/core/implicit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Key components:
- ImplicitArray: Abstract base class for symbolic array representations
- primitive_handler: Decorator for registering custom primitive handlers
- use_implicit_args: Decorator for functions to accept ImplicitArray arguments
- implicit_compact: Decorator for functions to accept ImplicitArray arguments
"""

import warnings
Expand All @@ -27,7 +27,6 @@
from jax.tree_util import register_pytree_with_keys_class
from plum import Dispatcher, Function

# Constants and global variables
_dispatch = Dispatcher()
_primitive_ids = count()

Expand Down Expand Up @@ -418,7 +417,7 @@ def materialize_nested(implicit_arr, full=False):
wrapped = lu.wrap_init(type(implicit_arr).materialize)
flat, in_tree = flatten_one_implicit_layer((implicit_arr,))
flat_fn, out_tree = flatten_fun_nokwargs(wrapped, in_tree)
out_flat = use_implicit_args(flat_fn.call_wrapped)(*flat)
out_flat = implicit_compact(flat_fn.call_wrapped)(*flat)
implicit_arr = jax.tree_util.tree_unflatten(out_tree(), out_flat)

if not full:
Expand Down Expand Up @@ -455,17 +454,27 @@ def _implicit_inner(main, *in_vals):
yield out_vals


def use_implicit_args(f: Callable) -> Callable:
def implicit_compact(f: Callable) -> Callable:
"""
Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly
including further ImplicitArray instances as children.
Any number of arguments (including 0) may be ImplicitArrays.
A decorator that enables compact handling of ImplicitArray subclasses within a function.
This allows for seamless integration of custom array types in JAX operations.
This decorator can be used in combination with jax.jit for optimized execution.
Args:
f: The function to be decorated.
Returns:
A wrapped function that can handle ImplicitArray arguments.
A wrapped function that can handle both regular arrays and ImplicitArray instances.
Example:
>>> @jax.jit
>>> @implicit_compact
>>> def f(a, b):
... return jnp.dot(a, b)
>>> result = f(regular_array, regular_array)
>>> implicit_result = f(implicit_array, implicit_or_normal_array)
"""

@wraps(f)
Expand Down Expand Up @@ -544,7 +553,7 @@ class ImplicitArray(_ImplicitArrayBase):
Subclasses must implement the materialize method, which defines the relationship
between the implicit array and the value it represents. Subclasses are valid
arguments to functions decorated with qax.use_implicit_args.
arguments to functions decorated with core.implicit_compact.
The represented shape and dtype may be defined in various ways:
1. Explicitly passing shape/dtype keyword arguments at initialization
Expand Down Expand Up @@ -665,7 +674,7 @@ def handle_primitive(self, primitive, *args, params):
flat_args, in_tree = flatten_one_implicit_layer((args, params))
flat_handler, out_tree = flatten_fun(handler, in_tree)

result = use_implicit_args(flat_handler.call_wrapped)(*flat_args)
result = implicit_compact(flat_handler.call_wrapped)(*flat_args)
return jax.tree_util.tree_unflatten(out_tree(), result)

def __init_subclass__(cls, commute_ops=True, warn_on_materialize=True, **kwargs):
Expand Down Expand Up @@ -761,7 +770,7 @@ def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True):
else:
literals = []

wrapped_fn = lu.wrap_init(use_implicit_args(partial(core.eval_jaxpr, jaxpr)))
wrapped_fn = lu.wrap_init(implicit_compact(partial(core.eval_jaxpr, jaxpr)))
flat_args, in_tree = jax.tree_util.tree_flatten((literals, *vals_with_implicits))
flat_fn, out_tree = flatten_fun_nokwargs(wrapped_fn, in_tree)

Expand All @@ -779,7 +788,7 @@ def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True):

def _transform_jaxpr_output(jaxpr, jaxpr_args, orig_out_struct, out_transform):
def eval_fn(literals, *args):
output = use_implicit_args(partial(core.eval_jaxpr, jaxpr.jaxpr))(
output = implicit_compact(partial(core.eval_jaxpr, jaxpr.jaxpr))(
literals, *args
)
unflattened_output = orig_out_struct.unflatten(output)
Expand Down Expand Up @@ -877,11 +886,7 @@ def _handle_scan(primitive, *vals, params):
xs = vals[n_consts + n_carry :]

if any(isinstance(c, ImplicitArray) for c in carries):
warnings.warn(
"ImplicitArray in scan carries are not yet supported."
" If you need this feature please open an issue on the Qax repo:"
" https://github.com/davisyoshida/qax/issues"
)
warnings.warn("Not Supported Yet.")
carries = _materialize_all(carries)

sliced_xs = jax.tree_map(partial(jax.eval_shape, lambda x: x[0]), xs)
Expand Down Expand Up @@ -921,7 +926,7 @@ def _handle_scan(primitive, *vals, params):
def materialize_handler(primitive, *vals, params):
vals = _materialize_all(vals)
subfuns, bind_params = primitive.get_bind_params(params)
result = use_implicit_args(primitive.bind)(*subfuns, *vals, **bind_params)
result = implicit_compact(primitive.bind)(*subfuns, *vals, **bind_params)
return result


Expand Down
8 changes: 4 additions & 4 deletions src/fjformer/core/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
aux_field,
default_handler,
primitive_handler,
use_implicit_args,
implicit_compact,
)
from fjformer.core.types import Complement

Expand Down Expand Up @@ -198,12 +198,12 @@ def copy(self) -> "SymbolicConstant":
raise OperationError(f"Failed to copy SymbolicConstant: {str(e)}")


@use_implicit_args
@implicit_compact
def broadcast_to(val, shape):
return jnp.broadcast_to(val, shape)


@use_implicit_args
@implicit_compact
def astype(val, dtype):
return val.astype(dtype)

Expand Down Expand Up @@ -233,7 +233,7 @@ def _op_and_reshape(primitive, lhs, rhs, flip=False):
if flip:
lhs, rhs = (rhs, lhs)

@use_implicit_args
@implicit_compact
def inner(arg):
other = lhs
if flip:
Expand Down
4 changes: 2 additions & 2 deletions src/fjformer/core/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import tree_util
from jax.dtypes import float0

from fjformer.core.implicit_array import use_implicit_args
from fjformer.core.implicit_array import implicit_compact
from fjformer.core.symbols import SymbolicConstant


Expand Down Expand Up @@ -187,7 +187,7 @@ def apply_updates(params: optax.Params, updates: optax.Updates) -> optax.Params:
)
semi_flat_params = update_struct.flatten_up_to(params)

updated_flat = use_implicit_args(optax.apply_updates)(
updated_flat = implicit_compact(optax.apply_updates)(
semi_flat_params, updates_flat
)
updated = update_struct.unflatten(updated_flat)
Expand Down
Loading

0 comments on commit 6ff6a70

Please sign in to comment.