Skip to content

Commit

Permalink
Merge pull request #381 from kohya-ss/dev
Browse files Browse the repository at this point in the history
feature to upload to huggingface etc.
  • Loading branch information
kohya-ss authored Apr 5, 2023
2 parents 8eb60ba + defefd7 commit b5c60d7
Show file tree
Hide file tree
Showing 13 changed files with 703 additions and 177 deletions.
198 changes: 117 additions & 81 deletions README.md

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down
85 changes: 69 additions & 16 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@

import library.model_util as model_util
import library.train_util as train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo

Expand Down Expand Up @@ -634,6 +635,7 @@ def __call__(
img2img_noise=None,
clip_prompts=None,
clip_guide_images=None,
networks: Optional[List[LoRANetwork]] = None,
**kwargs,
):
r"""
Expand Down Expand Up @@ -717,6 +719,7 @@ def __call__(
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]

vae_batch_size = (
batch_size
Expand Down Expand Up @@ -1010,6 +1013,11 @@ def __call__(

# predict the noise residual
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
i,
num_latent_input,
Expand All @@ -1019,7 +1027,7 @@ def __call__(
i / len(timesteps),
latent_model_input,
t,
text_embeddings,
text_emb_last,
).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand Down Expand Up @@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
if isinstance(prompt, str):
prompt = [prompt]

# split the prompts with "AND". each prompt must have the same number of splits
new_prompts = []
for p in prompt:
new_prompts.extend(p.split(" AND "))
prompt = new_prompts

if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None:
Expand Down Expand Up @@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
negative_scale: float
strength: float
network_muls: Tuple[float]
num_sub_prompts: int


class BatchData(NamedTuple):
Expand Down Expand Up @@ -2276,16 +2291,20 @@ def __getattr__(self, item):
print(f"metadata for: {network_weight}: {metadata}")

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return

if not args.network_merge:
mergiable = hasattr(network, "merge_to")
if args.network_merge and not mergiable:
print("network is not mergiable. ignore merge option.")

if not args.network_merge or not mergiable:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")

if args.opt_channels_last:
Expand Down Expand Up @@ -2349,12 +2368,12 @@ def __getattr__(self, item):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()

# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI

# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
Expand Down Expand Up @@ -2558,16 +2577,22 @@ def resize_images(imgs, size):
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))

regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))

size = None
for i, network in enumerate(networks):
if i < 3:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None

prev_image = None # for VGG16 guided
Expand Down Expand Up @@ -2623,7 +2648,14 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
height_1st = height_1st - height_1st % 32

ext_1st = BatchDataExt(
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
Expand Down Expand Up @@ -2651,7 +2683,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(width, height, steps, scale, negative_scale, strength, network_muls),
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)

Expand Down Expand Up @@ -2743,8 +2775,11 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):

# generate
if networks:
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)

images = pipe(
prompts,
Expand Down Expand Up @@ -2969,11 +3004,26 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print("Use previous image as guide image.")
guide_image = prev_image

if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None

b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
width,
height,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
Expand Down Expand Up @@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
)
# parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

return parser

Expand Down
78 changes: 78 additions & 0 deletions library/huggingface_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import *
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os

from library.utils import fire_in_thread


def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi(
token=token,
)
try:
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
return True
except:
return False


def upload(
args: argparse.Namespace,
src: Union[str, Path, bytes, BinaryIO],
dest_suffix: str = "",
force_sync_upload: bool = False,
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)

is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)

def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)

if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
else:
uploader()


def list_dir(
repo_id: str,
subfolder: str,
repo_type: str,
revision: str = "main",
token: str = None,
):
api = HfApi(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
return file_list
Loading

0 comments on commit b5c60d7

Please sign in to comment.