-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmusicgen_nodes.py
148 lines (123 loc) · 4.79 KB
/
musicgen_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional, Union
import torch
from audiocraft.models import AudioGen, MusicGen
from .util import do_cleanup, object_to, obj_on_device, stack_audio_tensors, tensors_to, tensors_to_cpu
MODEL_NAMES = [
"musicgen-small",
"musicgen-medium",
"musicgen-melody",
"musicgen-large",
"musicgen-melody-large",
# TODO: stereo models seem not to be working out of the box
# "musicgen-stereo-small",
# "musicgen-stereo-medium",
# "musicgen-stereo-melody",
# "musicgen-stereo-large",
# "musicgen-stereo-melody-large",
"audiogen-medium",
]
class MusicgenLoader:
def __init__(self):
self.model = None
self.name = None
@classmethod
def INPUT_TYPES(s):
return {"required": {"model_name": (MODEL_NAMES,)}}
RETURN_NAMES = ("MODEL", "SR")
RETURN_TYPES = ("MUSICGEN_MODEL", "INT")
FUNCTION = "load"
CATEGORY = "audio"
def load(self, model_name: str):
self.unload()
print(f"MusicgenLoader: loading {model_name}")
self.name = "facebook/" + model_name
model_class = AudioGen if "audiogen" in self.name else MusicGen
self.model = model_class.get_pretrained(self.name)
sr = self.model.sample_rate
return self.model, sr
def unload(self):
if self.model is not None:
# force move to cpu, delete/collect, clear cache
self.model = object_to(self.model, empty_cuda_cache=False)
del self.model
do_cleanup()
print("MusicgenLoader: unloaded model")
class MusicgenGenerate:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MUSICGEN_MODEL",),
"text": ("STRING", {"default": "", "multiline": True}),
"batch_size": ("INT", {"default": 1, "min": 1}),
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 300.0, "step": 0.01}),
"cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"top_k": ("INT", {"default": 250, "min": 0, "max": 10000, "step": 1}),
"top_p": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"temperature": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.001}),
"seed": ("INT", {"default": 0, "min": 0}),
},
"optional": {"audio": ("AUDIO_TENSOR",)},
}
RETURN_NAMES = ("RAW_AUDIO",)
RETURN_TYPES = ("AUDIO_TENSOR",)
FUNCTION = "generate"
CATEGORY = "audio"
def generate(
self,
model: Union[AudioGen, MusicGen],
text: str = "",
batch_size: int = 1,
duration: float = 10.0,
cfg: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
temperature: float = 1.0,
seed: int = 0,
audio: Optional[torch.Tensor] = None,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
# empty string = unconditional generation
if text == "":
text = None
model.set_generation_params(
top_k=top_k,
top_p=top_p,
temperature=temperature,
duration=duration,
cfg_coef=cfg,
)
with torch.random.fork_rng(), obj_on_device(model, dst=device, verbose_move=True) as m:
torch.manual_seed(seed)
text_input = [text] * batch_size
if audio is not None:
# do continuation with input audio and (optional) text prompting
if isinstance(audio, list):
# left-padded stacking into batch tensor
audio = stack_audio_tensors(audio)
if audio.shape[0] < batch_size:
# (try to) expand batch if smaller than requested
audio = audio.expand(batch_size, -1, -1)
elif audio.shape[0] > batch_size:
# truncate batch if larger than requested
audio = audio[:batch_size]
audio_input = tensors_to(audio, device)
audio_out = m.generate_continuation(audio_input, model.sample_rate, text_input, progress=True)
elif text is not None:
# do text-to-music
audio_out = m.generate(text_input, progress=True)
else:
# do unconditional music generation
audio_out = m.generate_unconditional(batch_size, progress=True)
audio_out = tensors_to_cpu(audio_out)
audio_out = torch.unbind(audio_out)
do_cleanup()
return list(audio_out),
NODE_CLASS_MAPPINGS = {
"MusicgenGenerate": MusicgenGenerate,
"MusicgenLoader": MusicgenLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MusicgenGenerate": "Musicgen Generator",
"MusicgenLoader": "Musicgen Loader",
}