-
Notifications
You must be signed in to change notification settings - Fork 921
/
svc_merge.py
58 lines (45 loc) · 1.69 KB
/
svc_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import torch
import argparse
import collections
def load_model(checkpoint_path):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model_g"]
return saved_state_dict
def save_model(state_dict, checkpoint_path):
torch.save({'model_g': state_dict}, checkpoint_path)
def average_model(model_list):
model_keys = list(model_list[0].keys())
model_average = collections.OrderedDict()
for key in model_keys:
key_sum = 0
for i in range(len(model_list)):
key_sum = (key_sum + model_list[i][key])
model_average[key] = torch.div(key_sum, float(len(model_list)))
return model_average
# ss_list = []
# ss_list.append(s1)
# ss_list.append(s2)
# ss_merge = average_model(ss_list)
def merge_model(model1, model2, rate):
model_keys = model1.keys()
model_merge = collections.OrderedDict()
for key in model_keys:
key_merge = rate * model1[key] + (1 - rate) * model2[key]
model_merge[key] = key_merge
return model_merge
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m1', '--model1', type=str, required=True)
parser.add_argument('-m2', '--model2', type=str, required=True)
parser.add_argument('-r1', '--rate', type=float, required=True)
args = parser.parse_args()
print(args.model1)
print(args.model2)
print(args.rate)
assert args.rate > 0 and args.rate < 1, f"{args.rate} should be in range (0, 1)"
s1 = load_model(args.model1)
s2 = load_model(args.model2)
merge = merge_model(s1, s2, args.rate)
save_model(merge, "sovits5.0_merge.pth")