From 83c92c2a113983bba1059249e6e7568c089d8788 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 10:32:53 +0800 Subject: [PATCH 1/5] FLUX: rename flux to mlx_flux --- flux/dreambooth.py | 2 +- flux/{flux => mlx_flux}/__init__.py | 0 flux/{flux => mlx_flux}/autoencoder.py | 0 flux/{flux => mlx_flux}/clip.py | 0 flux/{flux => mlx_flux}/datasets.py | 0 flux/{flux => mlx_flux}/flux.py | 0 flux/{flux => mlx_flux}/layers.py | 0 flux/{flux => mlx_flux}/lora.py | 0 flux/{flux => mlx_flux}/model.py | 0 flux/{flux => mlx_flux}/sampler.py | 0 flux/{flux => mlx_flux}/t5.py | 0 flux/{flux => mlx_flux}/tokenizers.py | 0 flux/{flux => mlx_flux}/trainer.py | 0 flux/{flux => mlx_flux}/utils.py | 0 flux/txt2image.py | 2 +- 15 files changed, 2 insertions(+), 2 deletions(-) rename flux/{flux => mlx_flux}/__init__.py (100%) rename flux/{flux => mlx_flux}/autoencoder.py (100%) rename flux/{flux => mlx_flux}/clip.py (100%) rename flux/{flux => mlx_flux}/datasets.py (100%) rename flux/{flux => mlx_flux}/flux.py (100%) rename flux/{flux => mlx_flux}/layers.py (100%) rename flux/{flux => mlx_flux}/lora.py (100%) rename flux/{flux => mlx_flux}/model.py (100%) rename flux/{flux => mlx_flux}/sampler.py (100%) rename flux/{flux => mlx_flux}/t5.py (100%) rename flux/{flux => mlx_flux}/tokenizers.py (100%) rename flux/{flux => mlx_flux}/trainer.py (100%) rename flux/{flux => mlx_flux}/utils.py (100%) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..9dcaffb3 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -13,7 +13,7 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from flux import FluxPipeline, Trainer, load_dataset +from mlx_flux import FluxPipeline, Trainer, load_dataset def generate_progress_images(iteration, flux, args): diff --git a/flux/flux/__init__.py b/flux/mlx_flux/__init__.py similarity index 100% rename from flux/flux/__init__.py rename to flux/mlx_flux/__init__.py diff --git a/flux/flux/autoencoder.py b/flux/mlx_flux/autoencoder.py similarity index 100% rename from flux/flux/autoencoder.py rename to flux/mlx_flux/autoencoder.py diff --git a/flux/flux/clip.py b/flux/mlx_flux/clip.py similarity index 100% rename from flux/flux/clip.py rename to flux/mlx_flux/clip.py diff --git a/flux/flux/datasets.py b/flux/mlx_flux/datasets.py similarity index 100% rename from flux/flux/datasets.py rename to flux/mlx_flux/datasets.py diff --git a/flux/flux/flux.py b/flux/mlx_flux/flux.py similarity index 100% rename from flux/flux/flux.py rename to flux/mlx_flux/flux.py diff --git a/flux/flux/layers.py b/flux/mlx_flux/layers.py similarity index 100% rename from flux/flux/layers.py rename to flux/mlx_flux/layers.py diff --git a/flux/flux/lora.py b/flux/mlx_flux/lora.py similarity index 100% rename from flux/flux/lora.py rename to flux/mlx_flux/lora.py diff --git a/flux/flux/model.py b/flux/mlx_flux/model.py similarity index 100% rename from flux/flux/model.py rename to flux/mlx_flux/model.py diff --git a/flux/flux/sampler.py b/flux/mlx_flux/sampler.py similarity index 100% rename from flux/flux/sampler.py rename to flux/mlx_flux/sampler.py diff --git a/flux/flux/t5.py b/flux/mlx_flux/t5.py similarity index 100% rename from flux/flux/t5.py rename to flux/mlx_flux/t5.py diff --git a/flux/flux/tokenizers.py b/flux/mlx_flux/tokenizers.py similarity index 100% rename from flux/flux/tokenizers.py rename to flux/mlx_flux/tokenizers.py diff --git a/flux/flux/trainer.py b/flux/mlx_flux/trainer.py similarity index 100% rename from flux/flux/trainer.py rename to flux/mlx_flux/trainer.py diff --git a/flux/flux/utils.py b/flux/mlx_flux/utils.py similarity index 100% rename from flux/flux/utils.py rename to flux/mlx_flux/utils.py diff --git a/flux/txt2image.py b/flux/txt2image.py index 5ebec81a..fc209b17 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -8,7 +8,7 @@ from PIL import Image from tqdm import tqdm -from flux import FluxPipeline +from mlx_flux import FluxPipeline def to_latent_size(image_size): From e61849a0032b2819961956bc62fd62dd34bb6199 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 12:35:49 +0800 Subject: [PATCH 2/5] FLUX: move cli to mlx_flux dir --- flux/{ => mlx_flux}/dreambooth.py | 21 ++++++++++++++------- flux/{ => mlx_flux}/txt2image.py | 23 +++++++++++++++-------- 2 files changed, 29 insertions(+), 15 deletions(-) rename flux/{ => mlx_flux}/dreambooth.py (98%) rename flux/{ => mlx_flux}/txt2image.py (92%) diff --git a/flux/dreambooth.py b/flux/mlx_flux/dreambooth.py similarity index 98% rename from flux/dreambooth.py rename to flux/mlx_flux/dreambooth.py index 9dcaffb3..8327af09 100644 --- a/flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -1,19 +1,20 @@ # Copyright © 2024 Apple Inc. import argparse -import time -from functools import partial -from pathlib import Path - import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np +import time +from PIL import Image +from functools import partial from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce -from PIL import Image +from pathlib import Path -from mlx_flux import FluxPipeline, Trainer, load_dataset +from .datasets import load_dataset +from .flux import FluxPipeline +from .trainer import Trainer def generate_progress_images(iteration, flux, args): @@ -186,6 +187,7 @@ def setup_arg_parser(): optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] + @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -196,12 +198,14 @@ def single_step(x, t5_feat, clip_feat, guidance): return loss + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) + @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -210,6 +214,7 @@ def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grad grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads + @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -225,6 +230,7 @@ def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): return loss + # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -247,6 +253,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): x, t5_feat, clip_feat, guidance, prev_grads ) + dataset = load_dataset(args.dataset) trainer = Trainer(flux, dataset, args) trainer.encode_dataset() @@ -266,7 +273,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024**3 + peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 print( f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " diff --git a/flux/txt2image.py b/flux/mlx_flux/txt2image.py similarity index 92% rename from flux/txt2image.py rename to flux/mlx_flux/txt2image.py index fc209b17..358e8cad 100644 --- a/flux/txt2image.py +++ b/flux/mlx_flux/txt2image.py @@ -1,14 +1,13 @@ # Copyright © 2024 Apple Inc. import argparse - import mlx.core as mx import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm -from mlx_flux import FluxPipeline +from .flux import FluxPipeline def to_latent_size(image_size): @@ -39,7 +38,7 @@ def load_adapter(flux, adapter_file, fuse=False): flux.fuse_lora_layers() -if __name__ == "__main__": +def build_parser(): parser = argparse.ArgumentParser( description="Generate images from a textual prompt using stable diffusion" ) @@ -62,7 +61,11 @@ def load_adapter(flux, adapter_file, fuse=False): parser.add_argument("--adapter") parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") - args = parser.parse_args() + return parser + + +def main(): + args = build_parser().parse_args() # Load the models flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) @@ -93,7 +96,7 @@ def load_adapter(flux, adapter_file, fuse=False): # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3 mx.metal.reset_peak_memory() # The following is not necessary but it may help in memory constrained @@ -108,15 +111,15 @@ def load_adapter(flux, adapter_file, fuse=False): # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3 mx.metal.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): - decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) + decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) @@ -148,3 +151,7 @@ def load_adapter(flux, adapter_file, fuse=False): print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") + + +if __name__ == "__main__": + main() From 1c43a832802da9bab5a6fa00b3ca7baea2fb7720 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 12:45:56 +0800 Subject: [PATCH 3/5] FLUX: add setup config --- flux/mlx_flux/__init__.py | 1 + flux/mlx_flux/_version.py | 3 ++ flux/setup.py | 61 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 flux/mlx_flux/_version.py create mode 100644 flux/setup.py diff --git a/flux/mlx_flux/__init__.py b/flux/mlx_flux/__init__.py index b1122d75..08c272ae 100644 --- a/flux/mlx_flux/__init__.py +++ b/flux/mlx_flux/__init__.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +from ._version import __version__ from .datasets import Dataset, load_dataset from .flux import FluxPipeline from .lora import LoRALinear diff --git a/flux/mlx_flux/_version.py b/flux/mlx_flux/_version.py new file mode 100644 index 00000000..87ee07a7 --- /dev/null +++ b/flux/mlx_flux/_version.py @@ -0,0 +1,3 @@ +# Copyright © 2023-2024 Apple Inc. + +__version__ = "0.1.0" diff --git a/flux/setup.py b/flux/setup.py new file mode 100644 index 00000000..7a92f78b --- /dev/null +++ b/flux/setup.py @@ -0,0 +1,61 @@ +# Copyright © 2024 Apple Inc. + +import os +import sys +from pathlib import Path + +from setuptools import find_namespace_packages, setup + +# 获取当前文件的父目录(项目根目录) +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +package_dir = os.path.join(ROOT_DIR, "mlx_flux") + +# 定义依赖列表 +requirements = [] +if os.path.exists(os.path.join(ROOT_DIR, "requirements.txt")): + with open(os.path.join(ROOT_DIR, "requirements.txt")) as fid: + requirements = [l.strip() for l in fid.readlines() if l.strip()] + +# 添加包路径 +sys.path.append(package_dir) + +from _version import __version__ + +try: + with open(os.path.join(ROOT_DIR, "README.md"), encoding="utf-8") as f: + long_description = f.read() +except FileNotFoundError: + long_description = "FLUX.1 on Apple silicon with MLX and the Hugging Face Hub" + +setup( + name="mlx-flux", + version=__version__, + description="FLUX.1 on Apple silicon with MLX and the Hugging Face Hub", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + readme="README.md", + author_email="mlx@group.apple.com", + author="MLX Contributors", + url="https://github.com/ml-explore/mlx-examples", + license="MIT", + install_requires=requirements, + + # Package configuration + packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包 + package_data={ + "mlx_flux": ["*.py"], + }, + include_package_data=True, + + python_requires=">=3.8", + entry_points={ + "console_scripts": [ + # generate images + "mlx_flux.generate = mlx_flux.txt2image:main", + "mlx_flux.txt2image = mlx_flux.txt2image:main", + # fine-tuning model + "mlx_flux.lora = mlx_flux.dreambooth:main", + "mlx_flux.dreambooth = mlx_flux.dreambooth:main", + ] + }, +) From 39fd6d272f2464dd400aa63be1bad0e6ece59c9b Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 12:51:22 +0800 Subject: [PATCH 4/5] FLUX: fix pre-commit lints --- flux/mlx_flux/dreambooth.py | 17 ++++++----------- flux/mlx_flux/txt2image.py | 9 +++++---- flux/setup.py | 6 +++--- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/flux/mlx_flux/dreambooth.py b/flux/mlx_flux/dreambooth.py index 8327af09..91049cb1 100644 --- a/flux/mlx_flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -1,16 +1,17 @@ # Copyright © 2024 Apple Inc. import argparse +import time +from functools import partial +from pathlib import Path + import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np -import time -from PIL import Image -from functools import partial from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce -from pathlib import Path +from PIL import Image from .datasets import load_dataset from .flux import FluxPipeline @@ -187,7 +188,6 @@ def setup_arg_parser(): optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] - @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -198,14 +198,12 @@ def single_step(x, t5_feat, clip_feat, guidance): return loss - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -214,7 +212,6 @@ def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grad grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads - @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -230,7 +227,6 @@ def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): return loss - # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -253,7 +249,6 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): x, t5_feat, clip_feat, guidance, prev_grads ) - dataset = load_dataset(args.dataset) trainer = Trainer(flux, dataset, args) trainer.encode_dataset() @@ -273,7 +268,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem = mx.metal.get_peak_memory() / 1024**3 print( f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " diff --git a/flux/mlx_flux/txt2image.py b/flux/mlx_flux/txt2image.py index 358e8cad..74892d6d 100644 --- a/flux/mlx_flux/txt2image.py +++ b/flux/mlx_flux/txt2image.py @@ -1,6 +1,7 @@ # Copyright © 2024 Apple Inc. import argparse + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -96,7 +97,7 @@ def main(): # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # The following is not necessary but it may help in memory constrained @@ -111,15 +112,15 @@ def main(): # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): - decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size)) + decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) diff --git a/flux/setup.py b/flux/setup.py index 7a92f78b..e8235b15 100644 --- a/flux/setup.py +++ b/flux/setup.py @@ -39,14 +39,14 @@ url="https://github.com/ml-explore/mlx-examples", license="MIT", install_requires=requirements, - # Package configuration - packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包 + packages=find_namespace_packages( + include=["mlx_flux", "mlx_flux.*"] + ), # 明确指定包含的包 package_data={ "mlx_flux": ["*.py"], }, include_package_data=True, - python_requires=">=3.8", entry_points={ "console_scripts": [ From 230215a50d6ebfb1ff70f9514a755a8a7b81d4c5 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 15:20:58 +0800 Subject: [PATCH 5/5] FLUX: dreambooth add main() def --- flux/mlx_flux/dreambooth.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flux/mlx_flux/dreambooth.py b/flux/mlx_flux/dreambooth.py index 91049cb1..2ea83eb6 100644 --- a/flux/mlx_flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -155,7 +155,7 @@ def setup_arg_parser(): return parser -if __name__ == "__main__": +def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -285,3 +285,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): if (i + 1) % 10 == 0: losses = [] tic = time.time() + + +if __name__ == "__main__": + main()