Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLUX: Add support for setup configuration to publish module #1096

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flux/flux/__init__.py → flux/mlx_flux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.

from ._version import __version__
from .datasets import Dataset, load_dataset
from .flux import FluxPipeline
from .lora import LoRALinear
Expand Down
3 changes: 3 additions & 0 deletions flux/mlx_flux/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.1.0"
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 8 additions & 2 deletions flux/dreambooth.py → flux/mlx_flux/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image

from flux import FluxPipeline, Trainer, load_dataset, save_config
from .datasets import load_dataset
from .flux import FluxPipeline
from .trainer import Trainer


def generate_progress_images(iteration, flux, args):
Expand Down Expand Up @@ -153,7 +155,7 @@ def setup_arg_parser():
return parser


if __name__ == "__main__":
def main():
parser = setup_arg_parser()
args = parser.parse_args()

Expand Down Expand Up @@ -290,3 +292,7 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):

save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")


if __name__ == "__main__":
main()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 11 additions & 3 deletions flux/txt2image.py → flux/mlx_flux/txt2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from tqdm import tqdm

from flux import FluxPipeline
from .flux import FluxPipeline


def to_latent_size(image_size):
Expand Down Expand Up @@ -39,7 +39,7 @@ def load_adapter(flux, adapter_file, fuse=False):
flux.fuse_lora_layers()


if __name__ == "__main__":
def build_parser():
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
Expand All @@ -62,7 +62,11 @@ def load_adapter(flux, adapter_file, fuse=False):
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
return parser


def main():
args = build_parser().parse_args()

# Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
Expand Down Expand Up @@ -148,3 +152,7 @@ def load_adapter(flux, adapter_file, fuse=False):
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")


if __name__ == "__main__":
main()
File renamed without changes.
61 changes: 61 additions & 0 deletions flux/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright © 2024 Apple Inc.

import os
import sys
from pathlib import Path

from setuptools import find_namespace_packages, setup

# 获取当前文件的父目录(项目根目录)
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
package_dir = os.path.join(ROOT_DIR, "mlx_flux")

# 定义依赖列表
requirements = []
if os.path.exists(os.path.join(ROOT_DIR, "requirements.txt")):
with open(os.path.join(ROOT_DIR, "requirements.txt")) as fid:
requirements = [l.strip() for l in fid.readlines() if l.strip()]

# 添加包路径
sys.path.append(package_dir)

from _version import __version__

try:
with open(os.path.join(ROOT_DIR, "README.md"), encoding="utf-8") as f:
long_description = f.read()
except FileNotFoundError:
long_description = "FLUX.1 on Apple silicon with MLX and the Hugging Face Hub"

setup(
name="mlx-flux",
version=__version__,
description="FLUX.1 on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
author_email="[email protected]",
author="MLX Contributors",
url="https://github.com/ml-explore/mlx-examples",
license="MIT",
install_requires=requirements,
# Package configuration
packages=find_namespace_packages(
include=["mlx_flux", "mlx_flux.*"]
), # 明确指定包含的包
package_data={
"mlx_flux": ["*.py"],
},
include_package_data=True,
python_requires=">=3.8",
entry_points={
"console_scripts": [
# generate images
"mlx_flux.generate = mlx_flux.txt2image:main",
"mlx_flux.txt2image = mlx_flux.txt2image:main",
# fine-tuning model
"mlx_flux.lora = mlx_flux.dreambooth:main",
"mlx_flux.dreambooth = mlx_flux.dreambooth:main",
]
},
)