diff --git a/build_tools/configure.py b/build_tools/configure.py index 70691e5..b104b45 100644 --- a/build_tools/configure.py +++ b/build_tools/configure.py @@ -20,15 +20,18 @@ import time parser = argparse.ArgumentParser(description='Determine setup options.') -parser.add_argument('--use-pinned', default=False, action=argparse.BooleanOptionalAction) +parser.add_argument('--use-pinned', + default=False, + action=argparse.BooleanOptionalAction) parser.add_argument('--add-version') - '''Sets values in version_info.json Sets values related to the libraries version numbers, the pinned versions, and whether the pinned versions should be exact or a lower bound. This is used for configuring on demand for bumping or validating pinned versions. ''' + + def main(): project_root = os.path.dirname(os.path.dirname(__file__)) @@ -52,9 +55,9 @@ def main(): print(json.dumps(version_info, indent=2)) with open(os.path.join(project_root, "version_info.json"), 'w') as f: - f.write(json.dumps(version_info, indent=2)) - f.write("\n") + f.write(json.dumps(version_info, indent=2)) + f.write("\n") + if __name__ == "__main__": main() - diff --git a/build_tools/pip_install.py b/build_tools/pip_install.py index 045cedb..c553eed 100644 --- a/build_tools/pip_install.py +++ b/build_tools/pip_install.py @@ -17,6 +17,7 @@ import subprocess import sys + def main(): project_root = os.path.dirname(os.path.dirname(__file__)) @@ -39,5 +40,6 @@ def main(): _, err = proc.communicate() print(err) + if __name__ == "__main__": - main() + main() diff --git a/build_tools/update_pinned.py b/build_tools/update_pinned.py index 1f99fca..c3fd626 100644 --- a/build_tools/update_pinned.py +++ b/build_tools/update_pinned.py @@ -17,12 +17,13 @@ import subprocess import sys from subprocess import check_output - '''Update the version_info.json with the currently installed libraries.''' + + def main(): reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze']) reqs = [c.split("==") for c in reqs.decode('ascii').split('\n') if "==" in c] - reqs = {a : b for a, b in reqs } + reqs = {a: b for a, b in reqs} # Load version_info.json. project_root = os.path.dirname(os.path.dirname(__file__)) @@ -36,15 +37,18 @@ def main(): pinned = version_info["pinned-versions"] for pinned_lib in pinned: - if pinned_lib in reqs: - pinned[pinned_lib] = reqs[pinned_lib] + if pinned_lib in reqs: + pinned[pinned_lib] = reqs[pinned_lib] - version_info = {a : version_info[a] for a in version_info if a == "pinned-versions"} + version_info = { + a: version_info[a] for a in version_info if a == "pinned-versions" + } with open(os.path.join(project_root, "version_info.json"), 'w') as f: - f.write(json.dumps(version_info, indent=2)) - f.write("\n") + f.write(json.dumps(version_info, indent=2)) + f.write("\n") print(pinned["iree-compiler"]) -if __name__ == "__main__": + +if __name__ == "__main__": main() diff --git a/examples/aqt_dense_simulated.py b/examples/aqt_dense_simulated.py index 7ba6b49..debb0b8 100644 --- a/examples/aqt_dense_simulated.py +++ b/examples/aqt_dense_simulated.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test compiling and executing a basic AQT MatMul with IREE.""" from collections import namedtuple @@ -28,18 +27,19 @@ Params = namedtuple("Params", "weights,bias,activation_scale") params = [ - Params( - weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 0.001, - bias=jnp.arange(3, dtype=jnp.float32) * 10.0, - activation_scale=jnp.array(5.0), - ), - Params( - weights=jnp.arange(27, dtype=jnp.float32).reshape(3, 9) * 0.01, - bias=jnp.arange(9, dtype=jnp.float32) * 3.0, - activation_scale=jnp.array(5.0), - ), + Params( + weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 0.001, + bias=jnp.arange(3, dtype=jnp.float32) * 10.0, + activation_scale=jnp.array(5.0), + ), + Params( + weights=jnp.arange(27, dtype=jnp.float32).reshape(3, 9) * 0.01, + bias=jnp.arange(9, dtype=jnp.float32) * 3.0, + activation_scale=jnp.array(5.0), + ), ] + def dense(params, activation): precision = 8 lower_bound = -2**(precision - 1) + 1 @@ -57,6 +57,7 @@ def dense(params, activation): matmul_result = scaled_result / (params.activation_scale * weight_scale) return matmul_result + params.bias[jnp.newaxis, :] + class AqtDenseModule(Program): _params = params diff --git a/examples/aqt_matmul_native.py b/examples/aqt_matmul_native.py index e9f4a38..43b1c9c 100644 --- a/examples/aqt_matmul_native.py +++ b/examples/aqt_matmul_native.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test compiling and executing a basic AQT MatMul with IREE.""" from collections import namedtuple @@ -28,10 +27,11 @@ Params = namedtuple("Params", "weights,activation_scale") params = Params( - weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3, - activation_scale=jnp.array(5.0), + weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3, + activation_scale=jnp.array(5.0), ) + class AqtMatmulModule(Program): _params = params @@ -52,8 +52,9 @@ def aqt_matmul_native(params, activation): weight_rounded = jnp.floor(weight_scaled + jnp.array(0.5)) weight_as_int = weight_rounded.astype(jnp.int8) - scaled_result = jax.lax.dot( - activation_as_int, weight_as_int, preferred_element_type=jnp.int32) + scaled_result = jax.lax.dot(activation_as_int, + weight_as_int, + preferred_element_type=jnp.int32) return scaled_result / (params.activation_scale * weight_scale) def compute_native(mdl, activation=like(activation_example)): diff --git a/examples/aqt_matmul_simulated.py b/examples/aqt_matmul_simulated.py index d7281e8..46fb285 100644 --- a/examples/aqt_matmul_simulated.py +++ b/examples/aqt_matmul_simulated.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test compiling and executing a basic AQT MatMul with IREE.""" from collections import namedtuple @@ -28,10 +27,11 @@ Params = namedtuple("Params", "weights,activation_scale") params = Params( - weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3, - activation_scale=jnp.array(5.0), + weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3, + activation_scale=jnp.array(5.0), ) + class AqtMatmulModule(Program): _params = params diff --git a/examples/datasets.py b/examples/datasets.py index c8a7994..cc85982 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Datasets used in examples.""" - import array import gzip import os @@ -24,7 +22,6 @@ import numpy as np - _DATA = "/tmp/jax_example_data/" @@ -64,8 +61,10 @@ def parse_images(filename): return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols) - for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", - "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]: + for filename in [ + "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" + ]: _download(base_url + filename, filename) train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) diff --git a/examples/mnist_export.py b/examples/mnist_export.py index d2dfa98..374a6ca 100644 --- a/examples/mnist_export.py +++ b/examples/mnist_export.py @@ -35,10 +35,10 @@ register_pytree_node) from iree.jax import ( - like, - kernel, - IREE, - Program, + like, + kernel, + IREE, + Program, ) @@ -58,6 +58,7 @@ def main(args): with open(os.path.join(output_dir, "mnist_train.vmfb"), "wb") as f: f.write(binary.compiled_artifact) + def build_model(): init_random_params, predict = stax.serial( Dense(1024), diff --git a/examples/run_trainer.py b/examples/run_trainer.py index 33c17a5..e81d634 100644 --- a/examples/run_trainer.py +++ b/examples/run_trainer.py @@ -73,6 +73,7 @@ def get_examples(): num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) print(f"Number of batches in dataset: {num_complete_batches}") + def data_stream(): rng = npr.RandomState(0) while True: diff --git a/iree/jax/builtins.py b/iree/jax/builtins.py index bed4e92..25769fc 100644 --- a/iree/jax/builtins.py +++ b/iree/jax/builtins.py @@ -48,7 +48,8 @@ class jit_kernel(tracing.CallableIntrinsic): def __init__(self, wrapped_f, *, wrap_with_jit: bool = True): self.wrapped_f = wrapped_f - self.jit_f = jax.jit(self.wrapped_f, backend="iree") if wrap_with_jit else self.wrapped_f + self.jit_f = jax.jit(self.wrapped_f, + backend="iree") if wrap_with_jit else self.wrapped_f def __repr__(self): return f"" diff --git a/iree/jax/exporter.py b/iree/jax/exporter.py index 5e199bf..38be678 100644 --- a/iree/jax/exporter.py +++ b/iree/jax/exporter.py @@ -143,7 +143,8 @@ def def_global_tree(self, # We fork between trackable things and static constants. Currently this # is just array vs not, but this should match Jax's heuristic. # TODO: Make sure this is the right way to detect array. - if isinstance(concrete_leaf, jax.core.ShapedArray) or hasattr(concrete_leaf, "__array__"): + if isinstance(concrete_leaf, jax.core.ShapedArray) or hasattr( + concrete_leaf, "__array__"): leaf_symbol = f"{symbol_name}${tracked_leaf_count}" logger.debug("def_global_tree: array %s=%r:%r", leaf_symbol, concrete_leaf.shape, concrete_leaf.dtype) diff --git a/iree/jax/frontend.py b/iree/jax/frontend.py index 6e38e27..4617a94 100644 --- a/iree/jax/frontend.py +++ b/iree/jax/frontend.py @@ -65,9 +65,8 @@ def aot(function, *args, **options): """ xla_comp = jax.xla_computation(function)(*args) hlo_proto = xla_comp.as_serialized_hlo_module_proto() - return iree.compiler.tools.xla.compile_str(hlo_proto, - input_type=iree.compiler.InputType.XLA, - **options) + return iree.compiler.tools.xla.compile_str( + hlo_proto, input_type=iree.compiler.InputType.XLA, **options) # A more JAX-native approach to jitting would be desireable here, however diff --git a/iree/jax/iree.py b/iree/jax/iree.py index 66418c3..e006be6 100644 --- a/iree/jax/iree.py +++ b/iree/jax/iree.py @@ -21,6 +21,7 @@ import iree.runtime _config_cache: Dict[str, iree.runtime.system_api.Config] = dict() + def get_rt_config(driver_name: str): driver = _config_cache.get(driver_name) if driver is None: @@ -35,8 +36,10 @@ def get_rt_config(driver_name: str): "IREE", ] + class IREE: - def __init__(self, program: Program, backends : List[str], runtimes : str): + + def __init__(self, program: Program, backends: List[str], runtimes: str): self._program = program self._backends = backends self._runtime = runtimes @@ -47,20 +50,20 @@ def __init__(self, program: Program, backends : List[str], runtimes : str): pass @staticmethod - def compile_program( - program: Program, - backends : List[str] = ["llvm-cpu"] , - runtime : str = "local-task"): + def compile_program(program: Program, + backends: List[str] = ["llvm-cpu"], + runtime: str = "local-task"): try: - iree.compiler + iree.compiler except NameError: - raise Exception("iree.compiler library is required for binary compilation") + raise Exception( + "iree.compiler library is required for binary compilation") try: - iree.runtime + iree.runtime except NameError: - raise Exception("iree.runtime library is required for binary compilation") + raise Exception("iree.runtime library is required for binary compilation") binary = IREE(program, backends, runtime) binary.compiled_artifact @@ -75,7 +78,7 @@ def compiled_artifact(self): ir_module.operation.write_bytecode(file=output) bytecode = output.getvalue() self._compiled_artifact = iree.compiler.tools.compile_str( - bytecode, target_backends=self._backends, input_type="mhlo") + bytecode, target_backends=self._backends, input_type="mhlo") return self._compiled_artifact @@ -83,8 +86,10 @@ def compiled_artifact(self): def runtime_module(self): if not self._runtime_module: rt_config = get_rt_config(self._runtime) - vm_module = iree.runtime.VmModule.from_flatbuffer(self._instance, self.compiled_artifact) - self._runtime_module = iree.runtime.system_api.load_vm_module(vm_module, rt_config) + vm_module = iree.runtime.VmModule.from_flatbuffer(self._instance, + self.compiled_artifact) + self._runtime_module = iree.runtime.system_api.load_vm_module( + vm_module, rt_config) info = Program.get_info(Program._get_instance(self._program)) for fun, _ in info.class_info.export_functions: @@ -92,14 +97,12 @@ def runtime_module(self): return self._runtime_module - def __getattr__(self, name): try: return self._shadow_dict[name] except KeyError as e: raise AttributeError(f"Attribute {name} not defined") from e - def _create_runtime_trampoline(self, exported_function_name): """Creates a runtime trampoline function for the given exported function.""" diff --git a/iree/jax/jax_utils.py b/iree/jax/jax_utils.py index 3dc7425..a0d3190 100644 --- a/iree/jax/jax_utils.py +++ b/iree/jax/jax_utils.py @@ -18,8 +18,7 @@ from . import array_types from jaxlib.mlir import ( - ir, -) + ir,) import jax.core import jax.interpreters.mlir @@ -64,7 +63,8 @@ def abstractify(x) -> jax.core.AbstractValue: def unwrap_global_array(x) -> Optional[array_types.ExportedGlobalArray]: # TODO: Ugh. Ugh. - if isinstance(x, jax.core.ConcreteArray) or isinstance(x, jax.core.ShapedArray): + if isinstance(x, jax.core.ConcreteArray) or isinstance( + x, jax.core.ShapedArray): x = x.val if not isinstance(x, array_types.ExportedGlobalArray): return None diff --git a/iree/jax/program_api.py b/iree/jax/program_api.py index d9ad577..b1106c8 100644 --- a/iree/jax/program_api.py +++ b/iree/jax/program_api.py @@ -257,6 +257,7 @@ def __init__(self, class_info: ProgramClassInfo, # Program instance itself arbitrates access via getattr/setattr. self.shadow_dict = dict() + ################################################################################ # Use weak references to track info objects for program classes and instances ################################################################################ diff --git a/iree/jax/tracing.py b/iree/jax/tracing.py index 5ed89ff..10df45c 100644 --- a/iree/jax/tracing.py +++ b/iree/jax/tracing.py @@ -19,8 +19,7 @@ from jaxlib.mlir import ( ir,) from jaxlib.mlir.dialects import ( - func as func_d, -) + func as func_d,) import jax.core from jax.tree_util import (tree_map, tree_flatten, tree_unflatten) diff --git a/iree/jax/utils.py b/iree/jax/utils.py index c9afa7c..a111ef6 100644 --- a/iree/jax/utils.py +++ b/iree/jax/utils.py @@ -13,5 +13,3 @@ # limitations under the License. import weakref - - diff --git a/models/gpt2/config.py b/models/gpt2/config.py index 94d7bd1..94856ae 100644 --- a/models/gpt2/config.py +++ b/models/gpt2/config.py @@ -11,11 +11,12 @@ flags.DEFINE_string('ir_path', '/tmp/gpt2.mlir', 'Path for IR') flags.DEFINE_string('assets_path', assets_dir, 'Path for assets dir') + # Create a tuple with model configuration details as follows: # B - batch size # K - encoder sequence length # S - total sequence length # T - decode step size def get_config(): - config = collections.namedtuple('Config', ['B', 'K', 'S', 'T']) - return config(4, 8, 64, 1) \ No newline at end of file + config = collections.namedtuple('Config', ['B', 'K', 'S', 'T']) + return config(4, 8, 64, 1) diff --git a/models/gpt2/export.py b/models/gpt2/export.py index 61dedd5..6a606bf 100644 --- a/models/gpt2/export.py +++ b/models/gpt2/export.py @@ -18,6 +18,7 @@ FLAGS = absl.flags.FLAGS + # Configuration details. # B - batch size # K - encoder sequence length @@ -26,7 +27,7 @@ def CreateGpt2Model(name, B, K, S, T): L, _, _, Q, H, _ = model.model_sizes[name] - prompt_type = ShapedArray((B,K), dtype=jnp.int32) + prompt_type = ShapedArray((B, K), dtype=jnp.int32) t_type = ShapedArray((B,), dtype=jnp.int32) x_type = ShapedArray((B, T), dtype=jnp.int32) kv_type = model.init_kv(B, S, L, Q, H, dtype=jnp.float32, abstract=True) @@ -75,23 +76,24 @@ def decode(self): return Gpt2Module + def main(argv): cfg = config.get_config() - B = cfg.B # Batch size - K = cfg.K # Input sequence length + B = cfg.B # Batch size + K = cfg.K # Input sequence length S = cfg.S - T = cfg.T # Batched decode + T = cfg.T # Batched decode module = CreateGpt2Model("gpt2", B, K, S, T) with open(FLAGS.ir_path, 'w') as f: f.write(str(Program.get_mlir_module(module))) - compiler.compile_file( - FLAGS.ir_path, - input_type="mhlo", - output_file=FLAGS.binary_path, - target_backends=["llvm-cpu"]) + compiler.compile_file(FLAGS.ir_path, + input_type="mhlo", + output_file=FLAGS.binary_path, + target_backends=["llvm-cpu"]) + if __name__ == '__main__': absl.app.run(main) diff --git a/models/gpt2/model.py b/models/gpt2/model.py index 3d36e69..f0ef282 100644 --- a/models/gpt2/model.py +++ b/models/gpt2/model.py @@ -40,12 +40,9 @@ def load_gpt2_model(name, gpt2_dir): w_i_bias = np.asarray(layer_root['mlp']['c_fc']['bias:0']) w_o = np.asarray(layer_root['mlp']['c_proj']['weight:0']) w_o_bias = np.asarray(layer_root['mlp']['c_proj']['bias:0']) - layer_params.append(((xnorm_scale, xnorm_bias), - (wqkv, wqkv_bias), - (wo, wo_bias), - (ynorm_scale, ynorm_bias), - (w_i, w_i_bias), - (w_o, w_o_bias))) + layer_params.append( + ((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias), + (ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias))) params.append(layer_params) fnorm_scale = np.asarray(root['ln_f']['gamma:0']) fnorm_bias = np.asarray(root['ln_f']['beta:0']) @@ -67,12 +64,8 @@ def init_layer(E, F, Q, H, dtype): w_i_bias = jnp.ones((F,), dtype=dtype) w_o = jnp.ones((F, E), dtype=dtype) w_o_bias = jnp.ones((E,), dtype=dtype) - return ((xnorm_scale, xnorm_bias), - (wqkv, wqkv_bias), - (wo, wo_bias), - (ynorm_scale, ynorm_bias), - (w_i, w_i_bias), - (w_o, w_o_bias)) + return ((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias), + (ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias)) def init(L, E, F, Q, H, V, dtype): @@ -86,19 +79,18 @@ def init(L, E, F, Q, H, V, dtype): def init_kv(B, S, L, Q, H, dtype, abstract=False): if abstract: - ret = [abstract_arrays.ShapedArray((2, B, S, H, Q), dtype=dtype) for l in range(L)] + ret = [ + abstract_arrays.ShapedArray((2, B, S, H, Q), dtype=dtype) + for l in range(L) + ] return ret return [jnp.zeros((2, B, S, H, Q), dtype=dtype) for l in range(L)] def fprop_layer(params, kv, x, t0, i, mask): """Run a single transformer layer.""" - ((xnorm_scale, xnorm_bias), - (wqkv, wqkv_bias), - (wo, wo_bias), - (ynorm_scale, ynorm_bias), - (w_i, w_i_bias), - (w_o, w_o_bias)) = params + ((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias), + (ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias)) = params # x = with_sharding_constraint(x, x_sharding) xnorm = jax.nn.normalize(x) * xnorm_scale + xnorm_bias qkv = jnp.einsum('bte,ihqe->ibthq', xnorm, wqkv) + wqkv_bias[:, None, None] @@ -109,11 +101,11 @@ def fprop_layer(params, kv, x, t0, i, mask): new_kv = mask[None, :, :, None, None] * new_kv q = q * mask[:, :, None, None] kv = jax.lax.dynamic_update_slice(kv, new_kv, [0, i, 0, 0, 0]) - k, v = jax.lax.dynamic_slice(kv, [0, i, 0, 0, 0], [2, x.shape[0], *kv.shape[2:]]) + k, v = jax.lax.dynamic_slice(kv, [0, i, 0, 0, 0], + [2, x.shape[0], *kv.shape[2:]]) elif t0 is not None: # "decoding" a single timestep - kv = jax.vmap(jax.lax.dynamic_update_slice, - (1, 1, [None, 0, None, None]), + kv = jax.vmap(jax.lax.dynamic_update_slice, (1, 1, [None, 0, None, None]), 1)(kv, new_kv, [0, t0, 0, 0]) k, v = kv else: @@ -123,12 +115,11 @@ def fprop_layer(params, kv, x, t0, i, mask): jnp.sqrt(v.shape[-1]), dtype=x.dtype) # s refers to timestep attended to; t refers to timestep attending s = jnp.arange(outer.shape[2])[None, None, :] - t = (0 if t0 is None else t0[:, None, None] - ) + jnp.arange(outer.shape[1])[None, :, None] + t = (0 if t0 is None else t0[:, None, None]) + jnp.arange( + outer.shape[1])[None, :, None] if i is not None or t0 is not None: invalid = t < s - outer = outer - jnp.asarray( - jnp.inf, dtype=x.dtype) * invalid[:, :, :, None] + outer = outer - jnp.asarray(jnp.inf, dtype=x.dtype) * invalid[:, :, :, None] alpha = jax.nn.softmax(outer, 2) inner = jnp.einsum('btsh,bshq->bthq', alpha, v) y = jnp.einsum('bthq,hqe->bte', inner, wo) + wo_bias + x @@ -141,21 +132,19 @@ def fprop_layer(params, kv, x, t0, i, mask): def embed(embedding, x): - return jax.vmap(jax.vmap( - lambda emb, inputs: emb[inputs], - (None, 0), 0), (None, 0), 0)(embedding, x) + return jax.vmap(jax.vmap(lambda emb, inputs: emb[inputs], (None, 0), 0), + (None, 0), 0)(embedding, x) def fprop(params, kv, x, t0, i, mask): (wte, wpe, layer_params, (fnorm_scale, fnorm_bias)) = params - x = embed(wte, x) + embed(wpe, (0 if t0 is None else t0[:, None] - ) + jnp.arange( - x.shape[1], dtype=x.dtype)[None, :]) + x = embed(wte, x) + embed(wpe, (0 if t0 is None else t0[:, None]) + + jnp.arange(x.shape[1], dtype=x.dtype)[None, :]) if mask is not None: x = jnp.where(mask[:, :, None], x, 0) for l in range(len(layer_params)): - kv[l], x = jax.named_call(fprop_layer, name=f'L_{l}')( - layer_params[l], kv[l], x, t0, i, mask) + kv[l], x = jax.named_call(fprop_layer, name=f'L_{l}')(layer_params[l], + kv[l], x, t0, i, mask) x = jax.nn.normalize(x) * fnorm_scale + fnorm_bias return kv, x @@ -172,13 +161,15 @@ def encode(params, kv, prompt, i, t): mask = jnp.where(iota < length, 1, 0) kv, y = fprop(params, kv, prompt, jnp.array([0], dtype=jnp.int32), i, mask) - y = y [jnp.arange(t.shape[0]), t-1, :][:, None, :] + y = y[jnp.arange(t.shape[0]), t - 1, :][:, None, :] return kv, greedy(y, params[0]) + @jax.jit def encode_batch(params, x): return fprop(params, [None] * len(params[2]), x, None, None, None)[1] + @functools.partial(jax.jit) def decode(params, kv, x, t): _, T = x.shape @@ -189,8 +180,8 @@ def decode(params, kv, x, t): # order is (L=length, E=embed, F=ffn, Q=qkv, H=heads, V=vocab) model_sizes = { - 'gpt2': (12, 768, 4*768, 64, 768//64, 50257), - '355m': (24, 1024, 4*1024, 64, 1024//64, 51200), - 'gpt2-xl': (48, 1600, 4*1600, 64, 1600//64, 50257), - '52b': (64, 8192, 4*8192, 64, 8192//64, 51200), + 'gpt2': (12, 768, 4 * 768, 64, 768 // 64, 50257), + '355m': (24, 1024, 4 * 1024, 64, 1024 // 64, 51200), + 'gpt2-xl': (48, 1600, 4 * 1600, 64, 1600 // 64, 50257), + '52b': (64, 8192, 4 * 8192, 64, 8192 // 64, 51200), } diff --git a/models/gpt2/test_export.py b/models/gpt2/test_export.py index cb95662..ef3b105 100644 --- a/models/gpt2/test_export.py +++ b/models/gpt2/test_export.py @@ -13,22 +13,22 @@ import config - FLAGS = flags.FLAGS + class ExportedModelTest(absltest.TestCase): + def setUp(self): gpt2_dir = FLAGS.assets_path - self.tokenizer = GPT2Tokenizer( - vocab_file=path.join(gpt2_dir, 'vocab.json'), - merges_file=path.join(gpt2_dir, 'merges.txt')) + self.tokenizer = GPT2Tokenizer(vocab_file=path.join(gpt2_dir, 'vocab.json'), + merges_file=path.join( + gpt2_dir, 'merges.txt')) self.tokenize = self.tokenizer.encode with open(FLAGS.binary_path, 'rb') as f: config = iree_rt.Config("local-task") context = iree_rt.SystemContext(config=config) - vm_module = iree_rt.VmModule.from_flatbuffer( - config.vm_instance, f.read()) + vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, f.read()) context.add_vm_module(vm_module) self.module = context.modules.gpt2_module self.encode = self.module.encode @@ -96,5 +96,6 @@ def test_batch(self): self.assertEqual(y0, e0) self.assertEqual(y1, e1) + if __name__ == '__main__': absltest.main() diff --git a/models/gpt2/test_jax.py b/models/gpt2/test_jax.py index 69375d9..f05ccdf 100644 --- a/models/gpt2/test_jax.py +++ b/models/gpt2/test_jax.py @@ -23,6 +23,7 @@ FLAGS = absl.flags.FLAGS + def load_gpt2_tokenizer(): builtins.open, tmp_open = open, builtins.open gpt2_dir = FLAGS.assets_path @@ -32,6 +33,7 @@ def load_gpt2_tokenizer(): builtins.open = tmp_open return tokenizer + class GPT2RealWeightsTest(parameterized.TestCase): def setUp(self): @@ -41,8 +43,7 @@ def setUp(self): self.params = model.load_gpt2_model(self.model_name, gpt2_dir) super().setUp() - @parameterized.parameters(*itertools.product( - ["cpu", "iree"])) + @parameterized.parameters(*itertools.product(["cpu", "iree"])) def test_batch_one(self, backend): dtype = jnp.float32 S = 64 @@ -68,5 +69,6 @@ def test_batch_one(self, backend): self.assertEqual(self.tokenizer.decode(int(x1[0, 0])), ' six') self.assertEqual(self.tokenizer.decode(int(x2[0, 0])), ' seven') + if __name__ == '__main__': absltest.main() diff --git a/setup.py b/setup.py index 714bc70..6cb54ef 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def load_version_info(): PACKAGE_VERSION = version_info.get("package-version") or "0.1dev1" + def get_pinned_package(name): pinned_versions = version_info.get("pinned-versions") use_pinned = version_info.get("use-pinned") @@ -44,6 +45,7 @@ def get_pinned_package(name): restriction = "==" if use_pinned else ">=" return f"{name}{restriction}{pinned_versions[name]}" + setup( name=f"iree-jax", version=f"{PACKAGE_VERSION}", @@ -59,11 +61,7 @@ def get_pinned_package(name): get_pinned_package("jaxlib"), ], extras_require={ - "xla": [ - get_pinned_package("iree-tools-xla"), - ], - "test": [ - "lit", - ] + "xla": [get_pinned_package("iree-tools-xla"),], + "test": ["lit",] }, ) diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index 2347566..8a9711f 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -48,23 +48,24 @@ # Find a suitable filecheck. filecheck_exe = None if filecheck_exe is None: - filecheck_exe = shutil.which("FileCheck") - if filecheck_exe: - print(f"Using LLVM FileCheck: {filecheck_exe}") + filecheck_exe = shutil.which("FileCheck") + if filecheck_exe: + print(f"Using LLVM FileCheck: {filecheck_exe}") if filecheck_exe is None: - filecheck_exe = shutil.which("filecheck") - if filecheck_exe: - print(f"Using pure python filecheck: {filecheck_exe}") + filecheck_exe = shutil.which("filecheck") + if filecheck_exe: + print(f"Using pure python filecheck: {filecheck_exe}") if filecheck_exe is not None: - config.substitutions.extend([ - ('FileCheck', filecheck_exe), - ]) + config.substitutions.extend([ + ('FileCheck', filecheck_exe), + ]) else: - print("FileCheck not found " - "(install pure python version with 'pip install filecheck')") + print("FileCheck not found " + "(install pure python version with 'pip install filecheck')") project_root = os.path.dirname(os.path.dirname(__file__)) -lit.llvm.llvm_config.with_environment( - "PYTHONPATH", project_root, append_path=True) +lit.llvm.llvm_config.with_environment("PYTHONPATH", + project_root, + append_path=True) config.environment["FILECHECK_OPTS"] = "--dump-input=fail" diff --git a/tests/program/fft.py b/tests/program/fft.py index 7fb904e..311c807 100644 --- a/tests/program/fft.py +++ b/tests/program/fft.py @@ -27,6 +27,7 @@ x = np.ones((1, 512), dtype=jnp.float32) + class FFT(Program): def fft(self, x=like(x)): diff --git a/tests/program/trivial_globals.py b/tests/program/trivial_globals.py index 18b0e76..6a06980 100644 --- a/tests/program/trivial_globals.py +++ b/tests/program/trivial_globals.py @@ -67,4 +67,3 @@ def set_params(self, new_params=like(params)): # CHECK-SAME-DAG: _params$1 # CHECK-SAME-DAG: %arg1 print(Program.get_mlir_module(instance)) -