From 1cd07770a432a18523fffee1dea057bf480cb0b4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Apr 2023 09:13:47 +0900 Subject: [PATCH 1/4] format by black --- library/model_util.py | 1963 ++++++++++++++++++++-------------------- networks/merge_lora.py | 397 ++++---- 2 files changed, 1179 insertions(+), 1181 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index a465bf7f0..b58330669 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -45,596 +45,574 @@ def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) - for path in paths: - new_path = path["new"] + for path in paths: + new_path = path["new"] - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) - return new_checkpoint + return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + conv_attn_to_linear(new_checkpoint) + return new_checkpoint def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) - - return config + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) + + return config def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + return text_model_dict def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # rename or add position_ids - ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" - if ANOTHER_POSITION_IDS_KEY in new_sd: - # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] - del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - - new_sd["text_model.embeddings.position_ids"] = position_ids - return new_sd + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + # endregion @@ -642,543 +620,546 @@ def convert_key(key): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict # ================# # VAE Conversion # # ================# + def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict # endregion # region 自作のモデル読み書きなど + def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' + return os.path.splitext(path)[1].lower() == ".safetensors" def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path) # , device) # may causes error - else: - checkpoint = torch.load(ckpt_path, map_location=device) - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error else: - state_dict = checkpoint - checkpoint = None + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] - return checkpoint, state_dict + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config).to(device) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config).to(device) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) - logging.set_verbosity_warning() + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) - return text_model, vae, unet + return text_model, vae, unet def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value - return new_sd + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - # epoch and global_step are sometimes not int - try: - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - except: - pass - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) VAE_PREFIX = "first_stage_model." def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) return vae - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") - else: - # StableDiffusion - vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) - else torch.load(vae_id, map_location="cpu")) - vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae # endregion def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) - - resos = set() - - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) - - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) - - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) - - size += divisible - - resos = list(resos) - resos.sort() - return resos - - -if __name__ == '__main__': - resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) - - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + size += divisible + + resos = list(resos) + resos.sort() + return resos + + +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 8d97392fd..101747856 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -1,4 +1,3 @@ - import math import argparse import os @@ -9,216 +8,234 @@ def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == '.safetensors': - sd = load_file(file_name) - else: - sd = torch.load(file_name, map_location='cpu') - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - return sd + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) - else: - torch.save(model, file_name) + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name) + else: + torch.save(model, file_name) def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): - text_encoder.to(merge_dtype) - unet.to(merge_dtype) - - # create module map - name_to_module = {} - for i, root_module in enumerate([text_encoder, unet]): - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE - - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') - name_to_module[lora_name] = child_module - - for model, ratio in zip(models, ratios): - print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - print(f"merging...") - for key in lora_sd.keys(): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[:key.index("lora_down")] + 'alpha' - - # find original module for this lora - module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" - if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") - continue - module = name_to_module[module_name] - # print(f"apply {key} to {module}") - - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) - ).unsqueeze(2).unsqueeze(3) * scale + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - - module.weight = torch.nn.Parameter(weight) + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) def merge_lora_models(models, ratios, merge_dtype): - base_alphas = {} # alpha for merged model - base_dims = {} - - merged_sd = {} - for model, ratio in zip(models, ratios): - print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - # get alpha and dim - alphas = {} # alpha for current model - dims = {} # dims for current model - for key in lora_sd.keys(): - if 'alpha' in key: - lora_module_name = key[:key.rfind(".alpha")] - alpha = float(lora_sd[key].detach().numpy()) - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - elif "lora_down" in key: - lora_module_name = key[:key.rfind(".lora_down")] - dim = lora_sd[key].size()[0] - dims[lora_module_name] = dim - if lora_module_name not in base_dims: - base_dims[lora_module_name] = dim - - for lora_module_name in dims.keys(): - if lora_module_name not in alphas: - alpha = dims[lora_module_name] - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") - - # merge - print(f"merging...") - for key in lora_sd.keys(): - if 'alpha' in key: - continue - - lora_module_name = key[:key.rfind(".lora_")] - - base_alpha = base_alphas[lora_module_name] - alpha = alphas[lora_module_name] - - scale = math.sqrt(alpha / base_alpha) * ratio - - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale - else: - merged_sd[key] = lora_sd[key] * scale - - # set alpha to sd - for lora_module_name, alpha in base_alphas.items(): - key = lora_module_name + ".alpha" - merged_sd[key] = torch.tensor(alpha) - - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - - return merged_sd + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + print(f"merging...") + for key in lora_sd.keys(): + if "alpha" in key: + continue + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + return merged_sd def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype - if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) - merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"saving SD model to: {args.save_to}") - model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, - args.sd_model, 0, 0, save_dtype, vae) - else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") - parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") - parser.add_argument("--sd_model", type=str, default=None, - help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--models", type=str, nargs='*', - help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--ratios", type=float, nargs='*', - help="ratios for each model / それぞれのLoRAモデルの比率") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - merge(args) + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) From 4627b389ff07efea282ec47a58411ab78d920195 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Apr 2023 09:15:57 +0900 Subject: [PATCH 2/4] fix device not specified in merge_lora.py --- library/model_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/model_util.py b/library/model_util.py index b58330669..35b0b6afe 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -812,7 +812,7 @@ def is_safetensors(path): return os.path.splitext(path)[1].lower() == ".safetensors" -def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), From 770a56193eeefce7d234fe897e16db7b778128b3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Apr 2023 09:17:37 +0900 Subject: [PATCH 3/4] fix conv2d3x3 is not merged --- networks/merge_lora.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 101747856..2fa8861bc 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -42,7 +42,9 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: From 18d69d8e5e0fd515598041bb44cd01ee626b808f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Apr 2023 09:21:27 +0900 Subject: [PATCH 4/4] update readme --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 168482fd5..76695270c 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 1 Apr. 2023, 2023/4/1: + - Fix an issue that `merge_lora.py` does not work with the latest version. + - Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights. + - 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。 + - `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。 - 31 Mar. 2023, 2023/3/31: - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)