Skip to content

Commit

Permalink
Some improvements to the interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Kreitzer committed Jan 16, 2024
1 parent c30cbe0 commit e5d1545
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 18 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torch = "^2.1.2"
realesrgan = {git = "https://github.com/sberbank-ai/Real-ESRGAN.git"}
torchvision = "^0.16.2"
torchaudio = "^2.1.2"
tqdm = "^4.66.1"

[tool.poetry.scripts]
upscale = "video_upscaler.main:main"
Expand Down
62 changes: 45 additions & 17 deletions src/video_upscaler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,83 @@
from RealESRGAN import RealESRGAN
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm

def upscale_image(image: Image, device: torch.device, weights: Path, scale: int = 4) -> Image:
"""Upscale a single image."""
upscaler = get_upscaler(device, scale)
upscaler.load_weights(weights, download=True)
return upscaler.predict(image)

def upscale_frame(frame: av.VideoFrame, upscaler: RealESRGAN) -> av.VideoFrame:
"""Upscale a single frame."""
img = frame.to_image().convert('RGB')
sr_img = upscaler.predict(img)
return av.VideoFrame.from_image(sr_img)

def upscale_video(input_path: Path, output_path: Path, scale: int = 4):
"""Upscale video frames and reassemble the video."""
if torch.cuda.is_available():
device = torch.device('cuda')
print("Using CUDA.")
elif torch.backends.mps.is_available():
device = torch.device('mps')
print("Using MPS.")
else:
device = torch.device('cpu')
print("Using CPU.")
def get_upscaler(device: torch.device, model_weights_path: Path, scale: int = 4) -> RealESRGAN:
"""Get the upscaler model."""
upscaler = RealESRGAN(device, scale)
upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True)
upscaler.load_weights(str(model_weights_path), download=True)
return upscaler

def upscale_video(input_path: Path, upscaler: RealESRGAN, scale: int = 4) -> None:
"""Upscale video frames and reassemble the video."""

input_container = av.open(str(input_path))
output_container = av.open(str(output_path), 'w')
output_container = av.open(f"{input_path.stem}_HD.mp4", 'w')

stream = input_container.streams.video[0]
output_stream = output_container.add_stream('mpeg4', rate=stream.average_rate)
output_stream.width = stream.width * scale
output_stream.height = stream.height * scale
output_stream.pix_fmt = 'yuv420p'

total_frames = input_container.streams.video[0].frames

print("Upscaling video...")
progress_bar = tqdm(total=total_frames, unit='frames', desc='Upscaling')

for frame in input_container.decode(stream):
sr_frame = upscale_frame(frame, upscaler)
packet = output_stream.encode(sr_frame)
output_container.mux(packet)
progress_bar.update()

progress_bar.close()

# Flush and close the containers
print("Flush and close the containers")
progress_bar = tqdm(total=total_frames, unit='frames', desc='Encoding/Muxing')
for packet in output_stream.encode():
output_container.mux(packet)
progress_bar.update()

progress_bar.close()

input_container.close()
output_container.close()

def get_device() -> torch.device:
if torch.cuda.is_available():
print("Using CUDA.")
return torch.device('cuda')
elif torch.backends.mps.is_available():
print("Using MPS.")
return torch.device('mps')
else:
print("Using CPU.")
return torch.device('cpu')

def main():
parser = ArgumentParser(description="Upscale video files.")
parser.add_argument('input_file', type=Path, help='Path to the input video file.')
parser.add_argument('output_file', type=Path, help='Path for the output upscaled video file.')
parser.add_argument('--scale', type=int, default=4, help='Upscaling factor.')
parser.add_argument('--model_weights', type=Path, default=Path('weights/RealESRGAN_x4.pth'), help='Path to the model weights file.')
parser.add_argument('--download-weights', action='store_true', default=True, help='Download the model weights if they are not found locally.')
args = parser.parse_args()

upscale_video(args.input_file, args.output_file)
device = get_device()
upscaler = get_upscaler(device, args.model_weights, args.scale)
upscale_video(args.input_file, upscaler, args.scale)

if __name__ == "__main__":
main()
Expand Down
Binary file added tests/data/DSC_0141.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/test_image_upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest
import video_upscaler

0 comments on commit e5d1545

Please sign in to comment.