diff --git a/poetry.lock b/poetry.lock index 16815cc..ca2b26a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1134,4 +1134,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fc966464d3ffb01227a2e78352820e32055828c4b685e01e199bb663b4f0a1ba" +content-hash = "9506eb71838bf8c0d0acfe2bda9c32ee7706ef504e1c6f020537754840f04177" diff --git a/pyproject.toml b/pyproject.toml index 9ecdef2..d162212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ torch = "^2.1.2" realesrgan = {git = "https://github.com/sberbank-ai/Real-ESRGAN.git"} torchvision = "^0.16.2" torchaudio = "^2.1.2" +tqdm = "^4.66.1" [tool.poetry.scripts] upscale = "video_upscaler.main:main" diff --git a/src/video_upscaler/main.py b/src/video_upscaler/main.py index 89895d9..7e62b5a 100644 --- a/src/video_upscaler/main.py +++ b/src/video_upscaler/main.py @@ -6,6 +6,13 @@ from RealESRGAN import RealESRGAN from argparse import ArgumentParser from pathlib import Path +from tqdm import tqdm + +def upscale_image(image: Image, device: torch.device, weights: Path, scale: int = 4) -> Image: + """Upscale a single image.""" + upscaler = get_upscaler(device, scale) + upscaler.load_weights(weights, download=True) + return upscaler.predict(image) def upscale_frame(frame: av.VideoFrame, upscaler: RealESRGAN) -> av.VideoFrame: """Upscale a single frame.""" @@ -13,22 +20,17 @@ def upscale_frame(frame: av.VideoFrame, upscaler: RealESRGAN) -> av.VideoFrame: sr_img = upscaler.predict(img) return av.VideoFrame.from_image(sr_img) -def upscale_video(input_path: Path, output_path: Path, scale: int = 4): - """Upscale video frames and reassemble the video.""" - if torch.cuda.is_available(): - device = torch.device('cuda') - print("Using CUDA.") - elif torch.backends.mps.is_available(): - device = torch.device('mps') - print("Using MPS.") - else: - device = torch.device('cpu') - print("Using CPU.") +def get_upscaler(device: torch.device, model_weights_path: Path, scale: int = 4) -> RealESRGAN: + """Get the upscaler model.""" upscaler = RealESRGAN(device, scale) - upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True) + upscaler.load_weights(str(model_weights_path), download=True) + return upscaler + +def upscale_video(input_path: Path, upscaler: RealESRGAN, scale: int = 4) -> None: + """Upscale video frames and reassemble the video.""" input_container = av.open(str(input_path)) - output_container = av.open(str(output_path), 'w') + output_container = av.open(f"{input_path.stem}_HD.mp4", 'w') stream = input_container.streams.video[0] output_stream = output_container.add_stream('mpeg4', rate=stream.average_rate) @@ -36,25 +38,51 @@ def upscale_video(input_path: Path, output_path: Path, scale: int = 4): output_stream.height = stream.height * scale output_stream.pix_fmt = 'yuv420p' + total_frames = input_container.streams.video[0].frames + + print("Upscaling video...") + progress_bar = tqdm(total=total_frames, unit='frames', desc='Upscaling') + for frame in input_container.decode(stream): sr_frame = upscale_frame(frame, upscaler) packet = output_stream.encode(sr_frame) output_container.mux(packet) + progress_bar.update() + + progress_bar.close() - # Flush and close the containers + print("Flush and close the containers") + progress_bar = tqdm(total=total_frames, unit='frames', desc='Encoding/Muxing') for packet in output_stream.encode(): output_container.mux(packet) + progress_bar.update() + + progress_bar.close() input_container.close() output_container.close() +def get_device() -> torch.device: + if torch.cuda.is_available(): + print("Using CUDA.") + return torch.device('cuda') + elif torch.backends.mps.is_available(): + print("Using MPS.") + return torch.device('mps') + else: + print("Using CPU.") + return torch.device('cpu') + def main(): parser = ArgumentParser(description="Upscale video files.") parser.add_argument('input_file', type=Path, help='Path to the input video file.') - parser.add_argument('output_file', type=Path, help='Path for the output upscaled video file.') + parser.add_argument('--scale', type=int, default=4, help='Upscaling factor.') + parser.add_argument('--model_weights', type=Path, default=Path('weights/RealESRGAN_x4.pth'), help='Path to the model weights file.') + parser.add_argument('--download-weights', action='store_true', default=True, help='Download the model weights if they are not found locally.') args = parser.parse_args() - - upscale_video(args.input_file, args.output_file) + device = get_device() + upscaler = get_upscaler(device, args.model_weights, args.scale) + upscale_video(args.input_file, upscaler, args.scale) if __name__ == "__main__": main() diff --git a/tests/data/DSC_0141.jpeg b/tests/data/DSC_0141.jpeg new file mode 100644 index 0000000..8dfdac4 Binary files /dev/null and b/tests/data/DSC_0141.jpeg differ diff --git a/tests/test_image_upscale.py b/tests/test_image_upscale.py new file mode 100644 index 0000000..1f93174 --- /dev/null +++ b/tests/test_image_upscale.py @@ -0,0 +1,3 @@ +import pytest +import video_upscaler +