Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning App demo #2

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This repository provides code for the Waveformer architecture proposed in the paper. Waveformer is a low-latency target sound extraction model implementing streaming inference -- the model process a ~10 ms input audio chunk at each time step, while only looking at past chunks and no future chunks. On a Core i5 CPU using a single thread, real-time factors (RTFs) of different model configurations range from 0.66 to 0.94, with an end-to-end latency less than 20 ms.

[![Gradio demo](https://img.shields.io/badge/arxiv-abs-green)](https://arxiv.org/abs/2211.02250) [![Gradio demo](https://img.shields.io/badge/arxiv-pdf-green)](https://arxiv.org/pdf/2211.02250) [![Gradio demo](https://img.shields.io/badge/Gradio-app-blue)](https://huggingface.co/spaces/uwx/waveformer)
[![App Gallery](https://bit.ly/3xTcccO)](https://01ghh2pnbdet9ex9sdqqsnpxwh.litng-ai-03.litng.ai/view)

<video src="https://user-images.githubusercontent.com/16723254/199796287-e6aa464d-7da4-4941-b356-0668d96d9184.mp4"></video>

Expand Down
80 changes: 80 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
import os

import gradio as gr
import lightning as L
import torch
import torchaudio
import wget
from lightning.app.components.serve import ServeGradio

from Waveformer import TARGETS
from Waveformer import Waveformer as WaveformerModel


class ModelDemo(ServeGradio):
inputs = [
gr.Audio(label="Input audio"),
gr.CheckboxGroup(choices=TARGETS, label="Extract target sound"),
]
outputs = gr.Audio(label="Output audio")
examples = [["data/Sample.wav"]]
enable_queue: bool = False

def __init__(self, *args, **kwargs):
super().__init__(cloud_compute=L.CloudCompute("cpu-medium"), **kwargs)
self._device = None

def build_model(self):
if not os.path.exists("default_config.json"):
config_url = (
"https://targetsound.cs.washington.edu/files/default_config.json"
)
print("Downloading model configuration from %s:" % config_url)
wget.download(config_url)

if not os.path.exists("default_ckpt.pt"):
ckpt_url = "https://targetsound.cs.washington.edu/files/default_ckpt.pt"
print("\nDownloading the checkpoint from %s:" % ckpt_url)
wget.download(ckpt_url)

# Instantiate model
with open("default_config.json") as f:
params = json.load(f)
model = WaveformerModel(**params["model_params"])
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
self._device = device
print(f"loading model on {device}")
model.load_state_dict(
torch.load("default_ckpt.pt", map_location=self._device)["model_state_dict"]
)
return model.to(self._device).eval()

@torch.inference_mode()
def predict(self, audio, label_choices):
# Read input audio
fs, mixture = audio
if fs!=44100:
mixture = torchaudio.functional.resample(
torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100
).numpy()

mixture = torch.from_numpy(mixture).unsqueeze(0).unsqueeze(0).to(
torch.float
) / (2.0**15)

# Construct the query vector
query = torch.zeros(1, len(TARGETS)).to(self._device)
for t in label_choices:
query[0, TARGETS.index(t)] = 1.0

with torch.inference_mode():
output = (2.0**15) * self.model(mixture.to(self._device), query)

return fs, output.squeeze(0).squeeze(0).to(torch.short).cpu().numpy()


app = L.LightningApp(ModelDemo())
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ seaborn
ipykernel
scaper
wget
gradio