-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Version
0.0.37
imports are optimized and python 3.12 is supported
- Loading branch information
Showing
5 changed files
with
130 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,123 @@ | ||
from .utils import make_shard_and_gather_fns, get_mesh | ||
from .trainer import TrainArguments, fsdp_train_step, get_training_modules, CausalLMTrainer | ||
__version__ = "0.0.36" | ||
from .serve.torch_serve import ( | ||
PyTorchServer as PyTorchServer, | ||
PytorchServerConfig as PytorchServerConfig | ||
) | ||
from .serve.jax_serve import ( | ||
JAXServer as JAXServer, | ||
JaxServerConfig as JaxServerConfig | ||
) | ||
from .modules.llama.modelling_llama_flax import ( | ||
LlamaConfig as LlamaConfig, | ||
FlaxLlamaForCausalLM as FlaxLlamaForCausalLM, | ||
FlaxLlamaModel as FlaxLlamaModel | ||
) | ||
from .modules.gpt_j.modelling_gpt_j_flax import ( | ||
GPTJConfig as GPTJConfig, | ||
FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, | ||
FlaxGPTJModule as FlaxGPTJModule, | ||
FlaxGPTJModel as FlaxGPTJModel, | ||
FlaxGPTJForCausalLMModule as FlaxGPTJForCausalLMModule | ||
) | ||
from .modules.t5.modelling_t5_flax import ( | ||
T5Config as T5Config, | ||
FlaxT5ForConditionalGeneration as FlaxT5ForConditionalGeneration, | ||
FlaxT5Model as FlaxT5Model | ||
) | ||
from .modules.falcon.modelling_falcon_flax import ( | ||
FalconConfig as FalconConfig, | ||
FlaxFalconModel as FlaxFalconModel, | ||
FlaxFalconForCausalLM as FlaxFalconForCausalLM | ||
) | ||
from .modules.opt.modelling_opt_flax import ( | ||
OPTConfig as OPTConfig, | ||
FlaxOPTForCausalLM as FlaxOPTForCausalLM, | ||
FlaxOPTModel as FlaxOPTModel | ||
) | ||
from .modules.mistral.modelling_mistral_flax import ( | ||
MistralConfig as MistralConfig, | ||
FlaxMistralForCausalLM as FlaxMistralForCausalLM, | ||
FlaxMistralModule as FlaxMistralModule | ||
) | ||
from .modules.palm.modelling_palm_flax import ( | ||
PalmModel as PalmModel, | ||
PalmConfig as PalmConfig, | ||
FlaxPalmForCausalLM as FlaxPalmForCausalLM | ||
) | ||
|
||
from .modules.mosaic_mpt.modelling_mpt_flax import ( | ||
MptConfig as MptConfig, | ||
FlaxMptForCausalLM as FlaxMptForCausalLM, | ||
FlaxMptModel as FlaxMptModel | ||
) | ||
|
||
from .modules.gpt_neo_x.modelling_gpt_neo_x_flax import ( | ||
GPTNeoXConfig as GPTNeoXConfig, | ||
FlaxGPTNeoXModel as FlaxGPTNeoXModel, | ||
FlaxGPTNeoXForCausalLM as FlaxGPTNeoXForCausalLM | ||
) | ||
|
||
from .modules.lucid_transformer.modelling_lt_flax import ( | ||
FlaxLTModel as FlaxLTModel, | ||
FlaxLTModelModule as FlaxLTModelModule, | ||
FlaxLTConfig as FlaxLTConfig, | ||
FlaxLTForCausalLM as FlaxLTForCausalLM | ||
) | ||
|
||
from .utils.utils import ( | ||
get_mesh as get_mesh, | ||
names_in_mesh as names_in_mesh, | ||
get_names_from_partition_spec as get_names_from_partition_spec, | ||
make_shard_and_gather_fns as make_shard_and_gather_fns, | ||
with_sharding_constraint as with_sharding_constraint, | ||
RNG as RNG | ||
) | ||
|
||
from .trainer import ( | ||
CausalLMTrainer, TrainArguments, fsdp_train_step, get_training_modules | ||
) | ||
|
||
from .linen import ( | ||
from_8bit as from_8bit, | ||
Dense8Bit as Dense8Bit, | ||
array_from_8bit as array_from_8bit, | ||
array_to_bit8 as array_to_bit8, | ||
to_8bit as to_8bit | ||
) | ||
from .smi import ( | ||
run as run, | ||
initialise_tracking as initialise_tracking, | ||
get_mem as get_mem | ||
) | ||
|
||
from .transform.llama import ( | ||
llama_from_pretrained as llama_from_pretrained, | ||
llama_convert_flax_to_pt as llama_convert_flax_to_pt, | ||
llama_convert_hf_to_flax_load as llama_convert_hf_to_flax_load, | ||
llama_convert_hf_to_flax as llama_convert_hf_to_flax, | ||
llama_easydel_to_hf as llama_easydel_to_hf | ||
) | ||
from .transform.mpt import ( | ||
mpt_convert_flax_to_pt_1b as mpt_convert_flax_to_pt_1b, | ||
mpt_convert_pt_to_flax_1b as mpt_convert_pt_to_flax_1b, | ||
mpt_convert_pt_to_flax_7b as mpt_convert_pt_to_flax_7b, | ||
mpt_convert_flax_to_pt_7b as mpt_convert_flax_to_pt_7b, | ||
mpt_from_pretrained as mpt_from_pretrained | ||
) | ||
|
||
from .transform.falcon import ( | ||
falcon_convert_pt_to_flax_7b as falcon_convert_pt_to_flax_7b, | ||
falcon_convert_flax_to_pt_7b as falcon_convert_flax_to_pt_7b, | ||
falcon_from_pretrained as falcon_from_pretrained, | ||
falcon_convert_pt_to_flax as falcon_convert_pt_to_flax, | ||
falcon_easydel_to_hf as falcon_easydel_to_hf | ||
) | ||
from .transform.mistral import ( | ||
mistral_convert_hf_to_flax as mistral_convert_hf_to_flax, | ||
mistral_convert_hf_to_flax_load as mistral_convert_hf_to_flax_load, | ||
mistral_convert_flax_to_pt as mistral_convert_flax_to_pt, | ||
mistral_from_pretrained as mistral_from_pretrained, | ||
mistral_convert_pt_to_flax as mistral_convert_pt_to_flax, | ||
mistral_easydel_to_hf as mistral_easydel_to_hf | ||
) | ||
|
||
__version__ = "0.0.37" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from .checker import package_checker, is_jax_available, is_torch_available, is_flax_available, is_tensorflow_available | ||
from .utils import get_mesh, Timers, Timer, prefix_str, prefix_print, names_in_mesh, with_sharding_constraint, \ | ||
get_names_from_parition_spec, make_shard_and_gather_fns, RNG | ||
get_names_from_partition_spec, make_shard_and_gather_fns, RNG | ||
|
||
if is_jax_available(): | ||
from .utils import make_shard_and_gather_fns | ||
else: | ||
make_shard_and_gather_fns = ImportWarning | ||
__all__ = ('package_checker', 'is_torch_available', 'is_tensorflow_available', 'is_jax_available', 'is_flax_available', \ | ||
'make_shard_and_gather_fns', 'get_mesh', "Timers", "Timer", "prefix_str", "prefix_print", "names_in_mesh", | ||
"with_sharding_constraint", "get_names_from_parition_spec", "make_shard_and_gather_fns", "RNG") | ||
"with_sharding_constraint", "get_names_from_partition_spec", "make_shard_and_gather_fns", "RNG") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
setup( | ||
name='EasyDeL', | ||
version='0.0.36', | ||
version='0.0.37', | ||
author='Erfan Zare Chavoshi', | ||
author_email='[email protected]', | ||
description='An open-source library to make training faster and more optimized in Jax/Flax', | ||
|
@@ -20,6 +20,7 @@ | |
'Programming Language :: Python :: 3.9', | ||
'Programming Language :: Python :: 3.10', | ||
'Programming Language :: Python :: 3.11', | ||
'Programming Language :: Python :: 3.12', | ||
], | ||
keywords='machine learning, deep learning, pytorch, jax, flax', | ||
install_requires=[ | ||
|