diff --git a/README.md b/README.md index 2e17153..570fbc1 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,54 @@ Overview of the proposed approach. *Left* given a video, we identify the cleanes } ``` +## Installation +We recommend using the [**Anaconda**](https://www.anaconda.com/) package manager to avoid dependency/reproducibility +problems. +For Linux systems, you can find a conda installation +guide [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html). + +1. Clone the repository + +```sh +git clone https://github.com/miccunifi/TAPE +``` + +2. Install Python dependencies + +```sh +conda create -n TAPE -y python=3.10 +conda activate TAPE +cd TAPE +chmod +x install_requirements.sh +./install_requirements.sh TAPE +``` + +## Real-world inference +To use our method for restoring a real-world video, download the pre-trained model from the +[release](https://github.com/miccunifi/TAPE/releases/tag/latest) and place it under +the ```TAPE/experiments/pretrained_model``` directory. Then, run the following command: + +```python real_world_inference.py --input-path --output-path ``` + + +``` +--input-path Path to the video to restore +--output-path Path to the output folder +--checkpoint-path Path to the pretrained model checkpoint (default=experiments/pretrained_model/checkpoint.pth) +--num-input-frames Number of input frames T for each input window (default=5) +--num-reference-frames Number of reference frames D for each input window (default=5) +--preprocess-mode Preprocessing mode, options: ['crop', 'resize', 'none']. 'crop' extracts the --patch-size center + crop, 'resize' resizes the longest side to --patch-size while keeping the aspect ratio, 'none' + applies no preprocessing (default=crop) +--patch-size Maximum patch size for --preprocess-mode ['crop', 'resize'] (default=512) + +--frame-format Frame format of the extracted and restored frames (default=jpg) +--generate-combined-video Whether to generate the combined video (i.e. input and restored videos side by side) +--no-intermediate-products Whether to delete intermediate products (i.e. input frames, restored frames, references) +--batch-size Batch size (default=1) +--num-workers Number of workers of the data loader (default=20) +``` + ## Dataset

@@ -48,6 +96,7 @@ The dataset can be downloaded [here](https://drive.google.com/drive/folders/1NjT ## TO-DO: - [ ] Pre-trained model +- [ ] Real-world inference code - [ ] Testing code - [ ] Training code - [x] Synthetic dataset diff --git a/data/RealWorldVideoDataset.py b/data/RealWorldVideoDataset.py new file mode 100644 index 0000000..1eae7e9 --- /dev/null +++ b/data/RealWorldVideoDataset.py @@ -0,0 +1,93 @@ +import torch +from torch.utils.data import Dataset +from pathlib import Path +import numpy as np +from PIL import Image +from torchvision.transforms import ToTensor +import json +import cv2 + +from utils.utils import preprocess + + +class RealWorldVideoDataset(Dataset): + """ + Dataset for real world videos (i.e. no ground truth). Each item is given by a window of num_input_frames input + frames (to be restored) and a window of num_reference_frames reference frames. + + Args: + input_folder (Path): Path to the folder containing the input frames + num_input_frames (int): Number of input frames T of the input window + num_reference_frames (int): Number of reference frames D + references_file_path (Path): Path to the file containing the references for each frame + preprocess_mode (str): Preprocessing mode for when the size of the input frames is greater than the patch size. + Supported modes: ["crop", "resize"] + patch_size (int): Maximum patch size + frame_format (str): Format of the input frames + Returns: + dict with keys: + "imgs_lq" (torch.Tensor): Input frames + "imgs_ref" (torch.Tensor): Reference frames + "img_name" (str): Name of the center input frame + """ + + def __init__(self, + input_folder: Path, + num_input_frames: int = 5, + num_reference_frames: int = 5, + references_file_path: Path = "references.json", + preprocess_mode: str = "crop", + patch_size: int = 768, + frame_format: str = "jpg"): + self.input_folder = input_folder + self.num_input_frames = num_input_frames + self.num_reference_frames = num_reference_frames + self.preprocess_mode = preprocess_mode + self.patch_size = patch_size + + self.img_paths = sorted(list(input_folder.glob(f"*.{frame_format}"))) + + # Load references + with open(references_file_path, 'r') as f: + self.references = json.load(f) + + def __getitem__(self, idx): + img_name = self.img_paths[idx].name + + half_input_window_size = self.num_input_frames // 2 + idxs_imgs_lq = np.arange(idx - half_input_window_size, idx + half_input_window_size + 1) + idxs_imgs_lq = list(idxs_imgs_lq[(idxs_imgs_lq >= 0) & (idxs_imgs_lq <= len(self.img_paths) - 1)]) + imgs_lq = [] + for img_idx in idxs_imgs_lq: + img = cv2.imread(str(self.img_paths[img_idx])) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.astype(np.float32) / 255. + img_t = ToTensor()(img) + imgs_lq.append(img_t) + + # Pad with black frames if the window is not complete + if len(imgs_lq) < self.num_input_frames: + black_frame = torch.zeros_like(imgs_lq[0]) + missing_frames_left = half_input_window_size - (idx - 0) + for _ in range(missing_frames_left): + imgs_lq.insert(0, black_frame) + missing_frames_right = half_input_window_size - (len(self.img_paths) - 1 - idx) + for _ in range(missing_frames_right): + imgs_lq.append(black_frame) + imgs_lq = torch.stack(imgs_lq) + + imgs_ref = [] + for ref_name in self.references[img_name]: + img_t = ToTensor()(Image.open(self.input_folder / ref_name)) + imgs_ref.append(img_t) + imgs_ref = torch.stack(imgs_ref) + + if self.preprocess_mode != "none": + imgs_lq, imgs_ref = preprocess([imgs_lq, imgs_ref], mode=self.preprocess_mode, patch_size=self.patch_size) + + return {"imgs_lq": imgs_lq, + "imgs_ref": imgs_ref, + "img_name": img_name} + + def __len__(self): + return len(self.img_paths) diff --git a/experiments/pretrained_model/placeholder.txt b/experiments/pretrained_model/placeholder.txt new file mode 100644 index 0000000..e69de29 diff --git a/install_requirements.sh b/install_requirements.sh new file mode 100644 index 0000000..ab21ea9 --- /dev/null +++ b/install_requirements.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +# Install packages +conda install -y pytorch==2.1.1 torchvision==0.16.1 pytorch-cuda=11.8 -c pytorch -c nvidia +pip install pandas==2.1.3 matplotlib==3.8.2 pyyaml==6.0.1 dotmap==1.3.30 tqdm==4.66.1 comet-ml==3.35.3 git+https://github.com/openai/clip.git@a1d0717 scikit-image==0.22.0 opencv-python==4.8.1.78 einops==0.7.0 diff --git a/models/mrsff.py b/models/mrsff.py new file mode 100644 index 0000000..2afbf74 --- /dev/null +++ b/models/mrsff.py @@ -0,0 +1,394 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Tuple +from einops import rearrange + +from utils.utils_models import (compute_mask_2D, window_partition_2D, window_reverse_2D, get_window_size, DropPath, Mlp, + trunc_normal_) + + +class AttentionPooling1d(nn.Module): + """ + Inspired by https://amaarora.github.io/posts/2023-03-11_Understanding_CLIP_part_2.html and + https://github.com/openai/CLIP/blob/a1d071733d7111c9c014f024669f959182114e33/clip/model.py#L58 + + Args: + dim (int): Input dimension. + num_heads (int): Number of attention heads. + sequence_length (int): Length of the sequence of transformer tokens. + """ + def __init__(self, dim: int, num_heads: int, sequence_length: int): + super().__init__() + self.sequence_length = sequence_length + self.pos_embedding = nn.Parameter(torch.randn(sequence_length, dim) / dim ** 0.5) + self.q_proj = nn.Linear(dim, dim) + self.k_proj = nn.Linear(dim, dim) + self.v_proj = nn.Linear(dim, dim) + self.out_proj = nn.Linear(dim, dim) + self.num_heads = num_heads + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (B*T, M, N, C) + + Returns: + x (torch.Tensor): (B*T, N, C) + """ + avg = x.mean(dim=1, keepdim=True) # (B*T, 1, N, C) + x = torch.cat([avg, x], dim=1) # (B*T, M+1, N, C) + x = x + self.pos_embedding[None, None, :, :] # (B*T, M+1, N, C) + x = rearrange(x, 'b m n c -> (m n) b c') # ((M+1)*N, B*T, C) + + x, _ = F.multi_head_attention_forward( + query=x[:self.sequence_length], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.out_proj.weight, + out_proj_bias=self.out_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + x = rearrange(x, 'n b c -> b n c') # (B*T, N, C) + return x + + +class MultiReferenceWindowAttention(nn.Module): + """ Multi-Reference-(Shifted)Window-Multi-head Cross Attention (MR-(S)W-MCA) module with relative position bias. + It supports both shifted and non-shifted window. The query is the restored features, while the key and values + are the reference features. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim: int, + window_size: Tuple[int], + num_heads: int, + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.act = nn.GELU() + + self.dim_reduction = AttentionPooling1d(dim=dim, num_heads=num_heads, sequence_length=window_size[0] * window_size[1]) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: torch.Tensor, x_kv: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: input features with shape of (num_windows*B, T, N, C) + x_kv: input features with shape of (num_windows*B, M, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + x_kv = x if x_kv is None else x_kv + B_, T, N, C = x.shape + _, M, _, _ = x_kv.shape + + q = self.q(x).reshape(B_, T, N, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + kv = self.kv(x_kv).reshape(B_, M, N, 2, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = q[0], kv[0], kv[1] # B_, T (M), nH, N, C/nH + + q = q.unsqueeze(2) # B_, T, 1, nH, N, C/nH + k = k.unsqueeze(1) # B_, 1, M, nH, N, C/nH + v = v.unsqueeze(1) # B_, 1, M, nH, N, C/nH + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # B_, T, M, nH, N, N + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias[None, None, None, ...] + + if mask is not None: + nW = mask.shape[0] + attn = rearrange(attn, '(b nW) t m nH n1 n2 -> b t m nW nH n1 n2', nW=nW) + mask = mask.unsqueeze(1)[None, None, None, ...] + attn += mask + attn = rearrange(attn, 'b t m nW nH n1 n2 -> (b nW) t m nH n1 n2') + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, T, M, N, C) + + x = rearrange(x, 'b t m n c -> (b t) m n c') + x = self.dim_reduction(x) + x = rearrange(x, '(b t) n c -> b t n c', t=T) + + x = self.act(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MRSFFBlock(nn.Module): + """ A Multi-Reference Spatial Feature Fusion (MRSFF) block presented in the paper https://arxiv.org/abs/2310.14926. + It combines the restored and reference features. Based on the Swin Transformer 2D block implementation. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim: int, + num_heads: int, + window_size: Tuple[int] = (7, 7), + shift_size: Tuple[int] = (0, 0), + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + + self.norm_q = norm_layer(dim) + self.norm_kv = norm_layer(dim) + self.attn = MultiReferenceWindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x: torch.Tensor, kv: torch.Tensor, mask_matrix: torch.Tensor) -> torch.Tensor: + """ Forward function. + + Args: + x (torch.Tensor): Input feature, tensor size (B, T, H, W, C). + kv (torch.Tensor): Reference feature, tensor size (B, M, H, W, C). + mask_matrix (torch.Tensor): Attention mask for cyclic shift. + """ + shortcut = x + x = self.forward_part1(x, kv, mask_matrix) + x = shortcut + self.drop_path(x) + x = x + self.forward_part2(x) + return x + + def forward_part1(self, x: torch.Tensor, kv: torch.Tensor, mask_matrix: torch.Tensor) -> torch.Tensor: + B, T, H, W, C = x.shape + x = rearrange(x, 'b t h w c -> (b t) h w c', b=B, t=T) + + _, M, _, _, _ = kv.shape + kv = rearrange(kv, 'b m h w c -> (b m) h w c', b=B, m=M) + + window_size, shift_size = get_window_size((H, W), self.window_size, self.shift_size) + + x = self.norm_q(x) + kv = self.norm_kv(kv) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + kv = F.pad(kv, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + shifted_kv = torch.roll(kv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + shifted_kv = kv + attn_mask = None + + # partition windows + x_windows = window_partition_2D(shifted_x, window_size) # B*T*nW, Wh*Ww, C + kv_windows = window_partition_2D(shifted_kv, window_size) # B*M*nW, Wh*Ww, C + + _, N, C = x_windows.shape + x_windows = x_windows.reshape(-1, T, N, C) + kv_windows = kv_windows.reshape(-1, M, N, C) + + # MR-W-MCA/MR-SW-MCA + attn_windows = self.attn(x_windows, kv_windows, mask=attn_mask) # B*T*nW, Wd*Wh*Ww, C + + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse_2D(attn_windows, window_size, B * T, Hp, Wp) # B*T H' W' C + + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + else: + x = shifted_x + + x = rearrange(x, '(b t) h w c -> b t h w c', b=B, t=T) + + if pad_r > 0 or pad_b > 0: + x = x[:, :, :H, :W, :].contiguous() + return x + + def forward_part2(self, x: torch.Tensor) -> torch.Tensor: + # FFN + return self.drop_path(self.mlp(self.norm2(x))) + + +class MRSFFLayer(nn.Module): + """ A Multi-Reference Spatial Feature Fusion (MRSFF) layer. + + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim: int, + depth: int, + num_heads: int, + window_size: Tuple[int] = (7, 7), + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + norm_layer: nn.Module = nn.LayerNorm): + + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([ + MRSFFBlock(dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + self.last_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + def forward(self, x: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ Forward function. + Args: + x (torch.Tensor): Input feature, tensor size (B, C, T, H, W). + kv (torch.Tensor): Reference feature, tensor size (B, C, M, H, W). + """ + # calculate attention mask for SW-MSA + B, C, T, H, W = x.shape + window_size, shift_size = get_window_size((H, W), self.window_size, self.shift_size) + + x = rearrange(x, 'b c t h w -> b t h w c') + kv = rearrange(kv, 'b c m h w -> b m h w c') + residual = x.clone() + + Hp = int(np.ceil(H / window_size[0])) * window_size[0] + Wp = int(np.ceil(W / window_size[1])) * window_size[1] + attn_mask = compute_mask_2D(Hp, Wp, window_size, shift_size, x.device) + + for blk in self.blocks: + x = blk(x, kv, attn_mask) + + x = rearrange(x, 'b t h w c -> b t c h w').reshape(B * T, C, H, W) + x = self.last_conv(x) + x = rearrange(x.reshape(B, T, C, H, W), 'b t c h w -> b t h w c') + x = x + residual + x = rearrange(x, 'b t h w c -> b c t h w') + return x diff --git a/models/swin_feature_extractor.py b/models/swin_feature_extractor.py new file mode 100644 index 0000000..c9d81f5 --- /dev/null +++ b/models/swin_feature_extractor.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from torchvision.models import swin_t, Swin_T_Weights +from einops import rearrange +from typing import List + + +class SwinFeatureExtractor(nn.Module): + + def __init__(self, layer_name_list: List[str] = None, use_input_norm: bool = True, use_range_norm: bool = False, + requires_grad: bool = False): + """Swin Transformer network for feature extraction. + + Args: + layer_name_list (List[str]): Forward function returns the corresponding + features according to the layer_name_list. + use_input_norm (bool): If True, x: [0, 1] --> (x - mean) / std. Default: True + use_range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. + requires_grad (bool): If true, the parameters of the feature extractor network will be + optimized. Default: False. + """ + super(SwinFeatureExtractor, self).__init__() + if not layer_name_list: + self.layer_name_list = ["1", "3", "5"] + else: + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = use_range_norm + + self.swin_net = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1).features + + max_idx = 0 + for i, layer in enumerate(self.swin_net._modules.keys()): + if layer in self.layer_name_list: + max_idx = i + self.swin_net = self.swin_net[:max_idx + 1] + + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + self.register_buffer('mean', mean) + self.register_buffer('std', std) + + if not requires_grad: + self.swin_net.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor) -> dict: + """Forward function. + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + Returns: + dict[str, Tensor]: Output features. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.swin_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = rearrange(x.clone(), 'b h w c -> b c h w') + + return output \ No newline at end of file diff --git a/models/swin_transformer_3d.py b/models/swin_transformer_3d.py new file mode 100644 index 0000000..5db87fc --- /dev/null +++ b/models/swin_transformer_3d.py @@ -0,0 +1,382 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from typing import Tuple +from einops import rearrange + +from utils.utils_models import (compute_mask_3D, window_partition_3D, window_reverse_3D, get_window_size, DropPath, Mlp, + trunc_normal_) + + +class PatchMerging(nn.Module): + """ + Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward function + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class PatchExpand(nn.Module): + """ + Patch Expand Layer + + Args: + embed_dim (int): Embedding dimension. + """ + def __init__(self, embed_dim: int): + super().__init__() + self.before_conv = nn.Conv2d(embed_dim, embed_dim * 2, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.after_conv = nn.Conv2d(embed_dim // 2, embed_dim // 2, 3, 1, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = x.shape + x = rearrange(x, 'b c t h w -> b t c h w').reshape(B * T, C, H, W) + x = self.before_conv(x) + x = self.pixel_shuffle(x) + x = self.after_conv(self.lrelu(x)) + _, C, H, W = x.shape + x = rearrange(x.reshape(B, T, C, H, W), 'b t c h w -> b c t h w') + return x + +class WindowAttention3D(nn.Module): + """ + Window based 3D multi-head self attention (W-MSA) module with relative position bias. + It supports both shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim: int, + window_size: Tuple[int], + num_heads: int, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ Forward function. + + Args: + x (torch.Tensor): input features with shape of (num_windows*B, N, C) + mask (torch.Tensor): (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. Default: False. + """ + + def __init__(self, dim: int, + num_heads: int, + window_size: Tuple[int] = (2, 7, 7), + shift_size: Tuple[int] = (0, 0, 0), + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + use_checkpoint: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward_part1(self, x: torch.Tensor, mask_matrix: torch.Tensor) -> torch.Tensor: + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition_3D(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse_3D(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x: torch.Tensor) -> torch.Tensor: + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x: torch.Tensor, mask_matrix: torch.Tensor) -> torch.Tensor: + """ Forward function. + Args: + x (torch.Tensor): Input feature, tensor size (B, D, H, W, C). + mask_matrix (torch.Tensor): Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + +class SwinTransformer3DLayer(nn.Module): + """ A basic Swin Transformer 3D layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + sampling_operation (str | None, optional): Downsampling/upsampling operation at the end of the layer. Default: None + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim: int, + depth: int, + num_heads: int, + window_size: Tuple[int] = (1, 7, 7), + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + sampling_operation: str = None, + use_checkpoint: bool = False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint) + for i in range(depth)]) + + self.last_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + if sampling_operation is None: + self.sampling_operation = None + elif sampling_operation == "upsample": + self.sampling_operation = PatchExpand(embed_dim=dim) + elif sampling_operation == "downsample": + self.sampling_operation = PatchMerging(dim=dim, norm_layer=norm_layer) + else: + raise NotImplementedError("Unsupported sampling operation.") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward function. + Args: + x (torch.Tensor): Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + residual = x.clone() + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask_3D(Dp, Hp, Wp, window_size, shift_size, x.device) + + for blk in self.blocks: + x = blk(x, attn_mask) + + x = rearrange(x, 'b d h w c -> b d c h w').reshape(B*D, C, H, W) + x = self.last_conv(x) + x = rearrange(x.reshape(B, D, C, H, W), 'b d c h w -> b d h w c') + x = x + residual + + if isinstance(self.sampling_operation, PatchExpand): + x = rearrange(x, 'b t h w c -> b c t h w') + x = self.sampling_operation(x) + x = rearrange(x, 'b c t h w -> b t h w c') + elif isinstance(self.sampling_operation, PatchMerging): + x = self.sampling_operation(x) + x = rearrange(x, 'b t h w c -> b c t h w') + return x diff --git a/models/swin_unet.py b/models/swin_unet.py new file mode 100644 index 0000000..9a2abea --- /dev/null +++ b/models/swin_unet.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from typing import Tuple, List +from einops import rearrange + +from utils.utils_models import trunc_normal_ +from models.swin_feature_extractor import SwinFeatureExtractor +from models.swin_transformer_3d import SwinTransformer3DLayer +from models.mrsff import MRSFFLayer + + +class SwinUNet(nn.Module): + """ + Swin-UNet network for analog video restoration presented in the paper https://arxiv.org/abs/2310.14926. + The network is composed of a Swin Transformer encoder and a Swin Transformer decoder with MRSFF blocks. + The network takes as input a window of T input frames and a window of D reference frames. The output is the restored + window of input frames. + + Args: + in_chans (int): Number of input channels. Default: 3 + embed_dim (int): Dimension of the token embeddings. Default: 96 + depths (List[int]): Depths of the Swin Transformer layers. Default: None. If None, use [2, 2, 6, 2]. + num_heads (List[int]): Number of attention heads for each layer. Default: None. If None, use [8, 8, 8, 8]. + window_size (Tuple[int]): Window size for each layer. Default: (2, 8, 8). + mlp_ratio (float): Ratio of the mlp hidden dimension to the embedding dimension. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + use_checkpoint (bool): If True, use gradient checkpointing to save memory. Default: False. + """ + def __init__(self, + in_chans: int = 3, + embed_dim: int = 96, + depths: List[int] = None, + num_heads: List[int] = None, + window_size: Tuple[int] = (2, 8, 8), + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.2, + norm_layer: nn.Module = nn.LayerNorm, + use_checkpoint: bool = False): + + super(SwinUNet, self).__init__() + if num_heads is None: + num_heads = [8, 8, 8, 8] + if depths is None: + depths = [2, 2, 6, 2] + self.embed_dim = embed_dim + + self.conv_input = nn.Conv2d(in_chans, embed_dim, kernel_size=3, stride=2, padding=1) + self.conv_output = nn.Conv2d(embed_dim // 2, in_chans, kernel_size=3, stride=1, padding=1) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.num_layers = len(depths) + + self.encoding_layers = nn.ModuleList() + for i_layer in range(0, self.num_layers - 1): + layer = SwinTransformer3DLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + sampling_operation="downsample", + use_checkpoint=use_checkpoint) + self.encoding_layers.append(layer) + + self.decoding_layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = SwinTransformer3DLayer( + dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), + depth=depths[self.num_layers - 1 - i_layer], + num_heads=num_heads[self.num_layers - 1 - i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:self.num_layers - 1 - i_layer]):sum(depths[:self.num_layers - 1 - i_layer + 1])], + norm_layer=norm_layer, + sampling_operation="upsample", + use_checkpoint=use_checkpoint) + self.decoding_layers.append(layer) + + self.mrsff_layers = nn.ModuleList() + for i_layer in range(0, self.num_layers - 1): + layer = MRSFFLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[self.num_layers - 1 - i_layer], + num_heads=num_heads[self.num_layers - 1 - i_layer], + window_size=window_size[1:], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:self.num_layers - 1 - i_layer]):sum(depths[:self.num_layers - 1 - i_layer + 1])], + norm_layer=norm_layer) + self.mrsff_layers.append(layer) + + ref_feature_extractor_layers = ["1", "3", "5"] + self.ref_feature_extractor = SwinFeatureExtractor(layer_name_list=ref_feature_extractor_layers, + use_input_norm=True, use_range_norm=False, requires_grad=False) + self.ref_feature_extractor_conv = nn.ModuleList() + for i, layer in enumerate(ref_feature_extractor_layers): + self.ref_feature_extractor_conv.append(nn.Sequential(nn.Conv2d(embed_dim * 2 ** i, embed_dim * 2 ** i * 4, 3, 1, 1), + nn.PixelShuffle(2))) + self.apply(self._init_weights) + + def forward_encoding(self, imgs_lq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, C, H, W = imgs_lq.shape + restored = rearrange(imgs_lq, 'b t c h w -> (b t) c h w') + restored = self.conv_input(restored) + restored = rearrange(restored, '(b t) c h w -> b t c h w', t=T) + restored = rearrange(restored, 'b t c h w -> b c t h w') + + # UNet encoder + residual = [restored] + for layer in self.encoding_layers: + restored = layer(restored.contiguous()) + residual.append(restored) + + return restored, residual + + def forward_decoding(self, restored: torch.Tensor, imgs_ref: torch.Tensor, residual: List[torch.Tensor]) -> torch.Tensor: + # Extract features from reference frames + _, M, _, _, _ = imgs_ref.shape + imgs_ref = rearrange(imgs_ref, 'b m c h w -> (b m) c h w') + with torch.no_grad(): + feat_ref = list(self.ref_feature_extractor(imgs_ref).values()) + for i in range(len(feat_ref)): + feat_ref[i] = self.ref_feature_extractor_conv[i](feat_ref[i]) + feat_ref[i] = rearrange(feat_ref[i], '(b m) c h w -> b m c h w', m=M) + feat_ref[i] = rearrange(feat_ref[i], 'b m c h w -> b c m h w') + + # UNet decoder + B, _, T, _, _ = restored.shape + for i, layer in enumerate(self.decoding_layers): + if i == 0: + restored = layer(restored) # Bottleneck layer + else: + restored += residual[-1 - i] # Encoder-decoder skip connection + restored_ref = self.mrsff_layers[-i](restored, feat_ref[-i]) # Combine restored and reference features + restored += restored_ref # MRSFF skip connection + restored = layer(restored) # Decoder layer + + restored = rearrange(restored, 'b c t h w -> b t c h w') + B, T, C, H, W = restored.shape + restored = self.conv_output(restored.reshape(B * T, C, H, W)) + restored = restored.reshape(B, T, -1, H, W) + return restored + + def forward(self, imgs_lq: torch.Tensor, imgs_ref: torch.Tensor) -> torch.Tensor: + """ + Forward function. + + Args: + imgs_lq (Tensor): Input frames with shape (b, t, c, h, w). + imgs_ref (Tensor): Reference frames with shape (b, d, c, h, w). + """ + out = imgs_lq.clone() + restored, residual = self.forward_encoding(imgs_lq) + restored = self.forward_decoding(restored, imgs_ref, residual) + return out + restored + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) diff --git a/real_world_inference.py b/real_world_inference.py new file mode 100644 index 0000000..955d5b0 --- /dev/null +++ b/real_world_inference.py @@ -0,0 +1,235 @@ +import torch +import torch.nn.functional as F +import numpy as np +from argparse import ArgumentParser +import json +from pathlib import Path +import cv2 +from tqdm import tqdm +import clip +from PIL import Image +from skimage.filters import threshold_otsu +import torchvision +import shutil + +from utils.prompts import prompts +from data.RealWorldVideoDataset import RealWorldVideoDataset +from models.swin_unet import SwinUNet + + +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + + +def real_world_test(args): + """ + Restore a real-world video (i.e. without ground truth) using the pretrained model. + """ + + input_video_name = args.input_path.stem + output_folder = args.output_path / input_video_name + output_folder.mkdir(parents=True, exist_ok=False) + output_folder.mkdir(parents=True, exist_ok=True) + input_frames_folder = output_folder / "input_frames" + input_frames_folder.mkdir(parents=True, exist_ok=True) + restored_frames_folder = output_folder / "restored_frames" + restored_frames_folder.mkdir(parents=True, exist_ok=True) + references_file_path = output_folder / "references.json" + + ### 1) Frames extraction + print("Extracting frames from the video...") + input_video = cv2.VideoCapture(str(args.input_path)) + fps = input_video.get(cv2.CAP_PROP_FPS) + frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frame_count = int(input_video.get(cv2.CAP_PROP_FRAME_COUNT)) + + for i in tqdm(range(frame_count)): + success, frame = input_video.read() + if not success: + raise Exception("Failed to read frame from video") + padded_i = str(i).zfill(len(str(frame_count))) # Pad to a number of digits large enough to contain the total number of frames + cv2.imwrite(str(input_frames_folder / f"{padded_i}.{args.frame_format}"), frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + input_video.release() + + ### 2) Frame classification and references selection + print("Classifying frames and selecting references...") + clip_model, clip_preprocess = clip.load("RN50x4", device=device, jit=True) + output = {} + + # Extract text features using prompt ensembling + with torch.no_grad(), torch.cuda.amp.autocast(): + tokenized_prompts = clip.tokenize(prompts).to(device) + text_features = F.normalize(clip_model.encode_text(tokenized_prompts), dim=-1) + text_features = F.normalize(text_features.mean(dim=0), dim=-1).unsqueeze(0) # Prompt ensembling + + # Extract image features and compute similarity scores + img_features = [] + img_names = [] + similarity_scores = [] + for img_path in tqdm(sorted(list(input_frames_folder.glob("*"))), desc="Extracting CLIP image features"): + img_names.append(img_path.name) + img = Image.open(img_path) + preprocessed_img = clip_preprocess(img).to(device) + with torch.no_grad(), torch.cuda.amp.autocast(): + img_feat = F.normalize(clip_model.encode_image(preprocessed_img.unsqueeze(0)), dim=-1) + sim_score = img_feat @ text_features.T + img_features.append(img_feat.cpu()) + similarity_scores.append(sim_score.cpu().item()) + + img_names = np.array(img_names) + img_features = torch.cat(img_features, dim=0) + + # Classify frames + similarity_scores = np.array(similarity_scores) + sorted_similarity_scores = np.sort(similarity_scores) + threshold = threshold_otsu(sorted_similarity_scores) + threshold_index = sorted_similarity_scores.searchsorted(threshold) + indexes = np.argsort(similarity_scores)[:threshold_index] # Indexes of clean frames + + # Select references + for i, img_feat in enumerate(tqdm(img_features, desc="Selecting references")): + similarity = F.cosine_similarity(img_feat.unsqueeze(0), img_features[indexes], dim=-1) + similarity_indexes = torch.argsort(similarity, descending=True) + similarity_indexes = similarity_indexes[:args.num_reference_frames].numpy() + similar_imgs = img_names[similarity_indexes].tolist() + while len(similar_imgs) < args.num_reference_frames: # Pad with the first image if there are not enough similar images + similar_imgs.append(similar_imgs[0]) + output[img_names[i]] = similar_imgs + + # Save references + with open(references_file_path, 'w') as f: + json.dump(output, f) + + # Free memory + del clip_model + del text_features + del img_feat + torch.cuda.empty_cache() + + ### 3) Video restoration + print("Restoring the video...") + dataset = RealWorldVideoDataset(input_frames_folder, num_input_frames=args.num_input_frames, + num_reference_frames=args.num_reference_frames, + references_file_path=references_file_path, preprocess_mode=args.preprocess_mode, + patch_size=args.patch_size, frame_format=args.frame_format) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, + shuffle=False, pin_memory=True, drop_last=False) + + if args.preprocess_mode != "none" and (frame_width > args.patch_size or frame_height > args.patch_size): + if args.preprocess_mode == "crop": + new_frame_width = min(frame_width, args.patch_size) + new_frame_height = min(frame_height, args.patch_size) + elif args.preprocess_mode == "resize": + if frame_height > frame_height: + new_frame_height = args.patch_size + new_frame_width = int(frame_width * args.patch_size / frame_height) + else: + new_frame_width = args.patch_size + new_frame_height = int(frame_height * args.patch_size / frame_width) + else: + raise ValueError(f"Unknown preprocess mode: {args.preprocess_mode}") + else: + new_frame_width = frame_width + new_frame_height = frame_height + + output_video = cv2.VideoWriter(str(output_folder / f"restored_{input_video_name}.mp4"), + cv2.VideoWriter_fourcc(*'mp4v'), fps, (new_frame_width, new_frame_height)) + if args.generate_combined_video: + combined_output_video = cv2.VideoWriter(str(output_folder / f"combined_{input_video_name}.mp4"), + cv2.VideoWriter_fourcc(*'mp4v'), fps, (new_frame_width * 2, new_frame_height)) + else: + combined_output_video = None + + # Load model + model = SwinUNet() + state_dict = torch.load(args.checkpoint_path, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + model = model.eval().to(device) + + for batch in tqdm(dataloader, desc="Restoring frames"): + imgs_lq = batch["imgs_lq"] + imgs_ref = batch["imgs_ref"] + img_names = batch["img_name"] + + # Image size must be divisible by 16 (due to the 4 downsampling operations) + h, w = imgs_lq.shape[-2:] + pad_width = (16 - (w % 16)) % 16 + pad_height = (16 - (h % 16)) % 16 + pad = (0, pad_width, 0, pad_height) + imgs_lq = F.pad(imgs_lq, pad=pad, mode="constant", value=0).to(device) + imgs_ref = F.pad(imgs_ref, pad=pad, mode="constant", value=0).to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(): + output = model(imgs_lq, imgs_ref) + output = torch.clamp(output, min=0, max=1) + + for i, img_name in enumerate(img_names): + img_num = int(img_name[:-4]) + restored_frame = output[i, args.num_input_frames // 2] + restored_frame = torchvision.transforms.functional.crop(restored_frame, top=0, left=0, height=h, width=w) + restored_frame = restored_frame.cpu().numpy().transpose(1, 2, 0) * 255 + restored_frame = cv2.cvtColor(restored_frame, cv2.COLOR_RGB2BGR).astype(np.uint8) + cv2.imwrite(str(restored_frames_folder / f"{img_num}.{args.frame_format}"), restored_frame) + + # Reconstruct the video + output_video.write(restored_frame) + if args.generate_combined_video: + input_frame = imgs_lq[i, args.num_input_frames // 2] + input_frame = torchvision.transforms.functional.crop(input_frame, top=0, left=0, height=h, width=w) + input_frame = input_frame.cpu().numpy().transpose(1, 2, 0) * 255 + input_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR).astype(np.uint8) + combined_frame = np.concatenate((input_frame, restored_frame), axis=1) + combined_output_video.write(combined_frame) + + output_video.release() + if args.generate_combined_video: + combined_output_video.release() + + # Free memory + del model + del imgs_lq + del imgs_ref + torch.cuda.empty_cache() + + if args.no_intermediate_products: + print("Deleting intermediate products...") + (output_folder / f"restored_{input_video_name}.mp4").rename(Path(args.output_path) / f"restored_{input_video_name}.mp4") + if args.generate_combined_video: + (output_folder / f"combined_{input_video_name}.mp4").rename(Path(args.output_path) / f"combined_{input_video_name}.mp4") + shutil.rmtree(output_folder) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument("--input-path", type=str, required=True, help="Path to the video to restore") + parser.add_argument("--output-path", type=str, required=True, help="Path to the output folder") + parser.add_argument("--checkpoint-path", type=str, default="experiments/pretrained_model/checkpoint.pth", + help="Path to the pretrained model checkpoint") + parser.add_argument("--num-input-frames", type=int, default=5, + help="Number of input frames T for each input window") + parser.add_argument("--num-reference-frames", type=int, default=5, + help="Number of reference frames D for each input window") + parser.add_argument("--preprocess-mode", type=str, default="crop", choices=["crop", "resize", "none"], + help="Preprocessing mode, options: ['crop', 'resize', 'none']. 'crop' extracts the --patch-size" + " center crop, 'resize' resizes the longest side to --patch-size while keeping the aspect" + " ratio, 'none' applies no preprocessing") + parser.add_argument("--patch-size", type=int, default=512, + help="Maximum patch size for --preprocess-mode ['crop', 'resize']") + parser.add_argument("--frame-format", type=str, default="jpg", + help="Frame format of the extracted and restored frames") + parser.add_argument("--generate-combined-video", action="store_true", + help="Whether to generate the combined video (i.e. input and restored videos side by side)") + parser.add_argument("--no-intermediate-products", action="store_true", + help="Whether to delete intermediate products (i.e. input frames, restored frames, references)") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + parser.add_argument("--num-workers", type=int, default=20, help="Number of workers of the data loader") + + args = parser.parse_args() + + args.input_path = Path(args.input_path) + args.output_path = Path(args.output_path) + real_world_test(args) diff --git a/utils/prompts.py b/utils/prompts.py new file mode 100644 index 0000000..a4f4f82 --- /dev/null +++ b/utils/prompts.py @@ -0,0 +1,14 @@ +prompts = [ + 'an image with interlacing artifacts', + 'an image of a degraded photo', + 'a photo with distortions', + 'an image with color artifacts along rows', + 'an image of a noisy photo', + 'an image of a bad photo', + 'a jpeg corrupted image of a photo', + 'a pixelated image of a photo', + 'a blurry image of a photo', + 'a jpeg corrupted photo', + 'a pixelated photo', + 'a blurry photo' +] diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..e26c09e --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,69 @@ +import torch +from typing import Union, List + + +def preprocess(imgs: Union[List[torch.Tensor], torch.Tensor], mode: str = "crop", patch_size: int = 768)\ + -> Union[List[torch.Tensor], torch.Tensor]: + """Preprocesses a tensor of images or list of tensors of images. + + Args: + imgs (Union[List[torch.Tensor], torch.Tensor]): List of tensors of images or a single tensor of images. + mode (str, optional): Preprocess mode. Values can be in ["crop", "resize"]. + patch_size (int, optional): Maximum patch size + + Returns: + Union[List[torch.Tensor], torch.Tensor]: Preprocessed images. + """ + if isinstance(imgs, list): + return [preprocess(img, mode=mode, patch_size=patch_size) for img in imgs] + elif isinstance(imgs, torch.Tensor): + if mode == "crop": + return crop(imgs, patch_size=patch_size) + elif mode == "resize": + return resize(imgs, patch_size=patch_size) + else: + raise ValueError(f"Unknown preprocess mode: {mode}") + else: + raise TypeError(f"Unknown type for imgs: {type(imgs)}") + + +def crop(img: torch.Tensor, patch_size: int = 768) -> torch.Tensor: + """Center crops a tensor of images to patch_size. + + Args: + img (torch.Tensor): Tensor of images. + patch_size (int, optional): Maximum patch size + + Returns: + torch.Tensor: Cropped images. + """ + _, _, h, w = img.shape + if h > patch_size or w > patch_size: + h_start = max((h - patch_size) // 2, 0) + w_start = max((w - patch_size) // 2, 0) + return img[:, :, h_start:h_start + patch_size, w_start:w_start + patch_size] + else: + return img + + +def resize(img: torch.Tensor, patch_size: int = 768) -> torch.Tensor: + """Resizes a tensor of images so that the biggest dimension is equal to patch_size while keeping the aspect ratio. + + Args: + img (torch.Tensor): Tensor of images. + patch_size (int, optional): Maximum patch size + + Returns: + torch.Tensor: Resized images. + """ + _, _, h, w = img.shape + if h > patch_size or w > patch_size: + if h > w: + new_h = patch_size + new_w = int(w * patch_size / h) + else: + new_w = patch_size + new_h = int(h * patch_size / w) + return torch.nn.functional.interpolate(img, size=(new_h, new_w), mode="bilinear") + else: + return img diff --git a/utils/utils_models.py b/utils/utils_models.py new file mode 100644 index 0000000..fd8257a --- /dev/null +++ b/utils/utils_models.py @@ -0,0 +1,244 @@ +import torch +import torch.nn as nn +from typing import Tuple +import warnings +import math +from functools import reduce +from operator import mul + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, + act_layer: nn.Module = nn.GELU, drop: float = 0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def compute_mask_3D(D: int, H: int, W: int, window_size: Tuple[int], shift_size: Tuple[int], device: torch.device)\ + -> torch.Tensor: + """ + Compute 3D mask for window-based multi-head self-attention + """ + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition_3D(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +def window_partition_3D(x: torch.Tensor, window_size: Tuple[int]) -> torch.Tensor: + """ Partition the input into windows. Attention will be conducted within the windows. + From https://github.com/JingyunLiang/VRT/blob/main/models/network_vrt.py + + Args: + x (torch.Tensor): (B, D, H, W, C) + window_size (tuple[int]): window size + Returns: + windows (torch.Tensor): (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], + window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + + return windows + + +def window_reverse_3D(windows: torch.Tensor, window_size: Tuple[int], B: int, D: int, H: int, W: int) -> torch.Tensor: + """ Reverse windows back to the original input. Attention was conducted within the windows. + From https://github.com/JingyunLiang/VRT/blob/main/models/network_vrt.py + Args: + windows (torch.Tensor): (B*num_windows, window_size, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + Returns: + x (torch.Tensor): (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + + return x + + +def window_partition_2D(x: torch.Tensor, window_size: Tuple[int]) -> torch.Tensor: + """ Partition the input into windows. Attention will be conducted within the windows. + Args: + x (torch.Tensor): (B, H, W, C) + window_size (tuple[int]): window size + Returns: + windows (torch.Tensor): (num_windows*B, window_size*window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse_2D(windows: torch.Tensor, window_size: Tuple[int], B: int, H: int, W: int) -> torch.Tensor: + """ + Args: + windows (torch.Tensor): (num_windows*B, window_size, window_size, C) + window_size (tuple[int]): Window size + B (int): Batch size + H (int): Height of image + W (int): Width of image + Returns: + x (torch.Tensor): (B, H, W, C) + """ + x = windows.view(B, H // window_size[0], W // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def compute_mask_2D(H: int, W: int, window_size: Tuple[int], shift_size: Tuple[int], device: torch.device) -> torch.Tensor: + """ + Compute 2D mask for window-based multi-head self-attention + """ + img_mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1 + h_slices = (slice(-window_size[0]), + slice(-window_size[0], -shift_size[0]), + slice(-shift_size[0], None)) + w_slices = (slice(-window_size[1]), + slice(-window_size[1], -shift_size[1]), + slice(-shift_size[1], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition_2D(img_mask, window_size) # nW, window_size[0]*window_size[1], 1 + mask_windows = mask_windows.squeeze(-1) # nW, window_size[0]*window_size[1] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +def get_window_size(x_size: Tuple[int], window_size: Tuple[int], shift_size: Tuple[int] = None)\ + -> Tuple[int] | Tuple[Tuple[int]]: + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + w = torch.empty(3, 5) + nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) \ No newline at end of file