Skip to content

Commit

Permalink
add real-world inference
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzoAgnolucci authored Nov 27, 2023
1 parent f07985b commit 96cd431
Show file tree
Hide file tree
Showing 12 changed files with 1,741 additions and 0 deletions.
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,54 @@ Overview of the proposed approach. *Left* given a video, we identify the cleanes
}
```

## Installation
We recommend using the [**Anaconda**](https://www.anaconda.com/) package manager to avoid dependency/reproducibility
problems.
For Linux systems, you can find a conda installation
guide [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html).

1. Clone the repository

```sh
git clone https://github.com/miccunifi/TAPE
```

2. Install Python dependencies

```sh
conda create -n TAPE -y python=3.10
conda activate TAPE
cd TAPE
chmod +x install_requirements.sh
./install_requirements.sh TAPE
```

## Real-world inference
To use our method for restoring a real-world video, download the pre-trained model from the
[release](https://github.com/miccunifi/TAPE/releases/tag/latest) and place it under
the ```TAPE/experiments/pretrained_model``` directory. Then, run the following command:

```python real_world_inference.py --input-path <path_to_video> --output-path <path_to_output_folder>```


```
--input-path <str> Path to the video to restore
--output-path <str> Path to the output folder
--checkpoint-path <str> Path to the pretrained model checkpoint (default=experiments/pretrained_model/checkpoint.pth)
--num-input-frames <int> Number of input frames T for each input window (default=5)
--num-reference-frames <int> Number of reference frames D for each input window (default=5)
--preprocess-mode <str> Preprocessing mode, options: ['crop', 'resize', 'none']. 'crop' extracts the --patch-size center
crop, 'resize' resizes the longest side to --patch-size while keeping the aspect ratio, 'none'
applies no preprocessing (default=crop)
--patch-size <int> Maximum patch size for --preprocess-mode ['crop', 'resize'] (default=512)
--frame-format <str> Frame format of the extracted and restored frames (default=jpg)
--generate-combined-video <store_true> Whether to generate the combined video (i.e. input and restored videos side by side)
--no-intermediate-products <store_true> Whether to delete intermediate products (i.e. input frames, restored frames, references)
--batch-size <int> Batch size (default=1)
--num-workers <int> Number of workers of the data loader (default=20)
```

## Dataset

<p align='center'>
Expand All @@ -48,6 +96,7 @@ The dataset can be downloaded [here](https://drive.google.com/drive/folders/1NjT

## TO-DO:
- [ ] Pre-trained model
- [ ] Real-world inference code
- [ ] Testing code
- [ ] Training code
- [x] Synthetic dataset
Expand Down
93 changes: 93 additions & 0 deletions data/RealWorldVideoDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor
import json
import cv2

from utils.utils import preprocess


class RealWorldVideoDataset(Dataset):
"""
Dataset for real world videos (i.e. no ground truth). Each item is given by a window of num_input_frames input
frames (to be restored) and a window of num_reference_frames reference frames.
Args:
input_folder (Path): Path to the folder containing the input frames
num_input_frames (int): Number of input frames T of the input window
num_reference_frames (int): Number of reference frames D
references_file_path (Path): Path to the file containing the references for each frame
preprocess_mode (str): Preprocessing mode for when the size of the input frames is greater than the patch size.
Supported modes: ["crop", "resize"]
patch_size (int): Maximum patch size
frame_format (str): Format of the input frames
Returns:
dict with keys:
"imgs_lq" (torch.Tensor): Input frames
"imgs_ref" (torch.Tensor): Reference frames
"img_name" (str): Name of the center input frame
"""

def __init__(self,
input_folder: Path,
num_input_frames: int = 5,
num_reference_frames: int = 5,
references_file_path: Path = "references.json",
preprocess_mode: str = "crop",
patch_size: int = 768,
frame_format: str = "jpg"):
self.input_folder = input_folder
self.num_input_frames = num_input_frames
self.num_reference_frames = num_reference_frames
self.preprocess_mode = preprocess_mode
self.patch_size = patch_size

self.img_paths = sorted(list(input_folder.glob(f"*.{frame_format}")))

# Load references
with open(references_file_path, 'r') as f:
self.references = json.load(f)

def __getitem__(self, idx):
img_name = self.img_paths[idx].name

half_input_window_size = self.num_input_frames // 2
idxs_imgs_lq = np.arange(idx - half_input_window_size, idx + half_input_window_size + 1)
idxs_imgs_lq = list(idxs_imgs_lq[(idxs_imgs_lq >= 0) & (idxs_imgs_lq <= len(self.img_paths) - 1)])
imgs_lq = []
for img_idx in idxs_imgs_lq:
img = cv2.imread(str(self.img_paths[img_idx]))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.
img_t = ToTensor()(img)
imgs_lq.append(img_t)

# Pad with black frames if the window is not complete
if len(imgs_lq) < self.num_input_frames:
black_frame = torch.zeros_like(imgs_lq[0])
missing_frames_left = half_input_window_size - (idx - 0)
for _ in range(missing_frames_left):
imgs_lq.insert(0, black_frame)
missing_frames_right = half_input_window_size - (len(self.img_paths) - 1 - idx)
for _ in range(missing_frames_right):
imgs_lq.append(black_frame)
imgs_lq = torch.stack(imgs_lq)

imgs_ref = []
for ref_name in self.references[img_name]:
img_t = ToTensor()(Image.open(self.input_folder / ref_name))
imgs_ref.append(img_t)
imgs_ref = torch.stack(imgs_ref)

if self.preprocess_mode != "none":
imgs_lq, imgs_ref = preprocess([imgs_lq, imgs_ref], mode=self.preprocess_mode, patch_size=self.patch_size)

return {"imgs_lq": imgs_lq,
"imgs_ref": imgs_ref,
"img_name": img_name}

def __len__(self):
return len(self.img_paths)
Empty file.
5 changes: 5 additions & 0 deletions install_requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

# Install packages
conda install -y pytorch==2.1.1 torchvision==0.16.1 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install pandas==2.1.3 matplotlib==3.8.2 pyyaml==6.0.1 dotmap==1.3.30 tqdm==4.66.1 comet-ml==3.35.3 git+https://github.com/openai/clip.git@a1d0717 scikit-image==0.22.0 opencv-python==4.8.1.78 einops==0.7.0
Loading

0 comments on commit 96cd431

Please sign in to comment.