From 6ff6a700de547c813917fd36b6b77f106d1cf0f2 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Wed, 17 Jul 2024 14:38:17 +0330 Subject: [PATCH] Updating `ImplicitArray`, Adding `Array8Bit` --- .vscode/PythonImportHelper-v2-Completion.json | 366 ++++++++++----- src/fjformer/core/__init__.py | 2 +- src/fjformer/core/implicit_array.py | 41 +- src/fjformer/core/symbols.py | 8 +- src/fjformer/core/utilities.py | 4 +- src/fjformer/custom_array/quantized_8bit.py | 422 ++++++++++++++++++ src/fjformer/lora/lora_core.py | 2 +- src/fjformer/lora/rapture.py | 4 +- 8 files changed, 699 insertions(+), 150 deletions(-) create mode 100644 src/fjformer/custom_array/quantized_8bit.py diff --git a/.vscode/PythonImportHelper-v2-Completion.json b/.vscode/PythonImportHelper-v2-Completion.json index c48fb87..cee7949 100644 --- a/.vscode/PythonImportHelper-v2-Completion.json +++ b/.vscode/PythonImportHelper-v2-Completion.json @@ -371,6 +371,30 @@ "detail": "typing", "documentation": {} }, + { + "label": "Optional", + "importPath": "typing", + "description": "typing", + "isExtraImport": true, + "detail": "typing", + "documentation": {} + }, + { + "label": "Sequence", + "importPath": "typing", + "description": "typing", + "isExtraImport": true, + "detail": "typing", + "documentation": {} + }, + { + "label": "Union", + "importPath": "typing", + "description": "typing", + "isExtraImport": true, + "detail": "typing", + "documentation": {} + }, { "label": "Mapping", "importPath": "typing", @@ -1051,14 +1075,6 @@ "detail": "typing", "documentation": {} }, - { - "label": "Tuple", - "importPath": "typing", - "description": "typing", - "isExtraImport": true, - "detail": "typing", - "documentation": {} - }, { "label": "flax.struct", "kind": 6, @@ -1150,6 +1166,14 @@ "detail": "dataclasses", "documentation": {} }, + { + "label": "dataclass", + "importPath": "dataclasses", + "description": "dataclasses", + "isExtraImport": true, + "detail": "dataclasses", + "documentation": {} + }, { "label": "calibration", "importPath": "fjformer.bit_quantization", @@ -1359,6 +1383,30 @@ "detail": "jax", "documentation": {} }, + { + "label": "Array", + "importPath": "jax", + "description": "jax", + "isExtraImport": true, + "detail": "jax", + "documentation": {} + }, + { + "label": "lax", + "importPath": "jax", + "description": "jax", + "isExtraImport": true, + "detail": "jax", + "documentation": {} + }, + { + "label": "numpy", + "importPath": "jax", + "description": "jax", + "isExtraImport": true, + "detail": "jax", + "documentation": {} + }, { "label": "numpy", "importPath": "jax", @@ -1551,6 +1599,14 @@ "detail": "jax", "documentation": {} }, + { + "label": "lax", + "importPath": "jax", + "description": "jax", + "isExtraImport": true, + "detail": "jax", + "documentation": {} + }, { "label": "Array", "importPath": "jax", @@ -2102,7 +2158,7 @@ "documentation": {} }, { - "label": "use_implicit_args", + "label": "implicit_compact", "importPath": "fjformer.core.implicit_array", "description": "fjformer.core.implicit_array", "isExtraImport": true, @@ -2110,7 +2166,7 @@ "documentation": {} }, { - "label": "use_implicit_args", + "label": "implicit_compact", "importPath": "fjformer.core.implicit_array", "description": "fjformer.core.implicit_array", "isExtraImport": true, @@ -2158,6 +2214,47 @@ "detail": "fjformer.core.symbols", "documentation": {} }, + { + "label": "fjformer.core", + "kind": 6, + "isExtraImport": true, + "importPath": "fjformer.core", + "description": "fjformer.core", + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "EmptyNode", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "materialize_nested", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "tree_map_with_implicit", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, + { + "label": "implicit_compact", + "importPath": "fjformer.core", + "description": "fjformer.core", + "isExtraImport": true, + "detail": "fjformer.core", + "documentation": {} + }, { "label": "chex", "kind": 6, @@ -2265,70 +2362,6 @@ "detail": "fjformer", "documentation": {} }, - { - "label": "EmptyNode", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "materialize_nested", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "tree_map_with_implicit", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "use_implicit_args", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "ImplicitArray", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "primitive_handler", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "use_implicit_args", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, - { - "label": "ArrayValue", - "importPath": "fjformer.core", - "description": "fjformer.core", - "isExtraImport": true, - "detail": "fjformer.core", - "documentation": {} - }, { "label": "freeze_subtrees", "importPath": "fjformer.core.utilities", @@ -2927,6 +2960,14 @@ "detail": "fjformer.functions.loss_functions", "documentation": {} }, + { + "label": "Array8Bit", + "importPath": "fjformer.custom_array.quantized_8bit", + "description": "fjformer.custom_array.quantized_8bit", + "isExtraImport": true, + "detail": "fjformer.custom_array.quantized_8bit", + "documentation": {} + }, { "label": "project", "kind": 5, @@ -3553,7 +3594,7 @@ "kind": 6, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "class ImplicitArray(_ImplicitArrayBase):\n \"\"\"\n Abstract class for representing an abstract array without instantiation.\n Subclasses must implement the materialize method, which defines the relationship\n between the implicit array and the value it represents. Subclasses are valid\n arguments to functions decorated with qax.use_implicit_args.\n The represented shape and dtype may be defined in various ways:\n 1. Explicitly passing shape/dtype keyword arguments at initialization\n 2. Overriding the default_shape/default_dtype class variables\n 3. Overriding the compute_shape/compute_dtype methods", + "peekOfCode": "class ImplicitArray(_ImplicitArrayBase):\n \"\"\"\n Abstract class for representing an abstract array without instantiation.\n Subclasses must implement the materialize method, which defines the relationship\n between the implicit array and the value it represents. Subclasses are valid\n arguments to functions decorated with core.implicit_compact.\n The represented shape and dtype may be defined in various ways:\n 1. Explicitly passing shape/dtype keyword arguments at initialization\n 2. Overriding the default_shape/default_dtype class variables\n 3. Overriding the compute_shape/compute_dtype methods", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3657,11 +3698,11 @@ "documentation": {} }, { - "label": "use_implicit_args", + "label": "implicit_compact", "kind": 2, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "def use_implicit_args(f: Callable) -> Callable:\n \"\"\"\n Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly\n including further ImplicitArray instances as children.\n Any number of arguments (including 0) may be ImplicitArrays.\n Args:\n f: The function to be decorated.\n Returns:\n A wrapped function that can handle ImplicitArray arguments.\n \"\"\"", + "peekOfCode": "def implicit_compact(f: Callable) -> Callable:\n \"\"\"\n A decorator that enables compact handling of ImplicitArray subclasses within a function.\n This allows for seamless integration of custom array types in JAX operations.\n This decorator can be used in combination with jax.jit for optimized execution.\n Args:\n f: The function to be decorated.\n Returns:\n A wrapped function that can handle both regular arrays and ImplicitArray instances.\n Example:", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3688,7 +3729,7 @@ "kind": 2, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True):\n if isinstance(jaxpr, jax.core.ClosedJaxpr):\n literals = jaxpr.literals\n jaxpr = jaxpr.jaxpr\n else:\n literals = []\n wrapped_fn = lu.wrap_init(use_implicit_args(partial(core.eval_jaxpr, jaxpr)))\n flat_args, in_tree = jax.tree_util.tree_flatten((literals, *vals_with_implicits))\n flat_fn, out_tree = flatten_fun_nokwargs(wrapped_fn, in_tree)\n new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(", + "peekOfCode": "def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True):\n if isinstance(jaxpr, jax.core.ClosedJaxpr):\n literals = jaxpr.literals\n jaxpr = jaxpr.jaxpr\n else:\n literals = []\n wrapped_fn = lu.wrap_init(implicit_compact(partial(core.eval_jaxpr, jaxpr)))\n flat_args, in_tree = jax.tree_util.tree_flatten((literals, *vals_with_implicits))\n flat_fn, out_tree = flatten_fun_nokwargs(wrapped_fn, in_tree)\n new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3697,7 +3738,7 @@ "kind": 2, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "def materialize_handler(primitive, *vals, params):\n vals = _materialize_all(vals)\n subfuns, bind_params = primitive.get_bind_params(params)\n result = use_implicit_args(primitive.bind)(*subfuns, *vals, **bind_params)\n return result\ndef _broadcast_tuple(t, trees):\n if isinstance(trees, jax.tree_util.PyTreeDef):\n trees = jax.tree_util.tree_unflatten(trees, range(trees.num_leaves))\n assert len(t) == len(trees)\n return tuple(", + "peekOfCode": "def materialize_handler(primitive, *vals, params):\n vals = _materialize_all(vals)\n subfuns, bind_params = primitive.get_bind_params(params)\n result = implicit_compact(primitive.bind)(*subfuns, *vals, **bind_params)\n return result\ndef _broadcast_tuple(t, trees):\n if isinstance(trees, jax.tree_util.PyTreeDef):\n trees = jax.tree_util.tree_unflatten(trees, range(trees.num_leaves))\n assert len(t) == len(trees)\n return tuple(", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3841,7 +3882,7 @@ "kind": 5, "importPath": "src.fjformer.core.implicit_array", "description": "src.fjformer.core.implicit_array", - "peekOfCode": "_default_handlers = {\n \"cond\": _handle_cond,\n \"remat2\": _handle_remat2,\n \"pjit\": _handle_pjit,\n \"scan\": _handle_scan,\n}\ndef materialize_handler(primitive, *vals, params):\n vals = _materialize_all(vals)\n subfuns, bind_params = primitive.get_bind_params(params)\n result = use_implicit_args(primitive.bind)(*subfuns, *vals, **bind_params)", + "peekOfCode": "_default_handlers = {\n \"cond\": _handle_cond,\n \"remat2\": _handle_remat2,\n \"pjit\": _handle_pjit,\n \"scan\": _handle_scan,\n}\ndef materialize_handler(primitive, *vals, params):\n vals = _materialize_all(vals)\n subfuns, bind_params = primitive.get_bind_params(params)\n result = implicit_compact(primitive.bind)(*subfuns, *vals, **bind_params)", "detail": "src.fjformer.core.implicit_array", "documentation": {} }, @@ -3877,7 +3918,7 @@ "kind": 2, "importPath": "src.fjformer.core.symbols", "description": "src.fjformer.core.symbols", - "peekOfCode": "def broadcast_to(val, shape):\n return jnp.broadcast_to(val, shape)\n@use_implicit_args\ndef astype(val, dtype):\n return val.astype(dtype)\n@primitive_handler(\n [\n \"reshape\",\n \"broadcast_in_dim\",\n \"reduce_min\",", + "peekOfCode": "def broadcast_to(val, shape):\n return jnp.broadcast_to(val, shape)\n@implicit_compact\ndef astype(val, dtype):\n return val.astype(dtype)\n@primitive_handler(\n [\n \"reshape\",\n \"broadcast_in_dim\",\n \"reduce_min\",", "detail": "src.fjformer.core.symbols", "documentation": {} }, @@ -4124,6 +4165,123 @@ "detail": "src.fjformer.core.utilities", "documentation": {} }, + { + "label": "Array8Bit", + "kind": 6, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "class Array8Bit(core.ImplicitArray):\n \"\"\"\n Custom 8-bit Quantized Array for efficient manipulation of JAX arrays.\n This class provides methods for quantizing and dequantizing JAX arrays to 8-bit\n representation, which can significantly reduce memory usage and potentially\n improve computation speed for certain operations.\n Attributes:\n array_quantized (core.ArrayValue): The quantized array data.\n scale (core.ArrayValue): Scaling factors for dequantization.\n min_vals (core.ArrayValue): Minimum values used in quantization.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_dot_general", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_dot_general(primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's dot_general operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n lhs (ArrayType): Left-hand side array.\n rhs (ArrayType): Right-hand side array.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_add", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_add(primitive, x: ArrayType, y: ArrayType):\n \"\"\"\n Custom handler for JAX's add operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n x (ArrayType): First array to add.\n y (ArrayType): Second array to add.\n Returns:\n The result of lax.add operation.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_reduce", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_reduce(\n primitive, operand: ArrayType, init_value: ArrayType, *args, **kwargs\n):\n \"\"\"\n Custom handler for JAX's reduce operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n operand (ArrayType): The array to be reduced.\n init_value (ArrayType): The initial value for the reduction.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_mul", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_mul(primitive, x: ArrayType, y: ArrayType):\n \"\"\"\n Custom handler for JAX's mul operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n x (ArrayType): First array to multiply.\n y (ArrayType): Second array to multiply.\n Returns:\n The result of lax.mul operation.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_transpose", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_transpose(primitive, operand: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's transpose operation.\n Materializes Array8Bit input before performing the operation.\n Re-quantizes the result if the input was Array8Bit.\n Args:\n primitive: The JAX primitive being handled.\n operand (ArrayType): The array to be transposed.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_conv", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_conv(primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's conv_general_dilated operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n lhs (ArrayType): Left-hand side array (input).\n rhs (ArrayType): Right-hand side array (kernel).\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_max", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_max(primitive, x: ArrayType, y: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's max operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n x (ArrayType): First array for max comparison.\n y (ArrayType): Second array for max comparison.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_exp", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_exp(primitive, x: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's exp operation.\n Materializes Array8Bit input before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n x (ArrayType): The array to apply exponential to.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.\n Returns:", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_log", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_log(primitive, x: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's log operation.\n Materializes Array8Bit input before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n x (ArrayType): The array to apply logarithm to.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.\n Returns:", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_reshape", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_reshape(primitive, operand: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's reshape operation.\n Materializes Array8Bit input before performing the operation.\n Re-quantizes the result if the input was Array8Bit.\n Args:\n primitive: The JAX primitive being handled.\n operand (ArrayType): The array to be reshaped.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "handle_concatenate", + "kind": 2, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "def handle_concatenate(primitive, operands: Sequence[ArrayType], *args, **kwargs):\n \"\"\"\n Custom handler for JAX's concatenate operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n operands (Sequence[ArrayType]): Sequence of arrays to concatenate.\n *args: Variable length argument list.\n **kwargs: Arbitrary keyword arguments.\n Returns:", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, + { + "label": "ArrayType", + "kind": 5, + "importPath": "src.fjformer.custom_array.quantized_8bit", + "description": "src.fjformer.custom_array.quantized_8bit", + "peekOfCode": "ArrayType = Union[Array, Array8Bit]\n@core.primitive_handler(\"dot_general\")\ndef handle_dot_general(primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs):\n \"\"\"\n Custom handler for JAX's dot_general operation.\n Materializes Array8Bit inputs before performing the operation.\n Args:\n primitive: The JAX primitive being handled.\n lhs (ArrayType): Left-hand side array.\n rhs (ArrayType): Right-hand side array.", + "detail": "src.fjformer.custom_array.quantized_8bit", + "documentation": {} + }, { "label": "global_norm", "kind": 2, @@ -4912,7 +5070,7 @@ "kind": 6, "importPath": "src.fjformer.lora.lora_core", "description": "src.fjformer.lora.lora_core", - "peekOfCode": "class LoraError(Exception):\n \"\"\"Base exception for LoRA-related errors.\"\"\"\n pass\nclass UnsupportedOperationError(LoraError):\n \"\"\"Raised when an unsupported operation is encountered.\"\"\"\n pass\ndef lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.use_implicit_args(f)\n@dataclass", + "peekOfCode": "class LoraError(Exception):\n \"\"\"Base exception for LoRA-related errors.\"\"\"\n pass\nclass UnsupportedOperationError(LoraError):\n \"\"\"Raised when an unsupported operation is encountered.\"\"\"\n pass\ndef lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.implicit_compact(f)\n@dataclass", "detail": "src.fjformer.lora.lora_core", "documentation": {} }, @@ -4921,7 +5079,7 @@ "kind": 6, "importPath": "src.fjformer.lora.lora_core", "description": "src.fjformer.lora.lora_core", - "peekOfCode": "class UnsupportedOperationError(LoraError):\n \"\"\"Raised when an unsupported operation is encountered.\"\"\"\n pass\ndef lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.use_implicit_args(f)\n@dataclass\nclass LoraWeight(cr.ImplicitArray):\n \"\"\"Represents a LoRA (Low-Rank Adaptation) weight.\"\"\"\n w: cr.ArrayValue # M x N (2D)", + "peekOfCode": "class UnsupportedOperationError(LoraError):\n \"\"\"Raised when an unsupported operation is encountered.\"\"\"\n pass\ndef lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.implicit_compact(f)\n@dataclass\nclass LoraWeight(cr.ImplicitArray):\n \"\"\"Represents a LoRA (Low-Rank Adaptation) weight.\"\"\"\n w: cr.ArrayValue # M x N (2D)", "detail": "src.fjformer.lora.lora_core", "documentation": {} }, @@ -4939,7 +5097,7 @@ "kind": 2, "importPath": "src.fjformer.lora.lora_core", "description": "src.fjformer.lora.lora_core", - "peekOfCode": "def lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.use_implicit_args(f)\n@dataclass\nclass LoraWeight(cr.ImplicitArray):\n \"\"\"Represents a LoRA (Low-Rank Adaptation) weight.\"\"\"\n w: cr.ArrayValue # M x N (2D)\n a: cr.ArrayValue # k x N (2D)\n b: cr.ArrayValue # M x k (2D)\n alpha: float = cr.aux_field(default=1.00)", + "peekOfCode": "def lora(f: Any) -> Any:\n \"\"\"Decorator for LoRA-compatible functions.\"\"\"\n return cr.implicit_compact(f)\n@dataclass\nclass LoraWeight(cr.ImplicitArray):\n \"\"\"Represents a LoRA (Low-Rank Adaptation) weight.\"\"\"\n w: cr.ArrayValue # M x N (2D)\n a: cr.ArrayValue # k x N (2D)\n b: cr.ArrayValue # M x k (2D)\n alpha: float = cr.aux_field(default=1.00)", "detail": "src.fjformer.lora.lora_core", "documentation": {} }, @@ -6788,48 +6946,12 @@ "detail": "test.test_cross_ent_loss_and_acc", "documentation": {} }, - { - "label": "QuantizedArray", - "kind": 6, - "importPath": "env", - "description": "env", - "peekOfCode": "class QuantizedArray(ImplicitArray):\n array_quant: ArrayValue\n scale: ArrayValue\n min_vals: ArrayValue\n def materialize(self):\n return self.dequantize(\n array_quant=self.array_quant,\n scale=self.scale,\n min_vals=self.min_vals,\n float_dtype=self.dtype,", - "detail": "env", - "documentation": {} - }, - { - "label": "quantize", - "kind": 2, - "importPath": "env", - "description": "env", - "peekOfCode": "def quantize(array: Array, axis: int = -1) -> Tuple[Array, Array, Array]:\n min_vals = jnp.min(array, axis=axis, keepdims=True)\n max_vals = jnp.max(array, axis=axis, keepdims=True)\n # Compute the scaling factors\n scale = (max_vals - min_vals) / (2**7 - 1)\n # Quantize the data\n quantized_data = jnp.round((array - min_vals) / scale)\n # Clip the quantized values to ensure they lie within the representable range\n quantized_data = jnp.clip(quantized_data, 0, 2**7 - 1).astype(jnp.uint8)\n return quantized_data, scale, min_vals", - "detail": "env", - "documentation": {} - }, - { - "label": "dequantize", - "kind": 2, - "importPath": "env", - "description": "env", - "peekOfCode": "def dequantize(\n array_quant: Array,\n scale: Array,\n min_vals: Array,\n float_dtype: jnp.dtype = jnp.float16,\n):\n return (array_quant * scale + min_vals).astype(float_dtype)\n@dataclass\nclass QuantizedArray(ImplicitArray):\n array_quant: ArrayValue", - "detail": "env", - "documentation": {} - }, - { - "label": "get_binop_result_shape_dtype", - "kind": 2, - "importPath": "env", - "description": "env", - "peekOfCode": "def get_binop_result_shape_dtype(a, b):\n out_shape = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(b))\n out_dtype = jnp.result_type(a.dtype, b.dtype)\n return out_shape, out_dtype\n@jax.jit\n@use_implicit_args\ndef f(x, y):\n return (x + y)[0, 0]\ndef main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))", - "detail": "env", - "documentation": {} - }, { "label": "f", "kind": 2, "importPath": "env", "description": "env", - "peekOfCode": "def f(x, y):\n return (x + y)[0, 0]\ndef main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))\n quantized_array = QuantizedArray.quantize(orginal_array)\n print(f(quantized_array, jnp.ones(64)))\n print((orginal_array + jnp.ones(64))[0, 0])\nif __name__ == \"__main__\":\n main()", + "peekOfCode": "def f(x, xp):\n r = x @ xp\n r *= 0.689435468\n r = r @ xp.T\n return r\ndef main():\n x = jax.random.normal(jax.random.key(0), (512, 64), dtype=jnp.float32)\n xp = jax.random.normal(jax.random.key(1), (64, 256), dtype=jnp.float32)\n quantized_x = Array8Bit.quantize(x) # Now it's quantized Array\n quantized_xp = Array8Bit.quantize(xp) # Now it's quantized Array", "detail": "env", "documentation": {} }, @@ -6838,7 +6960,7 @@ "kind": 2, "importPath": "env", "description": "env", - "peekOfCode": "def main():\n orginal_array = jax.random.normal(jax.random.PRNGKey(0), (512, 64))\n quantized_array = QuantizedArray.quantize(orginal_array)\n print(f(quantized_array, jnp.ones(64)))\n print((orginal_array + jnp.ones(64))[0, 0])\nif __name__ == \"__main__\":\n main()", + "peekOfCode": "def main():\n x = jax.random.normal(jax.random.key(0), (512, 64), dtype=jnp.float32)\n xp = jax.random.normal(jax.random.key(1), (64, 256), dtype=jnp.float32)\n quantized_x = Array8Bit.quantize(x) # Now it's quantized Array\n quantized_xp = Array8Bit.quantize(xp) # Now it's quantized Array\n result = f(x, xp)\n q_result = f(quantized_x, quantized_xp)\n print(result[0, 0])\n print(q_result[0, 0])\nif __name__ == \"__main__\":", "detail": "env", "documentation": {} }, diff --git a/src/fjformer/core/__init__.py b/src/fjformer/core/__init__.py index b07cd8c..056dc6c 100644 --- a/src/fjformer/core/__init__.py +++ b/src/fjformer/core/__init__.py @@ -1,6 +1,6 @@ from fjformer.core.implicit_array import ( ImplicitArray as ImplicitArray, - use_implicit_args as use_implicit_args, + implicit_compact as implicit_compact, aux_field as aux_field, UninitializedAval as UninitializedAval, default_handler as default_handler, diff --git a/src/fjformer/core/implicit_array.py b/src/fjformer/core/implicit_array.py index 5b79061..24abdd9 100644 --- a/src/fjformer/core/implicit_array.py +++ b/src/fjformer/core/implicit_array.py @@ -6,7 +6,7 @@ Key components: - ImplicitArray: Abstract base class for symbolic array representations - primitive_handler: Decorator for registering custom primitive handlers -- use_implicit_args: Decorator for functions to accept ImplicitArray arguments +- implicit_compact: Decorator for functions to accept ImplicitArray arguments """ import warnings @@ -27,7 +27,6 @@ from jax.tree_util import register_pytree_with_keys_class from plum import Dispatcher, Function -# Constants and global variables _dispatch = Dispatcher() _primitive_ids = count() @@ -418,7 +417,7 @@ def materialize_nested(implicit_arr, full=False): wrapped = lu.wrap_init(type(implicit_arr).materialize) flat, in_tree = flatten_one_implicit_layer((implicit_arr,)) flat_fn, out_tree = flatten_fun_nokwargs(wrapped, in_tree) - out_flat = use_implicit_args(flat_fn.call_wrapped)(*flat) + out_flat = implicit_compact(flat_fn.call_wrapped)(*flat) implicit_arr = jax.tree_util.tree_unflatten(out_tree(), out_flat) if not full: @@ -455,17 +454,27 @@ def _implicit_inner(main, *in_vals): yield out_vals -def use_implicit_args(f: Callable) -> Callable: +def implicit_compact(f: Callable) -> Callable: """ - Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly - including further ImplicitArray instances as children. - Any number of arguments (including 0) may be ImplicitArrays. + A decorator that enables compact handling of ImplicitArray subclasses within a function. + This allows for seamless integration of custom array types in JAX operations. + + This decorator can be used in combination with jax.jit for optimized execution. Args: f: The function to be decorated. Returns: - A wrapped function that can handle ImplicitArray arguments. + A wrapped function that can handle both regular arrays and ImplicitArray instances. + + Example: + >>> @jax.jit + >>> @implicit_compact + >>> def f(a, b): + ... return jnp.dot(a, b) + + >>> result = f(regular_array, regular_array) + >>> implicit_result = f(implicit_array, implicit_or_normal_array) """ @wraps(f) @@ -544,7 +553,7 @@ class ImplicitArray(_ImplicitArrayBase): Subclasses must implement the materialize method, which defines the relationship between the implicit array and the value it represents. Subclasses are valid - arguments to functions decorated with qax.use_implicit_args. + arguments to functions decorated with core.implicit_compact. The represented shape and dtype may be defined in various ways: 1. Explicitly passing shape/dtype keyword arguments at initialization @@ -665,7 +674,7 @@ def handle_primitive(self, primitive, *args, params): flat_args, in_tree = flatten_one_implicit_layer((args, params)) flat_handler, out_tree = flatten_fun(handler, in_tree) - result = use_implicit_args(flat_handler.call_wrapped)(*flat_args) + result = implicit_compact(flat_handler.call_wrapped)(*flat_args) return jax.tree_util.tree_unflatten(out_tree(), result) def __init_subclass__(cls, commute_ops=True, warn_on_materialize=True, **kwargs): @@ -761,7 +770,7 @@ def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True): else: literals = [] - wrapped_fn = lu.wrap_init(use_implicit_args(partial(core.eval_jaxpr, jaxpr))) + wrapped_fn = lu.wrap_init(implicit_compact(partial(core.eval_jaxpr, jaxpr))) flat_args, in_tree = jax.tree_util.tree_flatten((literals, *vals_with_implicits)) flat_fn, out_tree = flatten_fun_nokwargs(wrapped_fn, in_tree) @@ -779,7 +788,7 @@ def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True): def _transform_jaxpr_output(jaxpr, jaxpr_args, orig_out_struct, out_transform): def eval_fn(literals, *args): - output = use_implicit_args(partial(core.eval_jaxpr, jaxpr.jaxpr))( + output = implicit_compact(partial(core.eval_jaxpr, jaxpr.jaxpr))( literals, *args ) unflattened_output = orig_out_struct.unflatten(output) @@ -877,11 +886,7 @@ def _handle_scan(primitive, *vals, params): xs = vals[n_consts + n_carry :] if any(isinstance(c, ImplicitArray) for c in carries): - warnings.warn( - "ImplicitArray in scan carries are not yet supported." - " If you need this feature please open an issue on the Qax repo:" - " https://github.com/davisyoshida/qax/issues" - ) + warnings.warn("Not Supported Yet.") carries = _materialize_all(carries) sliced_xs = jax.tree_map(partial(jax.eval_shape, lambda x: x[0]), xs) @@ -921,7 +926,7 @@ def _handle_scan(primitive, *vals, params): def materialize_handler(primitive, *vals, params): vals = _materialize_all(vals) subfuns, bind_params = primitive.get_bind_params(params) - result = use_implicit_args(primitive.bind)(*subfuns, *vals, **bind_params) + result = implicit_compact(primitive.bind)(*subfuns, *vals, **bind_params) return result diff --git a/src/fjformer/core/symbols.py b/src/fjformer/core/symbols.py index ca03f03..18cd722 100644 --- a/src/fjformer/core/symbols.py +++ b/src/fjformer/core/symbols.py @@ -21,7 +21,7 @@ aux_field, default_handler, primitive_handler, - use_implicit_args, + implicit_compact, ) from fjformer.core.types import Complement @@ -198,12 +198,12 @@ def copy(self) -> "SymbolicConstant": raise OperationError(f"Failed to copy SymbolicConstant: {str(e)}") -@use_implicit_args +@implicit_compact def broadcast_to(val, shape): return jnp.broadcast_to(val, shape) -@use_implicit_args +@implicit_compact def astype(val, dtype): return val.astype(dtype) @@ -233,7 +233,7 @@ def _op_and_reshape(primitive, lhs, rhs, flip=False): if flip: lhs, rhs = (rhs, lhs) - @use_implicit_args + @implicit_compact def inner(arg): other = lhs if flip: diff --git a/src/fjformer/core/utilities.py b/src/fjformer/core/utilities.py index 6d5a359..829b0cc 100644 --- a/src/fjformer/core/utilities.py +++ b/src/fjformer/core/utilities.py @@ -7,7 +7,7 @@ from jax import tree_util from jax.dtypes import float0 -from fjformer.core.implicit_array import use_implicit_args +from fjformer.core.implicit_array import implicit_compact from fjformer.core.symbols import SymbolicConstant @@ -187,7 +187,7 @@ def apply_updates(params: optax.Params, updates: optax.Updates) -> optax.Params: ) semi_flat_params = update_struct.flatten_up_to(params) - updated_flat = use_implicit_args(optax.apply_updates)( + updated_flat = implicit_compact(optax.apply_updates)( semi_flat_params, updates_flat ) updated = update_struct.unflatten(updated_flat) diff --git a/src/fjformer/custom_array/quantized_8bit.py b/src/fjformer/custom_array/quantized_8bit.py new file mode 100644 index 0000000..4b4a6c0 --- /dev/null +++ b/src/fjformer/custom_array/quantized_8bit.py @@ -0,0 +1,422 @@ +from dataclasses import dataclass +from typing import Optional, Sequence, Union + +import fjformer.core as core +from jax import Array, lax +from jax import numpy as jnp + + +@dataclass +class Array8Bit(core.ImplicitArray): + """ + Custom 8-bit Quantized Array for efficient manipulation of JAX arrays. + + This class provides methods for quantizing and dequantizing JAX arrays to 8-bit + representation, which can significantly reduce memory usage and potentially + improve computation speed for certain operations. + + Attributes: + array_quantized (core.ArrayValue): The quantized array data. + scale (core.ArrayValue): Scaling factors for dequantization. + min_vals (core.ArrayValue): Minimum values used in quantization. + shape (tuple): Shape of the quantized array. + dtype (jnp.dtype): Original dtype of the array before quantization. + + Example: + >>> import jax + >>> import jax.numpy as jnp + + >>> x = jax.random.normal(jax.random.key(0), (512, 64), dtype=jnp.float32) + >>> xp = jax.random.normal(jax.random.key(1), (64, 256), dtype=jnp.float32) + + >>> quantized_x = Array8Bit.quantize(x) + >>> quantized_xp = Array8Bit.quantize(xp) + + >>> @jax.jit + >>> @core.implicit_compact + >>> def f(a, b): + ... return jnp.dot(a, b) + + >>> result = f(x, xp) + >>> q_result = f(quantized_x, quantized_xp) + + >>> print(jnp.allclose(result, q_result, rtol=1e-2, atol=1e-2)) + True + """ + + array_quantized: core.ArrayValue + scale: core.ArrayValue + min_vals: core.ArrayValue + + def materialize(self) -> Array: + """ + Materialize the quantized array back to its original representation. + + Returns: + Array: The dequantized array in its original dtype. + """ + return self.dequantize( + array_quantized=self.array_quantized, + scale=self.scale, + min_vals=self.min_vals, + float_dtype=self.dtype, + ) + + @classmethod + def quantize( + cls, array: Array, axis: int = -1, dtype: Optional[jnp.dtype] = None + ) -> "Array8Bit": + """ + Quantize a JAX array to 8-bit representation. + + Args: + array (Array): The input array to quantize. + axis (int, optional): The axis along which to compute min and max. Defaults to -1. + dtype (jnp.dtype, optional): The desired dtype for the output. If None, uses the input array's dtype. + + Returns: + Array8Bit: The quantized array. + """ + min_vals = jnp.min(array, axis=axis, keepdims=True) + max_vals = jnp.max(array, axis=axis, keepdims=True) + + # Compute the scaling factors + scale = (max_vals - min_vals) / 255 + + # Quantize the data + quantized_data = jnp.round((array - min_vals) / scale) + + # Clip the quantized values to ensure they lie within the representable range + quantized_data = jnp.clip(quantized_data, 0, 255).astype(jnp.uint8) + + return cls( + array_quantized=quantized_data, + scale=scale, + min_vals=min_vals, + shape=quantized_data.shape, + dtype=dtype or array.dtype, + ) + + @staticmethod + def dequantize( + array_quantized: Array, + scale: Array, + min_vals: Array, + float_dtype: jnp.dtype, + ) -> Array: + """ + Dequantize an 8-bit array back to its original representation. + + Args: + array_quantized (Array): The quantized array data. + scale (Array): The scaling factors used in quantization. + min_vals (Array): The minimum values used in quantization. + float_dtype (jnp.dtype): The desired output dtype. + + Returns: + Array: The dequantized array. + """ + return (array_quantized.astype(float_dtype) * scale + min_vals).astype( + float_dtype + ) + + def __getitem__(self, idx: Union[int, slice, tuple]) -> Array: + """ + Enable indexing of the quantized array. + + Args: + idx (Union[int, slice, tuple]): The index or slice to access. + + Returns: + Array: The dequantized slice of the array. + """ + quantized_slice = self.array_quantized[idx] + scale_slice = self.scale[idx] if self.scale.ndim > 0 else self.scale + min_vals_slice = self.min_vals[idx] if self.min_vals.ndim > 0 else self.min_vals + return self.dequantize(quantized_slice, scale_slice, min_vals_slice, self.dtype) + + def __repr__(self) -> str: + return f"Array8Bit(shape={self.shape}, dtype={self.dtype})" + + @property + def nbytes(self) -> int: + """ + Calculate the total number of bytes used by the quantized representation. + + Returns: + int: The number of bytes used. + """ + return self.array_quantized.nbytes + self.scale.nbytes + self.min_vals.nbytes + + def memory_savings(self) -> float: + """ + Calculate the memory savings compared to the original array. + + Returns: + float: The percentage of memory saved. + """ + original_size = jnp.prod(jnp.array(self.shape)) * jnp.dtype(self.dtype).itemsize + return (1 - self.nbytes / original_size) * 100 + + +ArrayType = Union[Array, Array8Bit] + + +@core.primitive_handler("dot_general") +def handle_dot_general(primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's dot_general operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + lhs (ArrayType): Left-hand side array. + rhs (ArrayType): Right-hand side array. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.dot_general operation. + """ + if isinstance(lhs, Array8Bit): + lhs = lhs.materialize() + if isinstance(rhs, Array8Bit): + rhs = rhs.materialize() + return lax.dot_general(lhs=lhs, rhs=rhs, *args, **kwargs) + + +@core.primitive_handler("add") +def handle_add(primitive, x: ArrayType, y: ArrayType): + """ + Custom handler for JAX's add operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + x (ArrayType): First array to add. + y (ArrayType): Second array to add. + + Returns: + The result of lax.add operation. + """ + if isinstance(x, Array8Bit): + x = x.materialize() + if isinstance(y, Array8Bit): + y = y.materialize() + return lax.add(x, y) + + +@core.primitive_handler("reduce") +def handle_reduce( + primitive, operand: ArrayType, init_value: ArrayType, *args, **kwargs +): + """ + Custom handler for JAX's reduce operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + operand (ArrayType): The array to be reduced. + init_value (ArrayType): The initial value for the reduction. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.reduce operation. + """ + if isinstance(operand, Array8Bit): + operand = operand.materialize() + if isinstance(init_value, Array8Bit): + init_value = init_value.materialize() + return lax.reduce(operand, init_value, *args, **kwargs) + + +@core.primitive_handler("mul") +def handle_mul(primitive, x: ArrayType, y: ArrayType): + """ + Custom handler for JAX's mul operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + x (ArrayType): First array to multiply. + y (ArrayType): Second array to multiply. + + Returns: + The result of lax.mul operation. + """ + if isinstance(x, Array8Bit): + x = x.materialize() + if isinstance(y, Array8Bit): + y = y.materialize() + return lax.mul(x, y) + + +@core.primitive_handler("transpose") +def handle_transpose(primitive, operand: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's transpose operation. + + Materializes Array8Bit input before performing the operation. + Re-quantizes the result if the input was Array8Bit. + + Args: + primitive: The JAX primitive being handled. + operand (ArrayType): The array to be transposed. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.transpose operation, potentially re-quantized. + """ + _org_8 = False + if isinstance(operand, Array8Bit): + operand = operand.materialize() + _org_8 = True + operand = lax.transpose(operand, *args, **kwargs) + if _org_8: + operand = Array8Bit.quantize(operand, dtype=operand.dtype) + return operand + + +@core.primitive_handler("conv_general_dilated") +def handle_conv(primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's conv_general_dilated operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + lhs (ArrayType): Left-hand side array (input). + rhs (ArrayType): Right-hand side array (kernel). + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.conv operation. + """ + if isinstance(lhs, Array8Bit): + lhs = lhs.materialize() + if isinstance(rhs, Array8Bit): + rhs = rhs.materialize() + return lax.conv_general_dilated(lhs, rhs, *args, **kwargs) + + +@core.primitive_handler("max") +def handle_max(primitive, x: ArrayType, y: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's max operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + x (ArrayType): First array for max comparison. + y (ArrayType): Second array for max comparison. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.max operation. + """ + if isinstance(x, Array8Bit): + x = x.materialize() + if isinstance(y, Array8Bit): + y = y.materialize() + return lax.max(x, y, *args, **kwargs) + + +@core.primitive_handler("exp") +def handle_exp(primitive, x: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's exp operation. + + Materializes Array8Bit input before performing the operation. + + Args: + primitive: The JAX primitive being handled. + x (ArrayType): The array to apply exponential to. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.exp operation. + """ + if isinstance(x, Array8Bit): + x = x.materialize() + return lax.exp(x, *args, **kwargs) + + +@core.primitive_handler("log") +def handle_log(primitive, x: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's log operation. + + Materializes Array8Bit input before performing the operation. + + Args: + primitive: The JAX primitive being handled. + x (ArrayType): The array to apply logarithm to. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.log operation. + """ + if isinstance(x, Array8Bit): + x = x.materialize() + return lax.log(x, *args, **kwargs) + + +@core.primitive_handler("reshape") +def handle_reshape(primitive, operand: ArrayType, *args, **kwargs): + """ + Custom handler for JAX's reshape operation. + + Materializes Array8Bit input before performing the operation. + Re-quantizes the result if the input was Array8Bit. + + Args: + primitive: The JAX primitive being handled. + operand (ArrayType): The array to be reshaped. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.reshape operation, potentially re-quantized. + """ + _org_8 = False + if isinstance(operand, Array8Bit): + operand = operand.materialize() + _org_8 = True + operand = lax.reshape(operand, *args, **kwargs) + if _org_8: + operand = Array8Bit.quantize(operand, dtype=operand.dtype) + return operand + + +@core.primitive_handler("concatenate") +def handle_concatenate(primitive, operands: Sequence[ArrayType], *args, **kwargs): + """ + Custom handler for JAX's concatenate operation. + + Materializes Array8Bit inputs before performing the operation. + + Args: + primitive: The JAX primitive being handled. + operands (Sequence[ArrayType]): Sequence of arrays to concatenate. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of lax.concatenate operation. + """ + materialized_operands = [ + op.materialize() if isinstance(op, Array8Bit) else op for op in operands + ] + return lax.concatenate(materialized_operands, *args, **kwargs) diff --git a/src/fjformer/lora/lora_core.py b/src/fjformer/lora/lora_core.py index 591f4e6..7917323 100644 --- a/src/fjformer/lora/lora_core.py +++ b/src/fjformer/lora/lora_core.py @@ -24,7 +24,7 @@ class UnsupportedOperationError(LoraError): def lora(f: Any) -> Any: """Decorator for LoRA-compatible functions.""" - return cr.use_implicit_args(f) + return cr.implicit_compact(f) @dataclass diff --git a/src/fjformer/lora/rapture.py b/src/fjformer/lora/rapture.py index f27267e..cf570be 100644 --- a/src/fjformer/lora/rapture.py +++ b/src/fjformer/lora/rapture.py @@ -11,7 +11,7 @@ EmptyNode, materialize_nested, tree_map_with_implicit, - use_implicit_args, + implicit_compact, ) from fjformer.core.utilities import freeze_subtrees, freeze_keys from fjformer.lora.lora_core import LoraWeight @@ -432,7 +432,7 @@ def apply_lora( ) tx = self.wrap_tx(tx=tx, lora_spec=lora_spec) opt_state = tx.init(lora_parameters) - lora_model = use_implicit_args(module) + lora_model = implicit_compact(module) return RaptureModule( lora_opt_state=opt_state,