Skip to content

Commit

Permalink
Format python code
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Feb 1, 2023
1 parent b9b3e3c commit d0999c0
Show file tree
Hide file tree
Showing 26 changed files with 156 additions and 146 deletions.
13 changes: 8 additions & 5 deletions build_tools/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
import time

parser = argparse.ArgumentParser(description='Determine setup options.')
parser.add_argument('--use-pinned', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--use-pinned',
default=False,
action=argparse.BooleanOptionalAction)
parser.add_argument('--add-version')

'''Sets values in version_info.json
Sets values related to the libraries version numbers, the pinned versions, and
whether the pinned versions should be exact or a lower bound. This is used for
configuring on demand for bumping or validating pinned versions.
'''


def main():
project_root = os.path.dirname(os.path.dirname(__file__))

Expand All @@ -52,9 +55,9 @@ def main():
print(json.dumps(version_info, indent=2))

with open(os.path.join(project_root, "version_info.json"), 'w') as f:
f.write(json.dumps(version_info, indent=2))
f.write("\n")
f.write(json.dumps(version_info, indent=2))
f.write("\n")


if __name__ == "__main__":
main()

4 changes: 3 additions & 1 deletion build_tools/pip_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import subprocess
import sys


def main():
project_root = os.path.dirname(os.path.dirname(__file__))

Expand All @@ -39,5 +40,6 @@ def main():
_, err = proc.communicate()
print(err)


if __name__ == "__main__":
main()
main()
20 changes: 12 additions & 8 deletions build_tools/update_pinned.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import subprocess
import sys
from subprocess import check_output

'''Update the version_info.json with the currently installed libraries.'''


def main():
reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze'])
reqs = [c.split("==") for c in reqs.decode('ascii').split('\n') if "==" in c]
reqs = {a : b for a, b in reqs }
reqs = {a: b for a, b in reqs}

# Load version_info.json.
project_root = os.path.dirname(os.path.dirname(__file__))
Expand All @@ -36,15 +37,18 @@ def main():

pinned = version_info["pinned-versions"]
for pinned_lib in pinned:
if pinned_lib in reqs:
pinned[pinned_lib] = reqs[pinned_lib]
if pinned_lib in reqs:
pinned[pinned_lib] = reqs[pinned_lib]

version_info = {a : version_info[a] for a in version_info if a == "pinned-versions"}
version_info = {
a: version_info[a] for a in version_info if a == "pinned-versions"
}

with open(os.path.join(project_root, "version_info.json"), 'w') as f:
f.write(json.dumps(version_info, indent=2))
f.write("\n")
f.write(json.dumps(version_info, indent=2))
f.write("\n")
print(pinned["iree-compiler"])

if __name__ == "__main__":

if __name__ == "__main__":
main()
23 changes: 12 additions & 11 deletions examples/aqt_dense_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test compiling and executing a basic AQT MatMul with IREE."""

from collections import namedtuple
Expand All @@ -28,18 +27,19 @@

Params = namedtuple("Params", "weights,bias,activation_scale")
params = [
Params(
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 0.001,
bias=jnp.arange(3, dtype=jnp.float32) * 10.0,
activation_scale=jnp.array(5.0),
),
Params(
weights=jnp.arange(27, dtype=jnp.float32).reshape(3, 9) * 0.01,
bias=jnp.arange(9, dtype=jnp.float32) * 3.0,
activation_scale=jnp.array(5.0),
),
Params(
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 0.001,
bias=jnp.arange(3, dtype=jnp.float32) * 10.0,
activation_scale=jnp.array(5.0),
),
Params(
weights=jnp.arange(27, dtype=jnp.float32).reshape(3, 9) * 0.01,
bias=jnp.arange(9, dtype=jnp.float32) * 3.0,
activation_scale=jnp.array(5.0),
),
]


def dense(params, activation):
precision = 8
lower_bound = -2**(precision - 1) + 1
Expand All @@ -57,6 +57,7 @@ def dense(params, activation):
matmul_result = scaled_result / (params.activation_scale * weight_scale)
return matmul_result + params.bias[jnp.newaxis, :]


class AqtDenseModule(Program):

_params = params
Expand Down
11 changes: 6 additions & 5 deletions examples/aqt_matmul_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test compiling and executing a basic AQT MatMul with IREE."""

from collections import namedtuple
Expand All @@ -28,10 +27,11 @@

Params = namedtuple("Params", "weights,activation_scale")
params = Params(
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3,
activation_scale=jnp.array(5.0),
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3,
activation_scale=jnp.array(5.0),
)


class AqtMatmulModule(Program):

_params = params
Expand All @@ -52,8 +52,9 @@ def aqt_matmul_native(params, activation):
weight_rounded = jnp.floor(weight_scaled + jnp.array(0.5))
weight_as_int = weight_rounded.astype(jnp.int8)

scaled_result = jax.lax.dot(
activation_as_int, weight_as_int, preferred_element_type=jnp.int32)
scaled_result = jax.lax.dot(activation_as_int,
weight_as_int,
preferred_element_type=jnp.int32)
return scaled_result / (params.activation_scale * weight_scale)

def compute_native(mdl, activation=like(activation_example)):
Expand Down
6 changes: 3 additions & 3 deletions examples/aqt_matmul_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test compiling and executing a basic AQT MatMul with IREE."""

from collections import namedtuple
Expand All @@ -28,10 +27,11 @@

Params = namedtuple("Params", "weights,activation_scale")
params = Params(
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3,
activation_scale=jnp.array(5.0),
weights=jnp.arange(18, dtype=jnp.float32).reshape(6, 3) * 500.3,
activation_scale=jnp.array(5.0),
)


class AqtMatmulModule(Program):

_params = params
Expand Down
9 changes: 4 additions & 5 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Datasets used in examples."""


import array
import gzip
import os
Expand All @@ -24,7 +22,6 @@

import numpy as np


_DATA = "/tmp/jax_example_data/"


Expand Down Expand Up @@ -64,8 +61,10 @@ def parse_images(filename):
return np.array(array.array("B", fh.read()),
dtype=np.uint8).reshape(num_data, rows, cols)

for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
for filename in [
"train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"
]:
_download(base_url + filename, filename)

train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
Expand Down
9 changes: 5 additions & 4 deletions examples/mnist_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
register_pytree_node)

from iree.jax import (
like,
kernel,
IREE,
Program,
like,
kernel,
IREE,
Program,
)


Expand All @@ -58,6 +58,7 @@ def main(args):
with open(os.path.join(output_dir, "mnist_train.vmfb"), "wb") as f:
f.write(binary.compiled_artifact)


def build_model():
init_random_params, predict = stax.serial(
Dense(1024),
Expand Down
1 change: 1 addition & 0 deletions examples/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_examples():
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
print(f"Number of batches in dataset: {num_complete_batches}")

def data_stream():
rng = npr.RandomState(0)
while True:
Expand Down
3 changes: 2 additions & 1 deletion iree/jax/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class jit_kernel(tracing.CallableIntrinsic):

def __init__(self, wrapped_f, *, wrap_with_jit: bool = True):
self.wrapped_f = wrapped_f
self.jit_f = jax.jit(self.wrapped_f, backend="iree") if wrap_with_jit else self.wrapped_f
self.jit_f = jax.jit(self.wrapped_f,
backend="iree") if wrap_with_jit else self.wrapped_f

def __repr__(self):
return f"<Exportable Pure Func: {self.wrapped_f}>"
Expand Down
3 changes: 2 additions & 1 deletion iree/jax/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def def_global_tree(self,
# We fork between trackable things and static constants. Currently this
# is just array vs not, but this should match Jax's heuristic.
# TODO: Make sure this is the right way to detect array.
if isinstance(concrete_leaf, jax.core.ShapedArray) or hasattr(concrete_leaf, "__array__"):
if isinstance(concrete_leaf, jax.core.ShapedArray) or hasattr(
concrete_leaf, "__array__"):
leaf_symbol = f"{symbol_name}${tracked_leaf_count}"
logger.debug("def_global_tree: array %s=%r:%r", leaf_symbol,
concrete_leaf.shape, concrete_leaf.dtype)
Expand Down
5 changes: 2 additions & 3 deletions iree/jax/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def aot(function, *args, **options):
"""
xla_comp = jax.xla_computation(function)(*args)
hlo_proto = xla_comp.as_serialized_hlo_module_proto()
return iree.compiler.tools.xla.compile_str(hlo_proto,
input_type=iree.compiler.InputType.XLA,
**options)
return iree.compiler.tools.xla.compile_str(
hlo_proto, input_type=iree.compiler.InputType.XLA, **options)


# A more JAX-native approach to jitting would be desireable here, however
Expand Down
31 changes: 17 additions & 14 deletions iree/jax/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import iree.runtime

_config_cache: Dict[str, iree.runtime.system_api.Config] = dict()

def get_rt_config(driver_name: str):
driver = _config_cache.get(driver_name)
if driver is None:
Expand All @@ -35,8 +36,10 @@ def get_rt_config(driver_name: str):
"IREE",
]


class IREE:
def __init__(self, program: Program, backends : List[str], runtimes : str):

def __init__(self, program: Program, backends: List[str], runtimes: str):
self._program = program
self._backends = backends
self._runtime = runtimes
Expand All @@ -47,20 +50,20 @@ def __init__(self, program: Program, backends : List[str], runtimes : str):
pass

@staticmethod
def compile_program(
program: Program,
backends : List[str] = ["llvm-cpu"] ,
runtime : str = "local-task"):
def compile_program(program: Program,
backends: List[str] = ["llvm-cpu"],
runtime: str = "local-task"):

try:
iree.compiler
iree.compiler
except NameError:
raise Exception("iree.compiler library is required for binary compilation")
raise Exception(
"iree.compiler library is required for binary compilation")

try:
iree.runtime
iree.runtime
except NameError:
raise Exception("iree.runtime library is required for binary compilation")
raise Exception("iree.runtime library is required for binary compilation")

binary = IREE(program, backends, runtime)
binary.compiled_artifact
Expand All @@ -75,31 +78,31 @@ def compiled_artifact(self):
ir_module.operation.write_bytecode(file=output)
bytecode = output.getvalue()
self._compiled_artifact = iree.compiler.tools.compile_str(
bytecode, target_backends=self._backends, input_type="mhlo")
bytecode, target_backends=self._backends, input_type="mhlo")

return self._compiled_artifact

@property
def runtime_module(self):
if not self._runtime_module:
rt_config = get_rt_config(self._runtime)
vm_module = iree.runtime.VmModule.from_flatbuffer(self._instance, self.compiled_artifact)
self._runtime_module = iree.runtime.system_api.load_vm_module(vm_module, rt_config)
vm_module = iree.runtime.VmModule.from_flatbuffer(self._instance,
self.compiled_artifact)
self._runtime_module = iree.runtime.system_api.load_vm_module(
vm_module, rt_config)

info = Program.get_info(Program._get_instance(self._program))
for fun, _ in info.class_info.export_functions:
self._shadow_dict[fun] = self._runtime_module[fun]

return self._runtime_module


def __getattr__(self, name):
try:
return self._shadow_dict[name]
except KeyError as e:
raise AttributeError(f"Attribute {name} not defined") from e


def _create_runtime_trampoline(self, exported_function_name):
"""Creates a runtime trampoline function for the given exported function."""

Expand Down
6 changes: 3 additions & 3 deletions iree/jax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from . import array_types

from jaxlib.mlir import (
ir,
)
ir,)

import jax.core
import jax.interpreters.mlir
Expand Down Expand Up @@ -64,7 +63,8 @@ def abstractify(x) -> jax.core.AbstractValue:

def unwrap_global_array(x) -> Optional[array_types.ExportedGlobalArray]:
# TODO: Ugh. Ugh.
if isinstance(x, jax.core.ConcreteArray) or isinstance(x, jax.core.ShapedArray):
if isinstance(x, jax.core.ConcreteArray) or isinstance(
x, jax.core.ShapedArray):
x = x.val
if not isinstance(x, array_types.ExportedGlobalArray):
return None
Expand Down
Loading

0 comments on commit d0999c0

Please sign in to comment.