Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video Transformer Network (https://arxiv.org/abs/2102.00719) #388

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ We provided original pretrained models from Caffe2 on heavy models (testing Caff
| X3D | M | - | 16 x 5 | 75.1 | 76.2 | 3.8 | 4.73 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/x3d_models/x3d_m.pyth) | Kinetics/X3D_M |
| X3D | L | - | 16 x 5 | 76.9 | 77.5 | 6.2 | 18.37 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/x3d_models/x3d_l.pyth) | Kinetics/X3D_L |

## VTN model (details in projects/vtn)

| architecture | backbone | pretrain | frame length x sample rate | top1 | top5 | model | config |
| :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | ------------- | ------------- |
| VTN | ViT-B | ImageNet-21K | - | 77.72 | 93.24 | [`link`](https://researchpublic.blob.core.windows.net/vtn/VTN_VIT_B_KINETICS.pyth) | Kinetics/VIT_B_VTN |

## AVA

| architecture | depth | Pretrain Model | frame length x sample rate | MAP | AVA version | model |
Expand Down Expand Up @@ -67,4 +73,4 @@ We also release the imagenet pretrained model if finetuning from ImageNet is pre

| architecture | depth | Top1 | Top5 | model |
| ------------- | ------------- | ------------- | ------------- | ------------- |
| ResNet | R50 | 23.6 | 6.8 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/R50_IN1K.pyth) |
| ResNet | R50 | 23.6 | 6.8 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/R50_IN1K.pyth) |
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ PySlowFast is an open source video understanding codebase from FAIR that provide
- [Non-local Neural Networks](https://arxiv.org/abs/1711.07971)
- [A Multigrid Method for Efficiently Training Video Models](https://arxiv.org/abs/1912.00998)
- [X3D: Progressive Network Expansion for Efficient Video Recognition](https://arxiv.org/abs/2004.04730)
- [Video Transformer Network](https://arxiv.org/abs/2102.00719)

<div align="center">
<img src="demo/ava_demo.gif" width="600px"/>
Expand All @@ -21,8 +22,10 @@ The goal of PySlowFast is to provide a high-performance, light-weight pytorch co
- I3D
- Non-local Network
- X3D
- VTN

## Updates
- We now support [VTN Model](https://arxiv.org/abs/2102.00719). See [`projects/vtn`](./projects/vtn/README.md) for more information.
- We now support [X3D Models](https://arxiv.org/abs/2004.04730). See [`projects/x3d`](./projects/x3d/README.md) for more information.
- We now support [Multigrid Training](https://arxiv.org/abs/1912.00998) for efficiently training video models. See [`projects/multigrid`](./projects/multigrid/README.md) for more information.
- PySlowFast is released in conjunction with our [ICCV 2019 Tutorial](https://alexander-kirillov.github.io/tutorials/visual-recognition-iccv19/).
Expand Down
60 changes: 60 additions & 0 deletions configs/Kinetics/VIT_B_VTN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
TRAIN:
ENABLE: True
DATASET: kinetics
BATCH_SIZE: 16
EVAL_PERIOD: 1
CHECKPOINT_PERIOD: 1
AUTO_RESUME: True
EVAL_FULL_VIDEO: True
EVAL_NUM_FRAMES: 250
DATA:
NUM_FRAMES: 16
SAMPLING_RATE: 8
TARGET_FPS: 25
TRAIN_JITTER_SCALES: [256, 320]
TRAIN_CROP_SIZE: 224
TEST_CROP_SIZE: 224
INPUT_CHANNEL_NUM: [3]
SOLVER:
BASE_LR: 0.001
LR_POLICY: steps_with_relative_lrs
STEPS: [0, 13, 24]
LRS: [1, 0.1, 0.01]
MAX_EPOCH: 25
MOMENTUM: 0.9
OPTIMIZING_METHOD: sgd
MODEL:
NUM_CLASSES: 400
ARCH: VIT
MODEL_NAME: VTN
LOSS_FUNC: cross_entropy
DROPOUT_RATE: 0.5
VTN:
PRETRAINED: True
MLP_DIM: 768
DROP_PATH_RATE: 0.0
DROP_RATE: 0.0
HIDDEN_DIM: 768
MAX_POSITION_EMBEDDINGS: 288
NUM_ATTENTION_HEADS: 12
NUM_HIDDEN_LAYERS: 3
ATTENTION_MODE: 'sliding_chunks'
PAD_TOKEN_ID: -1
ATTENTION_WINDOW: [18, 18, 18]
INTERMEDIATE_SIZE: 3072
ATTENTION_PROBS_DROPOUT_PROB: 0.1
HIDDEN_DROPOUT_PROB: 0.1
TEST:
ENABLE: True
DATASET: kinetics
BATCH_SIZE: 16
NUM_ENSEMBLE_VIEWS: 1
NUM_SPATIAL_CROPS: 1
DATA_LOADER:
NUM_WORKERS: 8
PIN_MEMORY: True
NUM_GPUS: 4
NUM_SHARDS: 1
RNG_SEED: 0
OUTPUT_DIR: .
LOG_MODEL_INFO: False
70 changes: 70 additions & 0 deletions projects/vtn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Video Transformer Network
Daniel Neimark, Omri Bar, Maya Zohar, Dotan Asselmann [[Paper](https://arxiv.org/abs/2102.00719)]

<div align="center">
<img src="fig/arch.png" width="400px" />
<img src="fig/vtn_demo.gif" width="400px" />
</div>
<br/>


## Installation
```
pip install timm
pip install transformers[torch]
```

## Getting started
To use VTN models please refer to the configs under `configs/Kinetics`, or see
the [MODEL_ZOO.md](https://github.com/facebookresearch/SlowFast/blob/master/MODEL_ZOO.md)
for pre-trained models*.

To train ViT-B-VTN on your dataset (see [paper](https://arxiv.org/abs/2102.00719) for details):
```
python tools/run_net.py \
--cfg configs/Kinetics/VIT_B_VTN.yaml \
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
```

To test the trained ViT-B-VTN on Kinetics-400 dataset:
```
python tools/run_net.py \
--cfg configs/Kinetics/VIT_B_VTN.yaml \
DATA.PATH_TO_DATA_DIR path_to_kinetics_dataset \
TRAIN.ENABLE False \
TEST.CHECKPOINT_FILE_PATH path_to_model \
TEST.CHECKPOINT_TYPE pytorch
```

\* VTN models in [MODEL_ZOO.md](https://github.com/facebookresearch/SlowFast/blob/master/MODEL_ZOO.md) produce slightly
different results than those reported in the paper due to differences between the PySlowFast code base and the
original code used to train the models (mainly around data and video loading).

## Citing VTN
If you find VTN useful for your research, please consider citing the paper using the following BibTeX entry.
```BibTeX
@article{neimark2021video,
title={Video Transformer Network},
author={Neimark, Daniel and Bar, Omri and Zohar, Maya and Asselmann, Dotan},
journal={arXiv preprint arXiv:2102.00719},
year={2021}
}
```


## Additional Qualitative Results

<div align="center">
<img src="fig/a.png" width="700px" /><p>
Label: Tai chi. Prediction: Tai chi.<p>
<img src="fig/b.png" width="700px" /><p>
Label: Chopping wood. Prediction: Chopping wood.<p>
<img src="fig/c.png" width="700px" /><p>
Label: Archery. Prediction: Archery.<p>
<img src="fig/d.png" width="700px" /><p>
Label: Throwing discus. Prediction: Flying kite.<p>
<img src="fig/e.png" width="700px" /><p>
Label: Surfing water. Prediction: Parasailing.<p>
</div>


Binary file added projects/vtn/fig/a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/c.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/e.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added projects/vtn/fig/vtn_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 53 additions & 0 deletions slowfast/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
# If set, clear all layer names according to the pattern provided.
_C.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN = () # ("backbone.",)

# If True, will use all video's frames during evaluation
_C.TRAIN.EVAL_FULL_VIDEO = False

# In case "EVAL_FULL_VIDEO" is True, this will set the number of frames to use for the full video (250 in VTN)
_C.TRAIN.EVAL_NUM_FRAMES = None

# ---------------------------------------------------------------------------- #
# Testing options
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -254,6 +260,53 @@
# pathway.
_C.SLOWFAST.FUSION_KERNEL_SZ = 5

# -----------------------------------------------------------------------------
# VTN options
# -----------------------------------------------------------------------------
_C.VTN = CfgNode()

# ViT: if True, will load pretrained weights for the backbone.
_C.VTN.PRETRAINED = True

# ViT: stochastic depth decay rule.
_C.VTN.DROP_PATH_RATE = 0.0

# ViT: dropout ratio.
_C.VTN.DROP_RATE = 0.0

# Longformer: the size of the embedding, this is the input size of the MLP head,
# and should match the ViT output dimension.
_C.VTN.HIDDEN_DIM = 768

# Longformer: the maximum sequence length that this model might ever be used with.
_C.VTN.MAX_POSITION_EMBEDDINGS = 288

# Longformer: number of attention heads for each attention layer in the Transformer encoder.
_C.VTN.NUM_ATTENTION_HEADS = 12

# Longformer: number of hidden layers in the Transformer encoder.
_C.VTN.NUM_HIDDEN_LAYERS = 3

# Longformer: Type of self-attention: LF use 'sliding_chunks' to process with a sliding window
_C.VTN.ATTENTION_MODE = 'sliding_chunks'

# Longformer: The value used to pad input_ids.
_C.VTN.PAD_TOKEN_ID = -1

# Longformer: Size of an attention window around each token.
_C.VTN.ATTENTION_WINDOW = [18, 18, 18]

# Longformer: Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
_C.VTN.INTERMEDIATE_SIZE = 3072

# Longformer: The dropout ratio for the attention probabilities.
_C.VTN.ATTENTION_PROBS_DROPOUT_PROB = 0.1

# Longformer: The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
_C.VTN.HIDDEN_DROPOUT_PROB = 0.1

# MLP Head: the dimension of the MLP head hidden layer.
_C.VTN.MLP_DIM = 768

# -----------------------------------------------------------------------------
# Data options
Expand Down
23 changes: 16 additions & 7 deletions slowfast/datasets/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def temporal_sampling(frames, start_idx, end_idx, num_samples):
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
return frames, index


def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
Expand Down Expand Up @@ -212,7 +212,7 @@ def torchvision_decode(


def pyav_decode(
container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30
container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, force_all_video=False
):
"""
Convert the video from its original fps to the target_fps. If the video
Expand All @@ -233,6 +233,7 @@ def pyav_decode(
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps before frame sampling.
force_all_video (bool): fetch all video's frames
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
Expand All @@ -246,7 +247,7 @@ def pyav_decode(
frames_length = container.streams.video[0].frames
duration = container.streams.video[0].duration

if duration is None:
if duration is None or force_all_video:
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
Expand Down Expand Up @@ -290,6 +291,7 @@ def decode(
target_fps=30,
backend="pyav",
max_spatial_scale=0,
force_all_video=False,
):
"""
Decode the video and perform temporal sampling.
Expand All @@ -313,6 +315,7 @@ def decode(
max_spatial_scale (int): keep the aspect ratio and resize the frame so
that shorter edge size is max_spatial_scale. Only used in
`torchvision` backend.
force_all_video (bool): fetch all video's frames - only supported with pyav backend
Returns:
frames (tensor): decoded frames from the video.
"""
Expand All @@ -327,6 +330,7 @@ def decode(
clip_idx,
num_clips,
target_fps,
force_all_video,
)
elif backend == "torchvision":
frames, fps, decode_all_video = torchvision_decode(
Expand All @@ -346,11 +350,11 @@ def decode(
)
except Exception as e:
print("Failed to decode by {} with exception: {}".format(backend, e))
return None
return None, None

# Return None if the frames was not decoded successfully.
if frames is None or frames.size(0) == 0:
return None
return None, None

clip_sz = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
Expand All @@ -359,6 +363,11 @@ def decode(
clip_idx if decode_all_video else 0,
num_clips if decode_all_video else 1,
)

if force_all_video:
# To avoid duplicate the last frame for videos smaller then 250 frames
end_idx = min(float(frames.shape[0]), end_idx)

# Perform temporal sampling from the decoded video.
frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
return frames
frames, frames_index = temporal_sampling(frames, start_idx, end_idx, num_frames)
return frames, frames_index
30 changes: 26 additions & 4 deletions slowfast/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def __init__(self, cfg, mode, num_retries=10):
cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
)

if self.mode in ["val", "test"] and cfg.TRAIN.EVAL_FULL_VIDEO:
# supporting full video evaluation
self.force_all_video = True
self.num_frames = self.cfg.TRAIN.EVAL_NUM_FRAMES
self.sampling_rate = 1
self._num_clips = 1
else:
self.force_all_video = False
self.num_frames = self.cfg.DATA.NUM_FRAMES
self.sampling_rate = self.cfg.DATA.SAMPLING_RATE

logger.info("Constructing Kinetics {}...".format(mode))
self._construct_loader()

Expand Down Expand Up @@ -158,6 +169,16 @@ def __getitem__(self, index):
/ self.cfg.MULTIGRID.DEFAULT_S
)
)
if self.mode in ["val"] and self.cfg.TRAIN.EVAL_FULL_VIDEO:
# supporting full video evaluation:
# spatial_sample_index=1 to take only the center
# The testing is deterministic and no jitter should be performed.
# min_scale, max_scale, and crop_size are expect to be the same.
# temporal_sample_index = -1 # this can be random - in the end we take [0,inf]
spatial_sample_index = 1
min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
crop_size = self.cfg.DATA.TEST_CROP_SIZE
elif self.mode in ["test"]:
temporal_sample_index = (
self._spatial_temporal_idx[index]
Expand Down Expand Up @@ -189,7 +210,7 @@ def __getitem__(self, index):
)
sampling_rate = utils.get_random_sampling_rate(
self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
self.cfg.DATA.SAMPLING_RATE,
self.sampling_rate,
)
# Try to decode and sample a clip from a video. If the video can not be
# decoded, repeatly find a random video replacement that can be decoded.
Expand Down Expand Up @@ -220,16 +241,17 @@ def __getitem__(self, index):
continue

# Decode video. Meta info is used to perform selective decoding.
frames = decoder.decode(
frames, frames_index = decoder.decode(
video_container,
sampling_rate,
self.cfg.DATA.NUM_FRAMES,
self.num_frames,
temporal_sample_index,
self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
video_meta=self._video_meta[index],
target_fps=self.cfg.DATA.TARGET_FPS,
backend=self.cfg.DATA.DECODING_BACKEND,
max_spatial_scale=min_scale,
force_all_video=self.force_all_video
)

# If decoding failed (wrong format, video is too short, and etc),
Expand Down Expand Up @@ -263,7 +285,7 @@ def __getitem__(self, index):
)

label = self._labels[index]
frames = utils.pack_pathway_output(self.cfg, frames)
frames = utils.pack_pathway_output(self.cfg, frames, frames_index)
return frames, label, index, {}
else:
raise RuntimeError(
Expand Down
Loading