From 1892c82a600baae741c1973579ac85f0bb99c1d3 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Sun, 2 Apr 2023 19:43:34 +0000 Subject: [PATCH 1/5] Reinstantiate weighted captions after a necessary revert to Main --- fine_tune.py | 21 +- library/custom_train_functions.py | 315 ++++++++++++++++++++++++++++++ library/train_util.py | 4 +- train_db.py | 22 ++- train_network.py | 18 +- 5 files changed, 360 insertions(+), 20 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 637a729a8..658079df3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,7 @@ BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings def train(args): @@ -284,10 +284,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings(tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -427,4 +436,4 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + train(args) \ No newline at end of file diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index dde0bdd43..cb3f1bdda 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,5 +1,8 @@ import torch import argparse +import re +from typing import List, Optional, Union + def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): alphas_cumprod = noise_scheduler.alphas_cumprod @@ -16,3 +19,315 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): def add_custom_train_arguments(parser: argparse.ArgumentParser): parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") + parser.add_argument("--weighted_captions", action="store_true", default=False, help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.") + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + tokenizer, + text_encoder, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = text_encoder(text_input_chunk)[0] + else: + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + device, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) + # prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + # prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + tokenizer, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + return text_embeddings \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index 59dbc44c7..07bf28260 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -950,10 +950,10 @@ def __getitem__(self, index): example["images"] = images example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] - example["captions"] = captions return example @@ -3097,4 +3097,4 @@ def __call__(self, examples): # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] + return examples[0] \ No newline at end of file diff --git a/train_db.py b/train_db.py index b3eead941..5165da0ac 100644 --- a/train_db.py +++ b/train_db.py @@ -23,8 +23,7 @@ BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight - +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings def train(args): train_util.verify_training_args(args) @@ -273,10 +272,19 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings(tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) @@ -426,4 +434,4 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + train(args) \ No newline at end of file diff --git a/train_network.py b/train_network.py index 476f76dfc..d22c4378e 100644 --- a/train_network.py +++ b/train_network.py @@ -25,7 +25,7 @@ BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings # TODO 他のスクリプトと共通化する @@ -538,9 +538,17 @@ def train(args): with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings(tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -721,4 +729,4 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + train(args) \ No newline at end of file From dbab72153fbbce07f641d8f6402a682a5ec29480 Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Sat, 8 Apr 2023 00:44:56 -0500 Subject: [PATCH 2/5] Clean up custom_train_functions.py Removed commented out lines from earlier bugfix. --- library/custom_train_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index cb3f1bdda..c5e7ab39f 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -283,8 +283,6 @@ def get_weighted_text_embeddings( prompt = [prompt] prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - # prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - # prompt_weights = [[1.0] * len(token) for token in prompt_tokens] # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) @@ -330,4 +328,4 @@ def get_weighted_text_embeddings( current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - return text_embeddings \ No newline at end of file + return text_embeddings From a876f2d3fba260714063c1e8852d10d170243521 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Apr 2023 21:36:35 +0900 Subject: [PATCH 3/5] format by black --- library/custom_train_functions.py | 50 +++++++++++++++++++------------ 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index c5e7ab39f..9c0c40284 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,22 +4,34 @@ from typing import List, Optional, Union -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - alpha = sqrt_alphas_cumprod - sigma = sqrt_one_minus_alphas_cumprod - all_snr = (alpha / sigma) ** 2 - snr = torch.stack([all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper - loss = loss * snr_weight - return loss +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper + loss = loss * snr_weight + return loss + def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") - parser.add_argument("--weighted_captions", action="store_true", default=False, help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.") + parser.add_argument( + "--min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.", + ) + re_attention = re.compile( r""" @@ -283,10 +295,10 @@ def get_weighted_text_embeddings( prompt = [prompt] prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - + # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) - + max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (tokenizer.model_max_length - 2) + 1, @@ -308,7 +320,7 @@ def get_weighted_text_embeddings( chunk_length=tokenizer.model_max_length, ) prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - + # get the embeddings text_embeddings = get_unweighted_text_embeddings( tokenizer, @@ -321,11 +333,11 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - + # assign weights to the prompts and normalize in the sense of mean previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - + return text_embeddings From 6a5f87d874031b66152de2356a96076061e596de Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Apr 2023 21:45:57 +0900 Subject: [PATCH 4/5] disable weighted captions in TI/XTI training --- library/custom_train_functions.py | 15 ++++++++------- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9c0c40284..7eb829fa4 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -18,19 +18,20 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): return loss -def add_custom_train_arguments(parser: argparse.ArgumentParser): +def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): parser.add_argument( "--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) - parser.add_argument( - "--weighted_captions", - action="store_true", - default=False, - help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.", - ) + if support_weighted_captions: + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) re_attention = re.compile( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d8d803a42..98639345d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -549,7 +549,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser, False) parser.add_argument( "--save_model_as", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 9bd775efe..db46ad1b7 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -603,7 +603,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser, False) parser.add_argument( "--save_model_as", From 08c54dcf22cd02997ccd413fd5207fb0ac60e7ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Apr 2023 21:58:22 +0900 Subject: [PATCH 5/5] update readme --- README.md | 104 +++++++++--------------------------------------------- 1 file changed, 16 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index 61a9748b0..6f1178403 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,22 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 8 Apr. 2021, 2021/4/8: + +- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution! + - Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336) + - Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI. + - This option is also applicable to token strings of the DreamBooth method. + - The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible. + - If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses. + +- 重みづけキャプションによる学習に対応しました。 AI-Casanova 氏の素晴らしい貢献に感謝します。 + - 詳細はこちらをご確認ください。[PR #336](https://github.com/kohya-ss/sd-scripts/pull/336) + - `--weighted_captions` オプションを指定してください。Textual InversionおよびXTIを除く学習スクリプトで使用可能です。 + - キャプションだけでなく DreamBooth 手法の token string でも有効です。 + - 重みづけキャプションの記法はWeb UIとほぼ同じで、`(abc)`や`[abc]`、`(abc:1.23)`などが使用できます。入れ子も可能です。 + - 括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。 + ### 6 Apr. 2023, 2023/4/6: - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. @@ -147,7 +163,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。 - - モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。 - `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。 - アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。 @@ -163,93 +178,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。 -### 4 Apr. 2023, 2023/4/4, Release 0.6.0: -- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. -- The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed. - -- Fix some bugs and add some features. - - Fix an issue that `.json` format dataset config files cannot be read. [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) Thanks to rockerBOO! - - Raise an error when an invalid `--lr_warmup_steps` option is specified (when warmup is not valid for the specified scheduler). [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) Thanks to shirayu! - - Add `min_snr_gamma` to metadata in `train_network.py`. [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) Thanks to rockerBOO! - - Fix the data type handling in `fine_tune.py`. This may fix an error that occurs in some environments when using xformers, npz format cache, and mixed_precision. - -- Add options to `train_network.py` to specify block weights for learning rates. [PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) Thanks to u-haru for the great contribution! - - Specify the weights of 25 blocks for the full model. - - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. - - Specify the following arguments with `--network_args`. - - `down_lr_weight` : Specify the learning rate weight of the down blocks of U-Net. The following can be specified. - - The weight for each block: Specify 12 numbers such as `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`. - - Specify from preset: Specify such as `"down_lr_weight=sine"` (the weights by sine curve). sine, cosine, linear, reverse_linear, zeros can be specified. Also, if you add `+number` such as `"down_lr_weight=cosine+.25"`, the specified number is added (such as 0.25~1.25). - - `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`. - - `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight. - - If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created. - - `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0. - -- Add options to `train_network.py` to specify block dims (ranks) for variable rank. - - Specify 25 values ​​for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values ​​always. - - Specify the following arguments with `--network_args`. - - `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`. - - `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used. - - `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. - - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. - -- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。 -- 階層別学習率、階層別dim(rank)についてはモジュール側の変更が必要なため、当リポジトリ内のnetworkモジュール以外(LyCORISなど)では現在は動作しないと思われます。 - -- いくつかのバグ修正、機能追加を行いました。 - - `.json`形式のdataset設定ファイルを読み込めない不具合を修正しました。 [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) rockerBOO 氏に感謝します。 - - 無効な`--lr_warmup_steps` オプション(指定したスケジューラでwarmupが無効な場合)を指定している場合にエラーを出すようにしました。 [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) shirayu 氏に感謝します。 - - `train_network.py` で `min_snr_gamma` をメタデータに追加しました。 [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) rockerBOO 氏に感謝します。 - - `fine_tune.py` でデータ型の取り扱いが誤っていたのを修正しました。一部の環境でxformersを使い、npz形式のキャッシュ、mixed_precisionで学習した時にエラーとなる不具合が解消されるかもしれません。 - -- 階層別学習率を `train_network.py` で指定できるようになりました。[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) u-haru 氏の多大な貢献に感謝します。 - - フルモデルの25個のブロックの重みを指定できます。 - - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 - -`--network_args` で以下の引数を指定してください。 - - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 - - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 - - `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 - - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 - -- 階層別dim (rank)を `train_network.py` で指定できるようになりました。 - - フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 - - `--network_args` で以下の引数を指定してください。 - - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 - - `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。 - - `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。 - - `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。 - -- 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification: - - ` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"` - - ` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"` - -- 階層別学習率tomlファイル指定例 / Examples of block learning rate toml file specification - - `network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]` - - `network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]` - - -- 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification: - - ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"` - - ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` - - ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` - -- 階層別dim (rank)tomlファイル指定例 / Examples of block dim (rank) toml file specification - - `network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]` - - `network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]` - - ## Sample image generation during training A prompt file might look like this, for example