-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathWaveformer.py
119 lines (108 loc) · 3.13 KB
/
Waveformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import torch
import torchaudio
import wget
from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer
TARGETS = [
"Acoustic_guitar",
"Applause",
"Bark",
"Bass_drum",
"Burping_or_eructation",
"Bus",
"Cello",
"Chime",
"Clarinet",
"Computer_keyboard",
"Cough",
"Cowbell",
"Double_bass",
"Drawer_open_or_close",
"Electric_piano",
"Fart",
"Finger_snapping",
"Fireworks",
"Flute",
"Glockenspiel",
"Gong",
"Gunshot_or_gunfire",
"Harmonica",
"Hi-hat",
"Keys_jangling",
"Knock",
"Laughter",
"Meow",
"Microwave_oven",
"Oboe",
"Saxophone",
"Scissors",
"Shatter",
"Snare_drum",
"Squeak",
"Tambourine",
"Tearing",
"Telephone",
"Trumpet",
"Violin_or_fiddle",
"Writing",
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input", type=str, default=None, help="Path to the input audio file."
)
parser.add_argument(
"output",
type=str,
default=None,
help="Path to the output audio file (output is written in the .wav format).",
)
parser.add_argument(
"--targets",
nargs="+",
type=str,
default=[],
help="Targets to output. Pick a subset of: %s" % TARGETS,
)
args = parser.parse_args()
if not os.path.exists("default_config.json"):
config_url = "https://targetsound.cs.washington.edu/files/default_config.json"
print("Downloading model configuration from %s:" % config_url)
wget.download(config_url)
if not os.path.exists("default_ckpt.pt"):
ckpt_url = "https://targetsound.cs.washington.edu/files/default_ckpt.pt"
print("\nDownloading the checkpoint from %s:" % ckpt_url)
wget.download(ckpt_url)
# Instantiate model
params = utils.Params("default_config.json")
model = Waveformer(**params.model_params)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model.load_state_dict(
torch.load("default_ckpt.pt", map_location=device)["model_state_dict"]
)
model.to(device).eval()
# Read input audio
mixture, fs = torchaudio.load(args.input)
if fs != 44100:
mixture = torchaudio.functional.resample(mixture, orig_freq=fs, new_freq=44100)
mixture = mixture.unsqueeze(0).to(device)
print("Loaded input audio from %s" % args.input)
# Construct the query vector
if len(args.targets) == 0:
query = torch.ones(1, len(TARGETS))
else:
query = torch.zeros(1, len(TARGETS))
for t in args.targets:
query[0, TARGETS.index(t)] = 1.0
with torch.inference_mode():
output = model(mixture.to(device), query.to(device)).squeeze(0).cpu()
if fs != 44100:
output = torchaudio.functional.resample(output, orig_freq=44100, new_freq=fs)
print("Inference done. Saving output audio to %s" % args.output)
assert not os.path.exists(args.output), "Output file already exists."
torchaudio.save(args.output, output, fs)