Skip to content

Commit

Permalink
Version 0.0.37 imports are optimized and python 3.12 is supported
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 15, 2023
1 parent 7910c2c commit 20ca2e6
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 10 deletions.
125 changes: 122 additions & 3 deletions lib/python/EasyDel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,123 @@
from .utils import make_shard_and_gather_fns, get_mesh
from .trainer import TrainArguments, fsdp_train_step, get_training_modules, CausalLMTrainer
__version__ = "0.0.36"
from .serve.torch_serve import (
PyTorchServer as PyTorchServer,
PytorchServerConfig as PytorchServerConfig
)
from .serve.jax_serve import (
JAXServer as JAXServer,
JaxServerConfig as JaxServerConfig
)
from .modules.llama.modelling_llama_flax import (
LlamaConfig as LlamaConfig,
FlaxLlamaForCausalLM as FlaxLlamaForCausalLM,
FlaxLlamaModel as FlaxLlamaModel
)
from .modules.gpt_j.modelling_gpt_j_flax import (
GPTJConfig as GPTJConfig,
FlaxGPTJForCausalLM as FlaxGPTJForCausalLM,
FlaxGPTJModule as FlaxGPTJModule,
FlaxGPTJModel as FlaxGPTJModel,
FlaxGPTJForCausalLMModule as FlaxGPTJForCausalLMModule
)
from .modules.t5.modelling_t5_flax import (
T5Config as T5Config,
FlaxT5ForConditionalGeneration as FlaxT5ForConditionalGeneration,
FlaxT5Model as FlaxT5Model
)
from .modules.falcon.modelling_falcon_flax import (
FalconConfig as FalconConfig,
FlaxFalconModel as FlaxFalconModel,
FlaxFalconForCausalLM as FlaxFalconForCausalLM
)
from .modules.opt.modelling_opt_flax import (
OPTConfig as OPTConfig,
FlaxOPTForCausalLM as FlaxOPTForCausalLM,
FlaxOPTModel as FlaxOPTModel
)
from .modules.mistral.modelling_mistral_flax import (
MistralConfig as MistralConfig,
FlaxMistralForCausalLM as FlaxMistralForCausalLM,
FlaxMistralModule as FlaxMistralModule
)
from .modules.palm.modelling_palm_flax import (
PalmModel as PalmModel,
PalmConfig as PalmConfig,
FlaxPalmForCausalLM as FlaxPalmForCausalLM
)

from .modules.mosaic_mpt.modelling_mpt_flax import (
MptConfig as MptConfig,
FlaxMptForCausalLM as FlaxMptForCausalLM,
FlaxMptModel as FlaxMptModel
)

from .modules.gpt_neo_x.modelling_gpt_neo_x_flax import (
GPTNeoXConfig as GPTNeoXConfig,
FlaxGPTNeoXModel as FlaxGPTNeoXModel,
FlaxGPTNeoXForCausalLM as FlaxGPTNeoXForCausalLM
)

from .modules.lucid_transformer.modelling_lt_flax import (
FlaxLTModel as FlaxLTModel,
FlaxLTModelModule as FlaxLTModelModule,
FlaxLTConfig as FlaxLTConfig,
FlaxLTForCausalLM as FlaxLTForCausalLM
)

from .utils.utils import (
get_mesh as get_mesh,
names_in_mesh as names_in_mesh,
get_names_from_partition_spec as get_names_from_partition_spec,
make_shard_and_gather_fns as make_shard_and_gather_fns,
with_sharding_constraint as with_sharding_constraint,
RNG as RNG
)

from .trainer import (
CausalLMTrainer, TrainArguments, fsdp_train_step, get_training_modules
)

from .linen import (
from_8bit as from_8bit,
Dense8Bit as Dense8Bit,
array_from_8bit as array_from_8bit,
array_to_bit8 as array_to_bit8,
to_8bit as to_8bit
)
from .smi import (
run as run,
initialise_tracking as initialise_tracking,
get_mem as get_mem
)

from .transform.llama import (
llama_from_pretrained as llama_from_pretrained,
llama_convert_flax_to_pt as llama_convert_flax_to_pt,
llama_convert_hf_to_flax_load as llama_convert_hf_to_flax_load,
llama_convert_hf_to_flax as llama_convert_hf_to_flax,
llama_easydel_to_hf as llama_easydel_to_hf
)
from .transform.mpt import (
mpt_convert_flax_to_pt_1b as mpt_convert_flax_to_pt_1b,
mpt_convert_pt_to_flax_1b as mpt_convert_pt_to_flax_1b,
mpt_convert_pt_to_flax_7b as mpt_convert_pt_to_flax_7b,
mpt_convert_flax_to_pt_7b as mpt_convert_flax_to_pt_7b,
mpt_from_pretrained as mpt_from_pretrained
)

from .transform.falcon import (
falcon_convert_pt_to_flax_7b as falcon_convert_pt_to_flax_7b,
falcon_convert_flax_to_pt_7b as falcon_convert_flax_to_pt_7b,
falcon_from_pretrained as falcon_from_pretrained,
falcon_convert_pt_to_flax as falcon_convert_pt_to_flax,
falcon_easydel_to_hf as falcon_easydel_to_hf
)
from .transform.mistral import (
mistral_convert_hf_to_flax as mistral_convert_hf_to_flax,
mistral_convert_hf_to_flax_load as mistral_convert_hf_to_flax_load,
mistral_convert_flax_to_pt as mistral_convert_flax_to_pt,
mistral_from_pretrained as mistral_from_pretrained,
mistral_convert_pt_to_flax as mistral_convert_pt_to_flax,
mistral_easydel_to_hf as mistral_easydel_to_hf
)

__version__ = "0.0.37"
4 changes: 2 additions & 2 deletions lib/python/EasyDel/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .checker import package_checker, is_jax_available, is_torch_available, is_flax_available, is_tensorflow_available
from .utils import get_mesh, Timers, Timer, prefix_str, prefix_print, names_in_mesh, with_sharding_constraint, \
get_names_from_parition_spec, make_shard_and_gather_fns, RNG
get_names_from_partition_spec, make_shard_and_gather_fns, RNG

if is_jax_available():
from .utils import make_shard_and_gather_fns
else:
make_shard_and_gather_fns = ImportWarning
__all__ = ('package_checker', 'is_torch_available', 'is_tensorflow_available', 'is_jax_available', 'is_flax_available', \
'make_shard_and_gather_fns', 'get_mesh', "Timers", "Timer", "prefix_str", "prefix_print", "names_in_mesh",
"with_sharding_constraint", "get_names_from_parition_spec", "make_shard_and_gather_fns", "RNG")
"with_sharding_constraint", "get_names_from_partition_spec", "make_shard_and_gather_fns", "RNG")
6 changes: 3 additions & 3 deletions lib/python/EasyDel/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def gather_fn(tensor):
return shard_fns, gather_fns


def get_names_from_parition_spec(partition_specs):
def get_names_from_partition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
Expand All @@ -67,7 +67,7 @@ def get_names_from_parition_spec(partition_specs):
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))
names.update(get_names_from_partition_spec(item))

return list(names)

Expand All @@ -77,7 +77,7 @@ def names_in_mesh(*names):


def with_sharding_constraint(x, partition_specs):
axis_names = get_names_from_parition_spec(partition_specs)
axis_names = get_names_from_partition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[tool.black]
line-length = 88
target-version = ["py38", "py39", "py310", "py311"]
target-version = ["py38", "py39", "py310", "py311","py312"]

[tool.isort]
line_length = 88
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='EasyDeL',
version='0.0.36',
version='0.0.37',
author='Erfan Zare Chavoshi',
author_email='[email protected]',
description='An open-source library to make training faster and more optimized in Jax/Flax',
Expand All @@ -20,6 +20,7 @@
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
],
keywords='machine learning, deep learning, pytorch, jax, flax',
install_requires=[
Expand Down

0 comments on commit 20ca2e6

Please sign in to comment.