diff --git a/ai/TextReID/.gitignore b/ai/TextReID/.gitignore new file mode 100644 index 0000000000..beb77e0954 --- /dev/null +++ b/ai/TextReID/.gitignore @@ -0,0 +1,13 @@ +.ipynb_checkpoints +*.pyc +*.ipynb +*.npy + +output/ +datasets/ +pretrained/ +__pycache__/ +condor_log/ +.cache/ +.nv/ +docker_stderror diff --git a/ai/TextReID/.pre-commit-config.yaml b/ai/TextReID/.pre-commit-config.yaml new file mode 100644 index 0000000000..600799c38d --- /dev/null +++ b/ai/TextReID/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +repos: + + - repo: https://github.com/psf/black + rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags + hooks: + - id: black + language_version: python3 # Should be a command that runs python3.6+ + + # isort + - repo: https://github.com/timothycrosley/isort + rev: 5.6.4 + hooks: + - id: isort + + # flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + args: ["--config=setup.cfg", "--ignore=W504, W503, E501, E203, E741, F821"] + + # pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace # Trim trailing whitespace + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: end-of-file-fixer # Make sure files end in a newline and only a newline + - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 + - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- + args: ["--remove"] + - id: mixed-line-ending # Replace or check mixed line ending + args: ["--fix=lf"] diff --git a/ai/TextReID/README.md b/ai/TextReID/README.md new file mode 100644 index 0000000000..0ce9868170 --- /dev/null +++ b/ai/TextReID/README.md @@ -0,0 +1,93 @@ +# Text Based Person Search with Limited Data + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/text-based-person-search-with-limited-data/nlp-based-person-retrival-on-cuhk-pedes)](https://paperswithcode.com/sota/nlp-based-person-retrival-on-cuhk-pedes?p=text-based-person-search-with-limited-data) + +This is the codebase for our [BMVC 2021 paper](https://arxiv.org/abs/2110.10807). + +Slides and video for the online presentation are now available at [BMVC 2021 virtual conference website](https://www.bmvc2021-virtualconference.com/conference/papers/paper_0044.html). + +## Updates +- (10/12/2021) Add download link of trained models. +- (06/12/2021) Code refactor for easy reproduce. +- (20/10/2021) Code released! + +## Abstract +Text-based person search (TBPS) aims at retrieving a target person from an image gallery with a descriptive text query. +Solving such a fine-grained cross-modal retrieval task is challenging, which is further hampered by the lack of large-scale datasets. +In this paper, we present a framework with two novel components to handle the problems brought by limited data. +Firstly, to fully utilize the existing small-scale benchmarking datasets for more discriminative feature learning, we introduce a cross-modal momentum contrastive learning framework to enrich the training data for a given mini-batch. Secondly, we propose to transfer knowledge learned from existing coarse-grained large-scale datasets containing image-text pairs from drastically different problem domains to compensate for the lack of TBPS training data. A transfer learning method is designed so that useful information can be transferred despite the large domain gap. Armed with these components, our method achieves new state of the art on the CUHK-PEDES dataset with significant improvements over the prior art in terms of Rank-1 and mAP. + +## Results +![image](https://user-images.githubusercontent.com/37724292/144879635-86ab9c7b-0317-4b42-ac46-a37b06853d18.png) + +## Installation +### Setup environment +```bash +conda create -n txtreid-env python=3.7 +conda activate txtreid-env +git clone https://github.com/BrandonHanx/TextReID.git +cd TextReID +pip install -r requirements.txt +pre-commit install +``` +### Get CUHK-PEDES dataset +- Request the images from [Dr. Shuang Li](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description). +- Download the pre-processed captions we provide from [Google Drive](https://drive.google.com/file/d/1V4d8OjFket5SaQmBVozFFeflNs6f9e1R/view?usp=sharing). +- Organize the dataset as following: +```bash +datasets +└── cuhkpedes + ├── annotations + │ ├── test.json + │ ├── train.json + │ └── val.json + ├── clip_vocab_vit.npy + └── imgs + ├── cam_a + ├── cam_b + ├── CUHK01 + ├── CUHK03 + ├── Market + ├── test_query + └── train_query +``` + +### Download CLIP weights +```bash +mkdir pretrained/clip/ +cd pretrained/clip +wget https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt +wget https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt +cd - + +``` + +### Train +```bash +python train_net.py \ +--config-file configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml \ +--use-tensorboard +``` +### Inference +```bash +python test_net.py \ +--config-file configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml \ +--checkpoint-file output/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048/best.pth +``` +You can download our trained models (with CLIP RN50 and RN101) from [Google Drive](https://drive.google.com/drive/folders/1MoceVsLiByg3Sg8_9yByGSvR3ru15hJL?usp=sharing). + +## TODO +- [ ] Try larger pre-trained CLIP models. +- [ ] Fix the bug of multi-gpu runninng. +- [ ] Add dataloader for [ICFG-PEDES](https://github.com/zifyloo/SSAN). + +## Citation +If you find this project useful for your research, please use the following BibTeX entry. +``` +@inproceedings{han2021textreid, + title={Text-Based Person Search with Limited Data}, + author={Han, Xiao and He, Sen and Zhang, Li and Xiang, Tao}, + booktitle={BMVC}, + year={2021} +} +``` diff --git a/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn101_ls_bs128.yaml b/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn101_ls_bs128.yaml new file mode 100644 index 0000000000..34da8555f5 --- /dev/null +++ b/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn101_ls_bs128.yaml @@ -0,0 +1,41 @@ +MODEL: + WEIGHT: "imagenet" + FREEZE: False + VISUAL_MODEL: "m_resnet101" + TEXTUAL_MODEL: "bigru" + NUM_CLASSES: 11003 + GRU: + ONEHOT: "clip_vit" + EMBEDDING_SIZE: 512 + NUM_UNITS: 512 + VOCABULARY_SIZE: 512 + DROPOUT_KEEP_PROB: 1.0 + MAX_LENGTH: 100 + RESNET: + RES5_STRIDE: 1 + EMBEDDING: + EMBED_HEAD: 'simple' + FEATURE_SIZE: 256 + DROPOUT_PROB: 0.0 + EPSILON: 0.1 +INPUT: + HEIGHT: 384 + WIDTH: 128 + USE_AUG: True + PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] + PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] +DATASETS: + TRAIN: ("cuhkpedes_train", ) + TEST: ("cuhkpedes_test", ) +SOLVER: + IMS_PER_BATCH: 128 + NUM_EPOCHS: 80 + BASE_LR: 0.0001 + WEIGHT_DECAY: 0.00004 + CHECKPOINT_PERIOD: 40 + LRSCHEDULER: 'step' + STEPS: (40, 70) + WARMUP_FACTOR: 0.1 + WARMUP_EPOCHS: 5 +TEST: + IMS_PER_BATCH: 128 diff --git a/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn50_ls_bs128.yaml b/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn50_ls_bs128.yaml new file mode 100644 index 0000000000..93690cdf5d --- /dev/null +++ b/ai/TextReID/configs/cuhkpedes/baseline_gru_cliprn50_ls_bs128.yaml @@ -0,0 +1,41 @@ +MODEL: + WEIGHT: "imagenet" + FREEZE: False + VISUAL_MODEL: "m_resnet50" + TEXTUAL_MODEL: "bigru" + NUM_CLASSES: 11003 + GRU: + ONEHOT: "clip_vit" + EMBEDDING_SIZE: 512 + NUM_UNITS: 512 + VOCABULARY_SIZE: 512 + DROPOUT_KEEP_PROB: 1.0 + MAX_LENGTH: 100 + RESNET: + RES5_STRIDE: 1 + EMBEDDING: + EMBED_HEAD: 'simple' + FEATURE_SIZE: 256 + DROPOUT_PROB: 0.0 + EPSILON: 0.1 +INPUT: + HEIGHT: 384 + WIDTH: 128 + USE_AUG: True + PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] + PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] +DATASETS: + TRAIN: ("cuhkpedes_train", ) + TEST: ("cuhkpedes_test", ) +SOLVER: + IMS_PER_BATCH: 128 + NUM_EPOCHS: 80 + BASE_LR: 0.0001 + WEIGHT_DECAY: 0.00004 + CHECKPOINT_PERIOD: 40 + LRSCHEDULER: 'step' + STEPS: (40, 70) + WARMUP_FACTOR: 0.1 + WARMUP_EPOCHS: 5 +TEST: + IMS_PER_BATCH: 128 diff --git a/ai/TextReID/configs/cuhkpedes/baseline_gru_rn50_ls_bs128.yaml b/ai/TextReID/configs/cuhkpedes/baseline_gru_rn50_ls_bs128.yaml new file mode 100644 index 0000000000..958fb9d52c --- /dev/null +++ b/ai/TextReID/configs/cuhkpedes/baseline_gru_rn50_ls_bs128.yaml @@ -0,0 +1,39 @@ +MODEL: + WEIGHT: "imagenet" + FREEZE: False + VISUAL_MODEL: "resnet50" + TEXTUAL_MODEL: "bigru" + NUM_CLASSES: 11003 + GRU: + ONEHOT: "yes" + EMBEDDING_SIZE: 512 + NUM_UNITS: 512 + VOCABULARY_SIZE: 12000 + DROPOUT_KEEP_PROB: 1.0 + MAX_LENGTH: 100 + RESNET: + RES5_STRIDE: 1 + EMBEDDING: + EMBED_HEAD: 'simple' + FEATURE_SIZE: 256 + DROPOUT_PROB: 0.0 + EPSILON: 0.1 +INPUT: + HEIGHT: 384 + WIDTH: 128 + USE_AUG: True +DATASETS: + TRAIN: ("cuhkpedes_train", ) + TEST: ("cuhkpedes_test", ) +SOLVER: + IMS_PER_BATCH: 128 + NUM_EPOCHS: 80 + BASE_LR: 0.0001 + WEIGHT_DECAY: 0.00004 + CHECKPOINT_PERIOD: 40 + LRSCHEDULER: 'step' + STEPS: (40, 70) + WARMUP_FACTOR: 0.1 + WARMUP_EPOCHS: 5 +TEST: + IMS_PER_BATCH: 128 diff --git a/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn101_ls_bs128_2048.yaml b/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn101_ls_bs128_2048.yaml new file mode 100644 index 0000000000..0dd64a7638 --- /dev/null +++ b/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn101_ls_bs128_2048.yaml @@ -0,0 +1,44 @@ +MODEL: + WEIGHT: "imagenet" + FREEZE: False + VISUAL_MODEL: "m_resnet101" + TEXTUAL_MODEL: "bigru" + NUM_CLASSES: 11003 + GRU: + ONEHOT: "clip_vit" + EMBEDDING_SIZE: 512 + NUM_UNITS: 512 + VOCABULARY_SIZE: 512 + DROPOUT_KEEP_PROB: 1.0 + MAX_LENGTH: 100 + RESNET: + RES5_STRIDE: 1 + EMBEDDING: + EMBED_HEAD: 'moco' + FEATURE_SIZE: 256 + DROPOUT_PROB: 0.0 + EPSILON: 0.1 + MOCO: + FC: False + K: 2048 +INPUT: + HEIGHT: 384 + WIDTH: 128 + USE_AUG: True + PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] + PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] +DATASETS: + TRAIN: ("cuhkpedes_train", ) + TEST: ("cuhkpedes_test", ) +SOLVER: + IMS_PER_BATCH: 128 + NUM_EPOCHS: 80 + BASE_LR: 0.0001 + WEIGHT_DECAY: 0.00004 + CHECKPOINT_PERIOD: 40 + LRSCHEDULER: 'step' + STEPS: (40, 70) + WARMUP_FACTOR: 0.1 + WARMUP_EPOCHS: 5 +TEST: + IMS_PER_BATCH: 128 diff --git a/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml b/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml new file mode 100644 index 0000000000..da4b86793a --- /dev/null +++ b/ai/TextReID/configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml @@ -0,0 +1,44 @@ +MODEL: + WEIGHT: "imagenet" + FREEZE: False + VISUAL_MODEL: "m_resnet50" + TEXTUAL_MODEL: "bigru" + NUM_CLASSES: 11003 + GRU: + ONEHOT: "clip_vit" + EMBEDDING_SIZE: 512 + NUM_UNITS: 512 + VOCABULARY_SIZE: 512 + DROPOUT_KEEP_PROB: 1.0 + MAX_LENGTH: 100 + RESNET: + RES5_STRIDE: 1 + EMBEDDING: + EMBED_HEAD: 'moco' + FEATURE_SIZE: 256 + DROPOUT_PROB: 0.0 + EPSILON: 0.1 + MOCO: + FC: False + K: 2048 +INPUT: + HEIGHT: 384 + WIDTH: 128 + USE_AUG: True + PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] + PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] +DATASETS: + TRAIN: ("cuhkpedes_train", ) + TEST: ("cuhkpedes_test", ) +SOLVER: + IMS_PER_BATCH: 128 + NUM_EPOCHS: 80 + BASE_LR: 0.0001 + WEIGHT_DECAY: 0.00004 + CHECKPOINT_PERIOD: 40 + LRSCHEDULER: 'step' + STEPS: (40, 70) + WARMUP_FACTOR: 0.1 + WARMUP_EPOCHS: 5 +TEST: + IMS_PER_BATCH: 128 diff --git a/ai/TextReID/encoding.py b/ai/TextReID/encoding.py new file mode 100644 index 0000000000..3754ec748b --- /dev/null +++ b/ai/TextReID/encoding.py @@ -0,0 +1,37 @@ +import json +import re +import random + + +def encode(query): + file_path = "./datasets/cuhkpedes/annotations/test.json" + with open(file_path, "r") as file: + data = json.load(file) + + word_dict = {} # word : encode + max_onehot = -1 + + for i in range(len(data["annotations"])): + words = re.sub(r'[^a-zA-Z0-9\s]', '', data["annotations"][i]["sentence"]) + words = words.split() + for word, onehot in zip(words, data["annotations"][i]["onehot"]): + if onehot > max_onehot: + max_onehot = onehot + if word.lower() not in word_dict.keys(): + word_dict[word.lower()] = onehot + + output = [] + query = re.sub(r'[^a-zA-Z0-9\s]', '', query) + for w in query.split(): + try: + output.append(word_dict[w.lower()]) + except KeyError as e: + print("Key %s not found in the dictionary."%{e.args[0]}) + """word_dict[max_onehot+1] = e.args[0] + word_dict[e.args[0]] = max_onehot + 1 + output.append(word_dict[w.lower()])""" + output.append("None") + max_onehot += 1 + + # print(word_dict) + return output \ No newline at end of file diff --git a/ai/TextReID/lib/__init__.py b/ai/TextReID/lib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ai/TextReID/lib/config/__init__.py b/ai/TextReID/lib/config/__init__.py new file mode 100644 index 0000000000..57123a05f6 --- /dev/null +++ b/ai/TextReID/lib/config/__init__.py @@ -0,0 +1,3 @@ +from .defaults import _C as cfg + +__all__ = ["cfg"] diff --git a/ai/TextReID/lib/config/defaults.py b/ai/TextReID/lib/config/defaults.py new file mode 100644 index 0000000000..d2dc667e97 --- /dev/null +++ b/ai/TextReID/lib/config/defaults.py @@ -0,0 +1,144 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.ROOT = "./" + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +_C.DATASETS.TRAIN = () +_C.DATASETS.TEST = () +_C.DATASETS.USE_ONEHOT = True + + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +_C.DATALOADER.IMS_PER_ID = 4 +_C.DATALOADER.EN_SAMPLER = True + + +# ----------------------------------------------------------------------------- +# Input +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +_C.INPUT.HEIGHT = 224 +_C.INPUT.WIDTH = 224 +_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] +_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] +_C.INPUT.PADDING = 10 +_C.INPUT.USE_AUG = False + + +# ----------------------------------------------------------------------------- +# Model +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +_C.MODEL.DEVICE = "cuda" +_C.MODEL.VISUAL_MODEL = "resnet50" +_C.MODEL.TEXTUAL_MODEL = "bilstm" +_C.MODEL.NUM_CLASSES = 11003 +_C.MODEL.FREEZE = False +_C.MODEL.WEIGHT = "imagenet" + + +# ----------------------------------------------------------------------------- +# MoCo +# ----------------------------------------------------------------------------- +_C.MODEL.MOCO = CN() +_C.MODEL.MOCO.K = 1024 +_C.MODEL.MOCO.M = 0.999 +_C.MODEL.MOCO.FC = True + + +# ----------------------------------------------------------------------------- +# GRU +# ----------------------------------------------------------------------------- +_C.MODEL.GRU = CN() +_C.MODEL.GRU.ONEHOT = "yes" +_C.MODEL.GRU.EMBEDDING_SIZE = 512 +_C.MODEL.GRU.NUM_UNITS = 512 +_C.MODEL.GRU.VOCABULARY_SIZE = 12000 +_C.MODEL.GRU.DROPOUT_KEEP_PROB = 0.7 +_C.MODEL.GRU.MAX_LENGTH = 100 +_C.MODEL.GRU.NUM_LAYER = 1 + + +# ----------------------------------------------------------------------------- +# Resnet +# ----------------------------------------------------------------------------- +_C.MODEL.RESNET = CN() +_C.MODEL.RESNET.RES5_STRIDE = 2 +_C.MODEL.RESNET.RES5_DILATION = 1 +_C.MODEL.RESNET.PRETRAINED = None + + +# ----------------------------------------------------------------------------- +# Embedding +# ----------------------------------------------------------------------------- +_C.MODEL.EMBEDDING = CN() +_C.MODEL.EMBEDDING.EMBED_HEAD = "simple" +_C.MODEL.EMBEDDING.FEATURE_SIZE = 512 +_C.MODEL.EMBEDDING.DROPOUT_PROB = 0.3 +_C.MODEL.EMBEDDING.EPSILON = 0.0 + + +# ----------------------------------------------------------------------------- +# Solver +# ----------------------------------------------------------------------------- +_C.SOLVER = CN() +_C.SOLVER.IMS_PER_BATCH = 16 +_C.SOLVER.NUM_EPOCHS = 100 +_C.SOLVER.CHECKPOINT_PERIOD = 1 +_C.SOLVER.EVALUATE_PERIOD = 1 + +_C.SOLVER.OPTIMIZER = "Adam" +_C.SOLVER.BASE_LR = 0.0002 +_C.SOLVER.BIAS_LR_FACTOR = 2 + +_C.SOLVER.WEIGHT_DECAY = 0.00004 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0 + +_C.SOLVER.ADAM_ALPHA = 0.9 +_C.SOLVER.ADAM_BETA = 0.999 +_C.SOLVER.SGD_MOMENTUM = 0.9 + +_C.SOLVER.LRSCHEDULER = "step" + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 3 +_C.SOLVER.WARMUP_EPOCHS = 10 +_C.SOLVER.WARMUP_METHOD = "linear" + +_C.SOLVER.GAMMA = 0.1 +_C.SOLVER.STEPS = (500,) + +_C.SOLVER.POWER = 0.9 +_C.SOLVER.TARGET_LR = 0.0001 + + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.TEST.IMS_PER_BATCH = 16 + + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # + + +# ---------------------------------------------------------------------------- # +# Precision options +# ---------------------------------------------------------------------------- # +# Precision of input, allowable: (float32, float16) +_C.DTYPE = "float32" +# Enable verbosity in apex.amp +_C.AMP_VERBOSE = False diff --git a/ai/TextReID/lib/config/paths_catalog.py b/ai/TextReID/lib/config/paths_catalog.py new file mode 100644 index 0000000000..efc5feb4e7 --- /dev/null +++ b/ai/TextReID/lib/config/paths_catalog.py @@ -0,0 +1,49 @@ +import os + + +class DatasetCatalog: + DATA_DIR = "datasets" + # DATA_DIR = "datasets/runs/detect/exp55" + """DATASETS = { + "cuhkpedes_train": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations/train.json", + }, + "cuhkpedes_val": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations/val.json", + }, + "cuhkpedes_test": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations/test.json", + }, + }""" + DATASETS = { + "cuhkpedes_train": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations2/annotations.json", + }, + "cuhkpedes_val": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations2/annotations.json", + }, + "cuhkpedes_test": { + "img_dir": "cuhkpedes", + "ann_file": "cuhkpedes/annotations2/annotations.json", + }, + } + + @staticmethod + def get(root, name): + if "cuhkpedes" in name: + data_dir = DatasetCatalog.DATA_DIR + attrs = DatasetCatalog.DATASETS[name] + args = dict( + root=os.path.join(root, data_dir, attrs["img_dir"]), + ann_file=os.path.join(root, data_dir, attrs["ann_file"]), + ) + return dict( + factory="CUHKPEDESDataset", + args=args, + ) + raise RuntimeError("Dataset not available: {}".format(name)) diff --git a/ai/TextReID/lib/data/__init__.py b/ai/TextReID/lib/data/__init__.py new file mode 100644 index 0000000000..f4d24b3b4b --- /dev/null +++ b/ai/TextReID/lib/data/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_data_loader + +__all__ = ["make_data_loader"] diff --git a/ai/TextReID/lib/data/build.py b/ai/TextReID/lib/data/build.py new file mode 100644 index 0000000000..f0bae8341b --- /dev/null +++ b/ai/TextReID/lib/data/build.py @@ -0,0 +1,115 @@ +import torch.utils.data + +from lib.config.paths_catalog import DatasetCatalog +from lib.utils.comm import get_world_size + +from . import datasets as D +from . import samplers +from .collate_batch import collate_fn +from .transforms import build_transforms + + +def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True): + if not isinstance(dataset_list, (list, tuple)): + raise RuntimeError( + "dataset_list should be a list of strings, got {}".format(dataset_list) + ) + datasets = [] + for dataset_name in dataset_list: + data = dataset_catalog.get(cfg.ROOT, dataset_name) + factory = getattr(D, data["factory"]) + args = data["args"] + args["transforms"] = transforms + + if data["factory"] == "CUHKPEDESDataset": + args["use_onehot"] = cfg.DATASETS.USE_ONEHOT + args["max_length"] = 105 + + # make dataset from factory + dataset = factory(**args) + datasets.append(dataset) + + # for testing, return a list of datasets + if not is_train: + return datasets + + # for training, concatenate all datasets into a single one + dataset = datasets[0] + if len(datasets) > 1: + dataset = D.ConcatDataset(datasets) + + return [dataset] + + +def make_data_sampler(dataset, shuffle, distributed): + if distributed: + return torch.utils.data.distributed.DistributedSampler(dataset) + if shuffle: + sampler = torch.utils.data.sampler.RandomSampler(dataset) + else: + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + return sampler + + +def make_batch_data_sampler(cfg, dataset, sampler, images_per_batch, is_train=True): + if is_train and cfg.DATALOADER.EN_SAMPLER: + batch_sampler = samplers.TripletSampler( + sampler, + dataset, + images_per_batch, + cfg.DATALOADER.IMS_PER_ID, + drop_last=True, + ) + else: + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, images_per_batch, drop_last=is_train + ) + return batch_sampler + + +def make_data_loader(cfg, is_train=True, is_distributed=False): + num_gpus = get_world_size() + if is_train: + images_per_batch = cfg.SOLVER.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( + images_per_batch, num_gpus + ) + images_per_gpu = images_per_batch // num_gpus + shuffle = True + else: + images_per_batch = cfg.TEST.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( + images_per_batch, num_gpus + ) + images_per_gpu = images_per_batch // num_gpus + shuffle = is_distributed + + dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST + + transforms = build_transforms(cfg, is_train) + + datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, is_train) + + data_loaders = [] + for dataset in datasets: + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + cfg, dataset, sampler, images_per_gpu, is_train + ) + num_workers = cfg.DATALOADER.NUM_WORKERS + data_loader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + ) + data_loaders.append(data_loader) + if is_train: + # during training, a single (possibly concatenated) data_loader is returned + assert len(data_loaders) == 1 + return data_loaders[0] + return data_loaders diff --git a/ai/TextReID/lib/data/collate_batch.py b/ai/TextReID/lib/data/collate_batch.py new file mode 100644 index 0000000000..f39f233488 --- /dev/null +++ b/ai/TextReID/lib/data/collate_batch.py @@ -0,0 +1,9 @@ +import torch + + +def collate_fn(batch): + transposed_batch = list(zip(*batch)) + images = torch.stack(transposed_batch[0]) + captions = transposed_batch[1] + img_ids = transposed_batch[2] + return images, captions, img_ids diff --git a/ai/TextReID/lib/data/metrics/__init__.py b/ai/TextReID/lib/data/metrics/__init__.py new file mode 100644 index 0000000000..c392671654 --- /dev/null +++ b/ai/TextReID/lib/data/metrics/__init__.py @@ -0,0 +1,3 @@ +from .evaluation import evaluation + +__all__ = ["evaluation"] diff --git a/ai/TextReID/lib/data/metrics/evaluation.py b/ai/TextReID/lib/data/metrics/evaluation.py new file mode 100644 index 0000000000..4b0b079016 --- /dev/null +++ b/ai/TextReID/lib/data/metrics/evaluation.py @@ -0,0 +1,222 @@ +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.transforms.functional import to_pil_image +import matplotlib.pyplot as plt +from PIL import Image + +from lib.utils.logger import table_log + +# 텐서보드 import +from torchvision.utils import make_grid +from torch.utils.tensorboard import SummaryWriter + +def rank(similarity, q_pids, g_pids, topk=[1, 5, 10], get_mAP=True): + max_rank = max(topk) + if get_mAP: + indices = torch.argsort(similarity, dim=1, descending=True) + else: + # acclerate sort with topk + _, indices = torch.topk( + similarity, k=max_rank, dim=1, largest=True, sorted=True + ) # q * topk + indices = indices.to(g_pids.device) + pred_labels = g_pids[indices] # q * k + matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k + + all_cmc = matches[:, :max_rank].cumsum(1) + all_cmc[all_cmc > 1] = 1 + all_cmc = all_cmc.float().mean(0) * 100 + all_cmc = all_cmc[topk - 1] + + if not get_mAP: + return all_cmc, indices + + num_rel = matches.sum(1) # q + tmp_cmc = matches.cumsum(1) # q * k + tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])] + tmp_cmc = torch.stack(tmp_cmc, 1) * matches + AP = tmp_cmc.sum(1) / num_rel # q + mAP = AP.mean() * 100 + return all_cmc, mAP, indices + + +def jaccard(a_list, b_list): + return float(len(set(a_list) & set(b_list))) / float(len(set(a_list) | set(b_list))) + + +def jaccard_mat(row_nn, col_nn): + jaccard_sim = np.zeros((row_nn.shape[0], col_nn.shape[0])) + # FIXME: need optimization + for i in range(row_nn.shape[0]): + for j in range(col_nn.shape[0]): + jaccard_sim[i, j] = jaccard(row_nn[i], col_nn[j]) + return torch.from_numpy(jaccard_sim) + + +def k_reciprocal(q_feats, g_feats, neighbor_num=5, alpha=0.05): + qg_sim = torch.matmul(q_feats, g_feats.t()) # q * g + gg_sim = torch.matmul(g_feats, g_feats.t()) # g * g + + qg_indices = torch.argsort(qg_sim, dim=1, descending=True) + gg_indices = torch.argsort(gg_sim, dim=1, descending=True) + + qg_nn = qg_indices[:, :neighbor_num] # q * n + gg_nn = gg_indices[:, :neighbor_num] # g * n + + jaccard_sim = jaccard_mat(qg_nn.cpu().numpy(), gg_nn.cpu().numpy()) # q * g + jaccard_sim = jaccard_sim.to(qg_sim.device) + return alpha * jaccard_sim # q * g + + +def get_unique(image_ids): + keep_idx = {} + for idx, image_id in enumerate(image_ids): + if image_id not in keep_idx.keys(): + keep_idx[image_id] = idx + return torch.tensor(list(keep_idx.values())) + + +def evaluation( + dataset, + predictions, + output_folder, + topk, + cap, + save_data=True, + rerank=True, +): + logger = logging.getLogger("PersonSearch.inference") + data_dir = os.path.join(output_folder, "inference_data.npz") + + if predictions is None: + inference_data = np.load(data_dir) + logger.info("Load inference data from {}".format(data_dir)) + image_pid = torch.tensor(inference_data["image_pid"]) + text_pid = torch.tensor(inference_data["text_pid"]) + similarity = torch.tensor(inference_data["similarity"]) + if rerank: + rvn_mat = torch.tensor(inference_data["rvn_mat"]) + rtn_mat = torch.tensor(inference_data["rtn_mat"]) + else: + image_ids, pids = [], [] + image_global, text_global = [], [] + + # FIXME: need optimization + for idx, prediction in predictions.items(): + image_id, pid = dataset.get_id_info(idx) + image_ids.append(image_id) + pids.append(pid) + image_global.append(prediction[0]) + if len(prediction) == 2: # text query를 하나만 넣었으므로, text emgedding이 없는 부분이 있을 것임 + text_global.append(prediction[1]) + + pids = list(map(int, pids)) + image_pid = torch.tensor(pids) + text_pid = torch.tensor(pids) + image_global = torch.stack(image_global, dim=0) + text_global = torch.stack(text_global, dim=0) + + keep_idx = get_unique(image_ids) + image_global = image_global[keep_idx] + image_pid = image_pid[keep_idx] + + image_global = F.normalize(image_global, p=2, dim=1) + text_global = F.normalize(text_global, p=2, dim=1) + + similarity = torch.matmul(text_global, image_global.t()) + + writer = SummaryWriter() + """# 행렬 전체에 대해 top 10 results 반환 + flatten_sim = similarity.view(-1) + top_k = 10 + sorted_indices = torch.argsort(flatten_sim, descending=True) + sorted_values = flatten_sim[sorted_indices] + print(cap[0]) + images = [] + for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]): + img, caption, idx, query = dataset.__getitem__(index) + images.append(img) + if value < 0.6: + break + print(f"Index: {index}, Similarity: {value}") + writer.add_image(f"Image", images) + writer.add_text(f"Caption", cap[0])""" + # top 10 results 반환 + for i in range(4): + sorted_indices = torch.argsort(similarity[i], descending=True) + sorted_values = similarity[i][sorted_indices] + top_k = 10 + images = [] + similarities = "" + print(cap[i]) + for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]): + image_id, pid = dataset.get_id_info(idx) + img, caption, idx, query = dataset.__getitem__(index) + images.append(img) + print(f"Index: {index}, Similarity: {value}, pid: {pid}") + similarities += str(value) + "\t" + grid_img = make_grid(images, nrow=10) + writer.add_image(f"Image Grid for Query {i}", grid_img) + writer.add_text(f"Captions for Query {i}", cap[i]) + writer.close() + + + + """ + if rerank: + rtn_mat = k_reciprocal(image_global, text_global) + rvn_mat = k_reciprocal(text_global, image_global) + + if save_data: + if not rerank: + np.savez( + data_dir, + image_pid=image_pid.cpu().numpy(), + text_pid=text_pid.cpu().numpy(), + similarity=similarity.cpu().numpy(), + ) + else: + np.savez( + data_dir, + image_pid=image_pid.cpu().numpy(), + text_pid=text_pid.cpu().numpy(), + similarity=similarity.cpu().numpy(), + rvn_mat=rvn_mat.cpu().numpy(), + rtn_mat=rtn_mat.cpu().numpy(), + ) + + topk = torch.tensor(topk) + + if rerank: + i2t_cmc, i2t_mAP, _ = rank( + similarity.t(), image_pid, text_pid, topk, get_mAP=True + ) + t2i_cmc, t2i_mAP, _ = rank(similarity, text_pid, image_pid, topk, get_mAP=True) + re_i2t_cmc, re_i2t_mAP, _ = rank( + rtn_mat + similarity.t(), image_pid, text_pid, topk, get_mAP=True + ) + re_t2i_cmc, re_t2i_mAP, _ = rank( + rvn_mat + similarity, text_pid, image_pid, topk, get_mAP=True + ) + cmc_results = torch.stack([topk, t2i_cmc, re_t2i_cmc, i2t_cmc, re_i2t_cmc]) + mAP_results = torch.stack( + [torch.zeros_like(t2i_mAP), t2i_mAP, re_t2i_mAP, i2t_mAP, re_i2t_mAP] + ).unsqueeze(-1) + results = torch.cat([cmc_results, mAP_results], dim=1) + results = results.t().cpu().numpy().tolist() + results[-1][0] = "mAP" + logger.info( + "\n" + + table_log(results, headers=["topk", "t2i", "re-t2i", "i2t", "re-i2t"]) + ) + else: + t2i_cmc, _ = rank(similarity, text_pid, image_pid, topk, get_mAP=False) + i2t_cmc, _ = rank(similarity.t(), image_pid, text_pid, topk, get_mAP=False) + results = torch.stack((topk, t2i_cmc, i2t_cmc)).t().cpu().numpy() + logger.info("\n" + table_log(results, headers=["topk", "t2i", "i2t"])) + return t2i_cmc[0] +""" \ No newline at end of file diff --git a/ai/TextReID/lib/data/samplers/__init__.py b/ai/TextReID/lib/data/samplers/__init__.py new file mode 100644 index 0000000000..f598400200 --- /dev/null +++ b/ai/TextReID/lib/data/samplers/__init__.py @@ -0,0 +1,3 @@ +from .triplet_batch_sampler import TripletSampler + +__all__ = ["TripletSampler"] diff --git a/ai/TextReID/lib/data/samplers/triplet_batch_sampler.py b/ai/TextReID/lib/data/samplers/triplet_batch_sampler.py new file mode 100644 index 0000000000..be8ca60e2e --- /dev/null +++ b/ai/TextReID/lib/data/samplers/triplet_batch_sampler.py @@ -0,0 +1,129 @@ +import copy +import math +import random +from collections import defaultdict + +import torch +from torch.utils.data.sampler import BatchSampler + + +def _split(tensor, size, dim=0, drop_last=False): + if dim < 0: + dim += tensor.dim() + dim_size = tensor.size(dim) + + if dim_size < size: + times = math.ceil(size / dim_size) + tensor = tensor.repeat_interleave(times) + dim_size = size + + split_size = size + num_splits = (dim_size + split_size - 1) // split_size + last_split_size = split_size - (split_size * num_splits - dim_size) + + def get_split_size(i): + return split_size if i < num_splits - 1 else last_split_size + + if drop_last and last_split_size != split_size: + total_num_splits = num_splits - 1 + else: + total_num_splits = num_splits + + return list( + tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) + for i in range(0, total_num_splits) + ) + + +def _merge(splits, pids, num_pids_per_batch): + avaible_pids = copy.deepcopy(pids) + merged = [] + + while len(avaible_pids) >= num_pids_per_batch: + batch = [] + selected_pids = random.sample(avaible_pids, num_pids_per_batch) + for pid in selected_pids: + batch_idxs = splits[pid].pop(0) + batch.extend(batch_idxs.tolist()) + if len(splits[pid]) == 0: + avaible_pids.remove(pid) + merged.append(batch) + return merged + + +def _map(dataset): + id_to_img_map = [] + for i in range(len(dataset)): + _, pid = dataset.get_id_info(i) + id_to_img_map.append(pid) + return id_to_img_map + + +class TripletSampler(BatchSampler): + """ + Randomly sample N identities, then for each identity, + randomly sample K instances, therefore batch size is N*K. + Args: + - data_source (list): list of (img_path, pid, camid). + - num_instances (int): number of instances per identity in a batch. + - batch_size (int): number of examples in a batch. + """ + + def __init__(self, sampler, data_source, batch_size, images_per_pid, drop_last): + super(TripletSampler, self).__init__(sampler, batch_size, drop_last) + self.num_instances = images_per_pid + self.num_pids_per_batch = batch_size // images_per_pid + self.id_to_img_map = _map(data_source) + self.index_dict = defaultdict(list) + for index, pid in enumerate(self.id_to_img_map): + self.index_dict[pid].append(index) + self.pids = list(self.index_dict.keys()) + + self.group_ids = torch.as_tensor(self.id_to_img_map) + self.groups = torch.unique(self.group_ids).sort(0)[0] + + self._can_reuse_batches = False + + def _prepare_batches(self): + dataset_size = len(self.group_ids) + sampled_ids = torch.as_tensor(list(self.sampler)) + order = torch.full((dataset_size,), -1, dtype=torch.int64) + order[sampled_ids] = torch.arange(len(sampled_ids)) + + mask = order >= 0 + clusters = [(self.group_ids == i) & mask for i in self.groups] + relative_order = [order[cluster] for cluster in clusters] + permutation_ids = [s.sort()[0] for s in relative_order] + permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] + + splits = defaultdict(list) + for idx, c in enumerate(permuted_clusters): + splits[idx] = _split(c, self.num_instances, drop_last=True) + merged = _merge(splits, self.pids, self.num_pids_per_batch) + + first_element_of_batch = [t[0] for t in merged] + inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} + first_index_of_batch = torch.as_tensor( + [inv_sampled_ids_map[s] for s in first_element_of_batch] + ) + permutation_order = first_index_of_batch.sort(0)[1].tolist() + batches = [merged[i] for i in permutation_order] + + return batches + + def __iter__(self): + if self._can_reuse_batches: + batches = self._batches + self._can_reuse_batches = False + else: + batches = self._prepare_batches() + self._batches = batches + + for batch in iter(batches): + yield batch + + def __len__(self): + if not hasattr(self, "_batches"): + self._batches = self._prepare_batches() + self._can_reuse_batches = True + return len(self._batches) diff --git a/ai/TextReID/lib/data/transforms.py b/ai/TextReID/lib/data/transforms.py new file mode 100644 index 0000000000..a0c774babf --- /dev/null +++ b/ai/TextReID/lib/data/transforms.py @@ -0,0 +1,43 @@ +import torchvision.transforms as T + + +def build_transforms(cfg, is_train=True): + height = cfg.INPUT.HEIGHT + width = cfg.INPUT.WIDTH + use_aug = cfg.INPUT.USE_AUG + + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD + ) + + if is_train: + if use_aug: + transform = T.Compose( + [ + T.Resize((height, width)), + T.RandomHorizontalFlip(0.5), + T.Pad(cfg.INPUT.PADDING), + T.RandomCrop((height, width)), + T.ToTensor(), + normalize_transform, + T.RandomErasing(scale=(0.02, 0.4), value=cfg.INPUT.PIXEL_MEAN), + ] + ) + else: + transform = T.Compose( + [ + T.Resize((height, width)), + T.RandomHorizontalFlip(0.5), + T.ToTensor(), + normalize_transform, + ] + ) + else: + transform = T.Compose( + [ + T.Resize((height, width)), + T.ToTensor(), + normalize_transform, + ] + ) + return transform diff --git a/ai/TextReID/lib/engine/__init__.py b/ai/TextReID/lib/engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ai/TextReID/lib/engine/inference.py b/ai/TextReID/lib/engine/inference.py new file mode 100644 index 0000000000..e89a15b987 --- /dev/null +++ b/ai/TextReID/lib/engine/inference.py @@ -0,0 +1,107 @@ +import datetime +import logging +import os +import time +from collections import defaultdict + +import torch +from tqdm import tqdm + +from lib.data.metrics import evaluation +from lib.utils.comm import all_gather, is_main_process, synchronize + +from lib.utils.caption import Caption + +from encoding import encode + + +def compute_on_dataset(model, data_loader, cap, device): + model.eval() + results_dict = defaultdict(list) + for batch in tqdm(data_loader): + images, captions, image_ids = batch + images = images.to(device) + # captions = [captions[0].to(device)] # 첫 번째 캡션만 사용 + caption = input("\nText Query Input: ") + cap.append(caption) + caption = encode(caption) + caption = torch.tensor(caption) + captions = [Caption([caption]).to(device)] + # captions = [caption.to(device) for caption in captions] + with torch.no_grad(): + output = model(images, captions) + for result in output: + for img_id, pred in zip(image_ids, result): + results_dict[img_id].append(pred) + return results_dict, cap + + +def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): + all_predictions = all_gather(predictions_per_gpu) + if not is_main_process(): + return + # merge the list of dicts + predictions = {} + for p in all_predictions: + predictions.update(p) + # convert a dict where the key is the index in a list + image_ids = list(sorted(predictions.keys())) + if len(image_ids) != image_ids[-1] + 1: + logger = logging.getLogger("PersonSearch.inference") + logger.warning( + "Number of images that were gathered from multiple processes is not " + "a contiguous set. Some images might be missing from the evaluation" + ) + return predictions + + +def inference( + model, + data_loader, + dataset_name="cuhkpedes-test", + device="cuda", + output_folder="", + save_data=True, + rerank=True, +): + logger = logging.getLogger("PersonSearch.inference") + dataset = data_loader.dataset + logger.info( + "Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)) + ) + + predictions = None + if not os.path.exists(os.path.join(output_folder, "inference_data.npz")): + # convert to a torch.device for efficiency + device = torch.device(device) + num_devices = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) + start_time = time.time() + + predictions, cap = compute_on_dataset(model, data_loader, [], device) + # wait for all processes to complete before measuring the time + synchronize() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=total_time)) + logger.info( + "Total inference time: {} ({} s / img per device, on {} devices)".format( + total_time_str, total_time * num_devices / len(dataset), num_devices + ) + ) + predictions = _accumulate_predictions_from_multiple_gpus(predictions) + + if not is_main_process(): + return + + return evaluation( + dataset=dataset, + predictions=predictions, + output_folder=output_folder, + save_data=save_data, + rerank=rerank, + topk=[1, 5, 10], + cap=cap, + ) diff --git a/ai/TextReID/lib/engine/trainer.py b/ai/TextReID/lib/engine/trainer.py new file mode 100644 index 0000000000..dd017b1f8f --- /dev/null +++ b/ai/TextReID/lib/engine/trainer.py @@ -0,0 +1,139 @@ +import datetime +import logging +import time + +import torch +import torch.distributed as dist + +from lib.utils.comm import get_world_size + +from .inference import inference + + +def reduce_loss_dict(loss_dict): + """ + Reduce the loss dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + loss_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return loss_dict + with torch.no_grad(): + loss_names = [] + all_losses = [] + for k in sorted(loss_dict.keys()): + loss_names.append(k) + all_losses.append(loss_dict[k]) + all_losses = torch.stack(all_losses, dim=0) + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} + return reduced_losses + + +def do_train( + model, + data_loader, + data_loader_val, + optimizer, + scheduler, + checkpointer, + meters, + device, + checkpoint_period, + evaluate_period, + arguments, +): + logger = logging.getLogger("PersonSearch.trainer") + logger.info("Start training") + + max_epoch = arguments["max_epoch"] + epoch = arguments["epoch"] + max_iter = max_epoch * len(data_loader) + iteration = arguments["iteration"] + distributed = arguments["distributed"] + + best_top1 = 0.0 + start_training_time = time.time() + end = time.time() + + while epoch < max_epoch: + if distributed: + data_loader.sampler.set_epoch(epoch) + + epoch += 1 + model.train() + arguments["epoch"] = epoch + + for step, (images, captions, _) in enumerate(data_loader): + data_time = time.time() - end + inner_iter = step + iteration += 1 + arguments["iteration"] = iteration + + images = images.to(device) + captions = [caption.to(device) for caption in captions] + + loss_dict = model(images, captions) + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if inner_iter % 1 == 0: + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "epoch [{epoch}][{inner_iter}/{num_iter}]", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + epoch=epoch, + inner_iter=inner_iter, + num_iter=len(data_loader), + meters=str(meters), + lr=optimizer.param_groups[-1]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + + scheduler.step() + + if epoch % evaluate_period == 0: + top1 = inference(model, data_loader_val[0], save_data=False, rerank=False) + meters.update(top1=top1) + if top1 > best_top1: + best_top1 = top1 + checkpointer.save("best", **arguments) + + if epoch % checkpoint_period == 0: + checkpointer.save("epoch_{:d}".format(epoch), **arguments) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/ai/TextReID/lib/models/backbones/__init__.py b/ai/TextReID/lib/models/backbones/__init__.py new file mode 100644 index 0000000000..487df49ec5 --- /dev/null +++ b/ai/TextReID/lib/models/backbones/__init__.py @@ -0,0 +1,3 @@ +from .build import build_textual_model, build_visual_model + +__all__ = ["build_textual_model", "build_visual_model"] diff --git a/ai/TextReID/lib/models/backbones/build.py b/ai/TextReID/lib/models/backbones/build.py new file mode 100644 index 0000000000..b67bd9017b --- /dev/null +++ b/ai/TextReID/lib/models/backbones/build.py @@ -0,0 +1,17 @@ +from .gru import build_gru +from .m_resnet import build_m_resnet +from .resnet import build_resnet + + +def build_visual_model(cfg): + if cfg.MODEL.VISUAL_MODEL in ["resnet50", "resnet101"]: + return build_resnet(cfg) + if cfg.MODEL.VISUAL_MODEL in ["m_resnet50", "m_resnet101"]: + return build_m_resnet(cfg) + raise NotImplementedError + + +def build_textual_model(cfg): + if cfg.MODEL.TEXTUAL_MODEL == "bigru": + return build_gru(cfg, bidirectional=True) + raise NotImplementedError diff --git a/ai/TextReID/lib/models/backbones/gru.py b/ai/TextReID/lib/models/backbones/gru.py new file mode 100644 index 0000000000..18b13b7fe7 --- /dev/null +++ b/ai/TextReID/lib/models/backbones/gru.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn + +from lib.utils.directory import load_vocab_dict + + +class GRU(nn.Module): + def __init__( + self, + hidden_dim, + vocab_size, + embed_size, + num_layers, + drop_out, + bidirectional, + use_onehot, + root, + ): + super().__init__() + + self.use_onehot = use_onehot + + # word embedding + if use_onehot == "yes": + self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0) + else: + if vocab_size == embed_size: + self.embed = None + else: + self.embed = nn.Linear(vocab_size, embed_size) + + vocab_dict = load_vocab_dict(root, use_onehot) + assert vocab_size == vocab_dict.shape[1] + self.vocab_dict = torch.tensor(vocab_dict).cuda().float() + + self.gru = nn.GRU( + embed_size, + hidden_dim, + num_layers=num_layers, + dropout=drop_out, + bidirectional=bidirectional, + bias=False, + ) + self.out_channels = hidden_dim * 2 if bidirectional else hidden_dim + + self._init_weight() + + def forward(self, captions): + text = torch.stack([caption.text for caption in captions], dim=1) + text_length = torch.stack([caption.length for caption in captions], dim=1) + + text_length = text_length.view(-1) + text = text.view(-1, text.size(-1)) # b x l + + if not self.use_onehot == "yes": + bs, length = text.shape[0], text.shape[-1] + text = text.view(-1) # bl + text = self.vocab_dict[text].reshape(bs, length, -1) # b x l x vocab_size + if self.embed is not None: + text = self.embed(text) + + gru_out = self.gru_out(text, text_length) + gru_out, _ = torch.max(gru_out, dim=1) + return gru_out + + def gru_out(self, embed, text_length): + + _, idx_sort = torch.sort(text_length, dim=0, descending=True) + _, idx_unsort = torch.sort(idx_sort, dim=0) + + embed_sort = embed.index_select(0, idx_sort) + length_list = text_length[idx_sort] + pack = nn.utils.rnn.pack_padded_sequence( + embed_sort, length_list.cpu(), batch_first=True + ) + + gru_sort_out, _ = self.gru(pack) + gru_sort_out = nn.utils.rnn.pad_packed_sequence(gru_sort_out, batch_first=True) + gru_sort_out = gru_sort_out[0] + + gru_out = gru_sort_out.index_select(0, idx_unsort) + return gru_out + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight.data, 1) + nn.init.constant(m.bias.data, 0) + + +def build_gru(cfg, bidirectional): + use_onehot = cfg.MODEL.GRU.ONEHOT + hidden_dim = cfg.MODEL.GRU.NUM_UNITS + vocab_size = cfg.MODEL.GRU.VOCABULARY_SIZE + embed_size = cfg.MODEL.GRU.EMBEDDING_SIZE + num_layer = cfg.MODEL.GRU.NUM_LAYER + drop_out = 1 - cfg.MODEL.GRU.DROPOUT_KEEP_PROB + root = cfg.ROOT + + model = GRU( + hidden_dim, + vocab_size, + embed_size, + num_layer, + drop_out, + bidirectional, + use_onehot, + root, + ) + + if cfg.MODEL.FREEZE: + for m in [model.embed, model.gru]: + m.eval() + for param in m.parameters(): + param.requires_grad = False + + return model diff --git a/ai/TextReID/lib/models/backbones/m_resnet.py b/ai/TextReID/lib/models/backbones/m_resnet.py new file mode 100644 index 0000000000..d2003ee3af --- /dev/null +++ b/ai/TextReID/lib/models/backbones/m_resnet.py @@ -0,0 +1,307 @@ +import logging +import math +import os +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, + spacial_dim, + embed_dim, + num_heads, + output_dim=None, + patch_size=1, + ): + super().__init__() + self.spacial_dim = spacial_dim + self.proj_conv = None + if patch_size > 1: + self.proj_conv = nn.Conv2d( + embed_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + self.positional_embedding = nn.Parameter( + torch.randn( + (spacial_dim[0] // patch_size) * (spacial_dim[1] // patch_size) + 1, + embed_dim, + ) + / embed_dim ** 0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + if self.proj_conv is not None: + x = self.proj_conv(x) + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__( + self, + layers, + output_dim, + heads, + last_stride=1, + input_resolution=(224, 224), + width=64, + ): + super().__init__() + self.output_dim = output_dim + self.out_channels = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=last_stride) + + embed_dim = width * 32 # the ResNet feature dimension + down_ratio = 16 if last_stride == 1 else 32 + spacial_dim = ( + input_resolution[0] // down_ratio, + input_resolution[1] // down_ratio, + ) + self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +def resize_pos_embed(posemb, gs_new): + # Rescale the grid of position embeddings when loading from state_dict. + logger = logging.getLogger("PersonSearch.train") + posemb_tok, posemb_grid = posemb[:1], posemb[1:] + gs_old = int(math.sqrt(len(posemb_grid))) + logger.info("Resized position embedding: {} to {}".format((gs_old, gs_old), gs_new)) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=gs_new, mode="bilinear", align_corners=False + ) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=0) + return posemb + + +def state_filter(state_dict, final_stage_resolution): + out_dict = {} + for k, v in state_dict.items(): + if k.startswith("visual."): + k = k[7:] + if k == "attnpool.positional_embedding" and final_stage_resolution != (7, 7): + v = resize_pos_embed(v, final_stage_resolution) + out_dict[k] = v + return out_dict + + +def modified_resnet50( + input_resolution, + last_stride, + pretrained_path=None, +): + model = ModifiedResNet( + layers=[3, 4, 6, 3], + output_dim=1024, + heads=32, + last_stride=last_stride, + input_resolution=input_resolution, + ) + if pretrained_path: + p = torch.jit.load(pretrained_path).state_dict() + model.load_state_dict( + state_filter( + p, + final_stage_resolution=model.attnpool.spacial_dim, + ), + strict=False, + ) + return model + + +def modified_resnet101( + input_resolution, + last_stride, + pretrained_path=None, +): + model = ModifiedResNet( + layers=[3, 4, 23, 3], + output_dim=512, + heads=32, + last_stride=last_stride, + input_resolution=input_resolution, + ) + if pretrained_path: + p = torch.jit.load(pretrained_path).state_dict() + model.load_state_dict( + state_filter( + p, + final_stage_resolution=model.attnpool.spacial_dim, + ), + strict=False, + ) + return model + + +def build_m_resnet(cfg): + if cfg.MODEL.VISUAL_MODEL in ["m_resnet50", "m_resnet"]: + model = modified_resnet50( + (cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), + cfg.MODEL.RESNET.RES5_STRIDE, + pretrained_path=os.path.join(cfg.ROOT, "pretrained/clip/RN50.pt"), + ) + elif cfg.MODEL.VISUAL_MODEL == "m_resnet101": + model = modified_resnet101( + (cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), + cfg.MODEL.RESNET.RES5_STRIDE, + pretrained_path=os.path.join(cfg.ROOT, "pretrained/clip/RN101.pt"), + ) + return model diff --git a/ai/TextReID/lib/models/backbones/resnet.py b/ai/TextReID/lib/models/backbones/resnet.py new file mode 100644 index 0000000000..0f78d941ed --- /dev/null +++ b/ai/TextReID/lib/models/backbones/resnet.py @@ -0,0 +1,235 @@ +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + """3x3 convolution with padding""" + # original padding is 1; original dilation is 1 + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride, dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + # original padding is 1; original dilation is 1 + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + model_arch, + res5_stride=2, + res5_dilation=1, + pretrained=True, + ): + super().__init__() + block = model_arch.block + layers = model_arch.stage + + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=res5_stride, dilation=res5_dilation + ) + + if pretrained is None: + self.load_state_dict(remove_fc(model_zoo.load_url(model_arch.url))) + else: + self.load_state_dict(torch.load(pretrained)) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_channels = 512 * block.expansion + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, dilation)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + + return x + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def remove_fc(state_dict): + """Remove the fc layer parameters from state_dict.""" + for key in list(state_dict.keys()): + if key.startswith("fc."): + del state_dict[key] + return state_dict + + +resnet = namedtuple("resnet", ["block", "stage", "url"]) +model_archs = {} +model_archs["resnet18"] = resnet( + BasicBlock, + [2, 2, 2, 2], + "https://download.pytorch.org/models/resnet18-5c106cde.pth", +) +model_archs["resnet34"] = resnet( + BasicBlock, + [3, 4, 6, 3], + "https://download.pytorch.org/models/resnet34-333f7ec4.pth", +) +model_archs["resnet50"] = resnet( + Bottleneck, + [3, 4, 6, 3], + "https://download.pytorch.org/models/resnet50-19c8e357.pth", +) +model_archs["resnet101"] = resnet( + Bottleneck, + [3, 4, 23, 3], + "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", +) +model_archs["resnet152"] = resnet( + Bottleneck, + [3, 8, 36, 3], + "https://download.pytorch.org/models/resnet152-b121ed2d.pth", +) + + +def build_resnet(cfg): + arch = cfg.MODEL.VISUAL_MODEL + res5_stride = cfg.MODEL.RESNET.RES5_STRIDE + res5_dilation = cfg.MODEL.RESNET.RES5_DILATION + pretrained = cfg.MODEL.RESNET.PRETRAINED + + model_arch = model_archs[arch] + model = ResNet( + model_arch, + res5_stride, + res5_dilation, + pretrained=pretrained, + ) + + if cfg.MODEL.FREEZE: + for m in [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3]: + m.eval() + for param in m.parameters(): + param.requires_grad = False + + return model diff --git a/ai/TextReID/lib/models/embeddings/__init__.py b/ai/TextReID/lib/models/embeddings/__init__.py new file mode 100644 index 0000000000..e0dc2533e0 --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/__init__.py @@ -0,0 +1,3 @@ +from .build import build_embed + +__all__ = ["build_embed"] diff --git a/ai/TextReID/lib/models/embeddings/build.py b/ai/TextReID/lib/models/embeddings/build.py new file mode 100644 index 0000000000..4273b331dd --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/build.py @@ -0,0 +1,9 @@ +from .simple_head.head import build_simple_head + + +def build_embed(cfg, visual_out_channels, textual_out_channels): + + if cfg.MODEL.EMBEDDING.EMBED_HEAD == "simple": + return build_simple_head(cfg, visual_out_channels, textual_out_channels) + else: + raise NotImplementedError diff --git a/ai/TextReID/lib/models/embeddings/moco_head/__init__.py b/ai/TextReID/lib/models/embeddings/moco_head/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ai/TextReID/lib/models/embeddings/moco_head/head.py b/ai/TextReID/lib/models/embeddings/moco_head/head.py new file mode 100644 index 0000000000..5f54cf4b76 --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/moco_head/head.py @@ -0,0 +1,187 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .loss import make_loss_evaluator + + +class MoCoHead(nn.Module): + def __init__( + self, + cfg, + visual_model, + textual_model, + ): + super().__init__() + self.embed_size = cfg.MODEL.EMBEDDING.FEATURE_SIZE + self.K = cfg.MODEL.MOCO.K + self.m = cfg.MODEL.MOCO.M + self.fc = cfg.MODEL.MOCO.FC + + self.v_encoder_q = visual_model + self.t_encoder_q = textual_model + self.v_encoder_k = copy.deepcopy(visual_model) + self.t_encoder_k = copy.deepcopy(textual_model) + for param in self.v_encoder_k.parameters(): + param.requires_grad = False + for param in self.t_encoder_k.parameters(): + param.requires_grad = False + + if self.fc: + self.v_fc_q = nn.Sequential( + nn.Linear(visual_model.out_channels, self.embed_size), + nn.ReLU(), + nn.Linear(self.embed_size, self.embed_size), + ) + self.t_fc_q = nn.Sequential( + nn.Linear(textual_model.out_channels, self.embed_size), + nn.ReLU(), + nn.Linear(self.embed_size, self.embed_size), + ) + self.v_fc_k = copy.deepcopy(self.v_fc_q) + self.t_fc_k = copy.deepcopy(self.t_fc_q) + for param in self.v_fc_k.parameters(): + param.requires_grad = False + for param in self.t_fc_k.parameters(): + param.requires_grad = False + + self.v_embed_layer = nn.Linear(visual_model.out_channels, self.embed_size) + self.t_embed_layer = nn.Linear(textual_model.out_channels, self.embed_size) + + self.register_buffer("t_queue", torch.rand(self.embed_size, self.K)) + self.t_queue = F.normalize(self.t_queue, dim=0) + self.register_buffer("v_queue", torch.rand(self.embed_size, self.K)) + self.v_queue = F.normalize(self.v_queue, dim=0) + # initialize id label as -1 + self.register_buffer("id_queue", -torch.ones((1, self.K), dtype=torch.long)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.loss_evaluator = make_loss_evaluator(cfg) + self._init_weight() + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip( + self.v_encoder_q.parameters(), self.v_encoder_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + for param_q, param_k in zip( + self.t_encoder_q.parameters(), self.t_encoder_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + if self.fc: + for param_q, param_k in zip( + self.v_fc_q.parameters(), self.v_fc_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + for param_q, param_k in zip( + self.t_fc_q.parameters(), self.t_fc_k.parameters() + ): + param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, v_keys, t_keys, id_keys): + batch_size = v_keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.v_queue[:, ptr : ptr + batch_size] = v_keys.T + self.t_queue[:, ptr : ptr + batch_size] = t_keys.T + self.id_queue[:, ptr : ptr + batch_size] = id_keys.T + + ptr = (ptr + batch_size) % self.K # move pointer + self.queue_ptr[0] = ptr + + def forward(self, images, captions): + N = images.shape[0] + + v_embed = self.v_encoder_q(images) + t_embed = self.t_encoder_q(captions) + + if self.training: + if self.fc: + v_embed_q = self.v_fc_q(v_embed) + t_embed_q = self.t_fc_q(t_embed) + v_embed = self.v_embed_layer(v_embed) + t_embed = self.t_embed_layer(t_embed) + v_embed_q = F.normalize(v_embed_q, dim=1) + t_embed_q = F.normalize(t_embed_q, dim=1) + else: + v_embed = self.v_embed_layer(v_embed) + t_embed = self.t_embed_layer(t_embed) + v_embed_q = F.normalize(v_embed, dim=1) + t_embed_q = F.normalize(t_embed, dim=1) + id_q = torch.stack([caption.get_field("id") for caption in captions]).long() + + with torch.no_grad(): + self._momentum_update_key_encoder() + v_embed_k = self.v_encoder_k(images) + if self.fc: + v_embed_k = self.v_fc_k(v_embed_k) + else: + v_embed_k = self.v_embed_layer(v_embed_k) + v_embed_k = F.normalize(v_embed_k, dim=1) + t_embed_k = self.t_encoder_k(captions) + if self.fc: + t_embed_k = self.t_fc_k(t_embed_k) + else: + t_embed_k = self.t_embed_layer(t_embed_k) + t_embed_k = F.normalize(t_embed_k, dim=1) + + # regard same instance ids as positive sapmles, we need filter them out + pos_idx = ( + self.id_queue.expand(N, self.K) + .eq(id_q.unsqueeze(-1)) + .nonzero(as_tuple=False)[:, 1] + ) + unique, counts = torch.unique( + torch.cat([torch.arange(self.K).long().cuda(), pos_idx]), + return_counts=True, + ) + neg_idx = unique[counts == 1] + + # v positive logits: Nx1 + v_pos = torch.einsum("nc,nc->n", [v_embed_q, t_embed_k]).unsqueeze(-1) + # v negative logits: NxK + t_queue = self.t_queue.clone().detach() + t_queue = t_queue[:, neg_idx] + v_neg = torch.einsum("nc,ck->nk", [v_embed_q, t_queue]) + # t positive logits: Nx1 + t_pos = torch.einsum("nc,nc->n", [t_embed_q, v_embed_k]).unsqueeze(-1) + # t negative logits: NxK + v_queue = self.v_queue.clone().detach() + v_queue = v_queue[:, neg_idx] + t_neg = torch.einsum("nc,ck->nk", [t_embed_q, v_queue]) + + losses = self.loss_evaluator( + v_embed, t_embed, v_pos, v_neg, t_pos, t_neg, id_q + ) + self._dequeue_and_enqueue(v_embed_k, t_embed_k, id_q) + return losses + + v_embed = self.v_embed_layer(v_embed) + t_embed = self.t_embed_layer(t_embed) + outputs = list() + outputs.append(v_embed) + outputs.append(t_embed) + return outputs + + +def build_moco_head(cfg, visual_model, textual_model): + return MoCoHead(cfg, visual_model, textual_model) diff --git a/ai/TextReID/lib/models/embeddings/moco_head/loss.py b/ai/TextReID/lib/models/embeddings/moco_head/loss.py new file mode 100644 index 0000000000..5ecabfaaf7 --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/moco_head/loss.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +import lib.models.losses as losses + + +class LossComputation(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.projection = Parameter( + torch.randn(cfg.MODEL.EMBEDDING.FEATURE_SIZE, cfg.MODEL.NUM_CLASSES), + requires_grad=True, + ) + self.epsilon = cfg.MODEL.EMBEDDING.EPSILON + # self.T = Parameter(torch.tensor(0.07), requires_grad=True) + self.T = 0.07 + nn.init.xavier_uniform_(self.projection.data, gain=1) + + def forward(self, v_embed, t_embed, v_pos, v_neg, t_pos, t_neg, labels): + loss = { + "instance_loss": losses.instance_loss( + self.projection, + v_embed, + t_embed, + labels, + epsilon=self.epsilon, + ), + "infonce_loss": losses.infonce_loss( + v_pos, + v_neg, + t_pos, + t_neg, + self.T, + ), + "global_align_loss": losses.global_align_loss(v_embed, t_embed, labels), + } + return loss + + +def make_loss_evaluator(cfg): + return LossComputation(cfg) diff --git a/ai/TextReID/lib/models/embeddings/simple_head/__init__.py b/ai/TextReID/lib/models/embeddings/simple_head/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ai/TextReID/lib/models/embeddings/simple_head/head.py b/ai/TextReID/lib/models/embeddings/simple_head/head.py new file mode 100644 index 0000000000..765b18a59e --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/simple_head/head.py @@ -0,0 +1,52 @@ +import torch.nn as nn + +from .loss import make_loss_evaluator + + +class SimpleHead(nn.Module): + def __init__( + self, + cfg, + visual_size, + textual_size, + ): + super().__init__() + self.embed_size = cfg.MODEL.EMBEDDING.FEATURE_SIZE + + self.visual_embed_layer = nn.Linear(visual_size, self.embed_size) + self.textual_embed_layer = nn.Linear(textual_size, self.embed_size) + + self.loss_evaluator = make_loss_evaluator(cfg) + self._init_weight() + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, visual_feature, textual_feature, captions): + batch_size = visual_feature.size(0) + + visual_embed = visual_feature.view(batch_size, -1) + textual_embed = textual_feature.view(batch_size, -1) + + visual_embed = self.visual_embed_layer(visual_embed) + textual_embed = self.textual_embed_layer(textual_embed) + + if self.training: + losses = self.loss_evaluator(visual_embed, textual_embed, captions) + return None, losses + + outputs = list() + outputs.append(visual_embed) + outputs.append(textual_embed) + return outputs, None + + +def build_simple_head(cfg, visual_size, textual_size): + model = SimpleHead(cfg, visual_size, textual_size) + return model diff --git a/ai/TextReID/lib/models/embeddings/simple_head/loss.py b/ai/TextReID/lib/models/embeddings/simple_head/loss.py new file mode 100644 index 0000000000..5b36d36a70 --- /dev/null +++ b/ai/TextReID/lib/models/embeddings/simple_head/loss.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +import lib.models.losses as losses + + +class LossComputation(nn.Module): + def __init__(self, cfg): + super().__init__() + self.epsilon = cfg.MODEL.EMBEDDING.EPSILON + self.scale_pos = 10.0 + self.scale_neg = 40.0 + + self.projection = Parameter( + torch.randn(cfg.MODEL.EMBEDDING.FEATURE_SIZE, cfg.MODEL.NUM_CLASSES), + requires_grad=True, + ) + nn.init.xavier_uniform_(self.projection.data, gain=1) + + def forward( + self, + visual_embed, + textual_embed, + captions, + ): + labels = torch.stack([caption.get_field("id") for caption in captions]).long() + loss = { + "instance_loss": losses.instance_loss( + self.projection, + visual_embed, + textual_embed, + labels, + epsilon=self.epsilon, + ), + "global_align_loss": losses.global_align_loss( + visual_embed, + textual_embed, + labels, + scale_pos=self.scale_pos, + scale_neg=self.scale_neg, + ), + } + return loss + + +def make_loss_evaluator(cfg): + return LossComputation(cfg) diff --git a/ai/TextReID/lib/models/losses.py b/ai/TextReID/lib/models/losses.py new file mode 100644 index 0000000000..4511b467ed --- /dev/null +++ b/ai/TextReID/lib/models/losses.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CrossEntropyLabelSmooth(nn.Module): + """Cross entropy loss with label smoothing regularizer. + + Reference: + Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. + Equation: y = (1 - epsilon) * y + epsilon / K. + + Args: + num_classes (int): number of classes. + epsilon (float): weight. + """ + + def __init__(self, num_classes, epsilon=0.1, use_gpu=True): + super().__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.use_gpu = use_gpu + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + """ + Args: + inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) + targets: ground truth labels with shape (num_classes) + """ + log_probs = self.logsoftmax(inputs) + targets = torch.zeros(log_probs.size()).scatter_( + 1, targets.unsqueeze(1).data.cpu(), 1 + ) + if self.use_gpu: + targets = targets.cuda() + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + + +def instance_loss( + projection, visual_embed, textual_embed, labels, scale=1, norm=False, epsilon=0.0 +): + if norm: + visual_norm = F.normalize(visual_embed, p=2, dim=-1) + textual_norm = F.normalize(textual_embed, p=2, dim=-1) + else: + visual_norm = visual_embed + textual_norm = textual_embed + projection_norm = F.normalize(projection, p=2, dim=0) + + visual_logits = scale * torch.matmul(visual_norm, projection_norm) + textual_logits = scale * torch.matmul(textual_norm, projection_norm) + + if epsilon > 0: + criterion = CrossEntropyLabelSmooth(num_classes=projection_norm.shape[1]) + else: + criterion = nn.CrossEntropyLoss(reduction="mean") + loss = criterion(visual_logits, labels) + criterion(textual_logits, labels) + + return loss + + +def cmpc_loss(projection, visual_embed, textual_embed, labels, verbose=False): + """ + Cross-Modal Projection Classfication loss (CMPC) + :param image_embeddings: Tensor with dtype torch.float32 + :param text_embeddings: Tensor with dtype torch.float32 + :param labels: Tensor with dtype torch.int32 + :return: + """ + visual_norm = F.normalize(visual_embed, p=2, dim=1) + textual_norm = F.normalize(textual_embed, p=2, dim=1) + projection_norm = F.normalize(projection, p=2, dim=0) + + image_proj_text = ( + torch.sum(visual_embed * textual_norm, dim=1, keepdim=True) * textual_norm + ) + text_proj_image = ( + torch.sum(textual_embed * visual_norm, dim=1, keepdim=True) * visual_norm + ) + + image_logits = torch.matmul(image_proj_text, projection_norm) + text_logits = torch.matmul(text_proj_image, projection_norm) + + criterion = nn.CrossEntropyLoss(reduction="mean") + loss = criterion(image_logits, labels) + criterion(text_logits, labels) + + # classification accuracy for observation + if verbose: + image_pred = torch.argmax(image_logits, dim=1) + text_pred = torch.argmax(text_logits, dim=1) + + image_precision = torch.mean((image_pred == labels).float()) + text_precision = torch.mean((text_pred == labels).float()) + + return loss, image_precision, text_precision + return loss + + +def global_align_loss( + visual_embed, + textual_embed, + labels, + alpha=0.6, + beta=0.4, + scale_pos=10, + scale_neg=40, +): + batch_size = labels.size(0) + visual_norm = F.normalize(visual_embed, p=2, dim=1) + textual_norm = F.normalize(textual_embed, p=2, dim=1) + similarity = torch.matmul(visual_norm, textual_norm.t()) + labels_ = ( + labels.expand(batch_size, batch_size) + .eq(labels.expand(batch_size, batch_size).t()) + .float() + ) + + pos_inds = labels_ == 1 + neg_inds = labels_ == 0 + loss_pos = torch.log(1 + torch.exp(-scale_pos * (similarity[pos_inds] - alpha))) + loss_neg = torch.log(1 + torch.exp(scale_neg * (similarity[neg_inds] - beta))) + loss = (loss_pos.sum() + loss_neg.sum()) * 2.0 + + loss /= batch_size + return loss + + +def global_align_loss_from_sim( + similarity, + labels, + alpha=0.6, + beta=0.4, + scale_pos=10, + scale_neg=40, +): + batch_size = labels.size(0) + labels_ = ( + labels.expand(batch_size, batch_size) + .eq(labels.expand(batch_size, batch_size).t()) + .float() + ) + + pos_inds = labels_ == 1 + neg_inds = labels_ == 0 + loss_pos = torch.log(1 + torch.exp(-scale_pos * (similarity[pos_inds] - alpha))) + loss_neg = torch.log(1 + torch.exp(scale_neg * (similarity[neg_inds] - beta))) + loss = (loss_pos.sum() + loss_neg.sum()) * 2.0 + + loss /= batch_size + return loss + + +def cmpm_loss(visual_embed, textual_embed, labels, verbose=False, epsilon=1e-8): + """ + Cross-Modal Projection Matching Loss(CMPM) + :param image_embeddings: Tensor with dtype torch.float32 + :param text_embeddings: Tensor with dtype torch.float32 + :param labels: Tensor with dtype torch.int32 + :return: + i2t_loss: cmpm loss for image projected to text + t2i_loss: cmpm loss for text projected to image + pos_avg_sim: average cosine-similarity for positive pairs + neg_avg_sim: averate cosine-similarity for negative pairs + """ + + batch_size = visual_embed.shape[0] + labels_reshape = torch.reshape(labels, (batch_size, 1)) + labels_dist = labels_reshape - labels_reshape.t() + labels_mask = labels_dist == 0 + + visual_norm = F.normalize(visual_embed, p=2, dim=1) + textual_norm = F.normalize(textual_embed, p=2, dim=1) + image_proj_text = torch.matmul(visual_embed, textual_norm.t()) + text_proj_image = torch.matmul(textual_embed, visual_norm.t()) + + # normalize the true matching distribution + labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) + + i2t_pred = F.softmax(image_proj_text, dim=1) + i2t_loss = i2t_pred * ( + F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon) + ) + + t2i_pred = F.softmax(text_proj_image, dim=1) + t2i_loss = t2i_pred * ( + F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon) + ) + + loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean( + torch.sum(t2i_loss, dim=1) + ) + + if verbose: + sim_cos = torch.matmul(visual_norm, textual_norm.t()) + + pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask)) + neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0)) + + return loss, pos_avg_sim, neg_avg_sim + return loss + + +def infonce_loss( + v_pos, + v_neg, + t_pos, + t_neg, + T=0.07, +): + v_logits = torch.cat([v_pos, v_neg], dim=1) / T + t_logits = torch.cat([t_pos, t_neg], dim=1) / T + labels = torch.zeros(v_logits.shape[0], dtype=torch.long).cuda() + loss = F.cross_entropy(v_logits, labels) + F.cross_entropy(t_logits, labels) + return loss diff --git a/ai/TextReID/lib/models/model.py b/ai/TextReID/lib/models/model.py new file mode 100644 index 0000000000..7500cb3c11 --- /dev/null +++ b/ai/TextReID/lib/models/model.py @@ -0,0 +1,45 @@ +from torch import nn + +from .backbones import build_textual_model, build_visual_model +from .embeddings import build_embed +from .embeddings.moco_head.head import build_moco_head + + +class Model(nn.Module): + def __init__(self, cfg): + super().__init__() + self.visual_model = build_visual_model(cfg) + self.textual_model = build_textual_model(cfg) + + if cfg.MODEL.EMBEDDING.EMBED_HEAD == "moco": + self.embed_model = build_moco_head( + cfg, self.visual_model, self.textual_model + ) + self.embed_type = "moco" + else: + self.embed_model = build_embed( + cfg, self.visual_model.out_channels, self.textual_model.out_channels + ) + self.embed_type = "normal" + + def forward(self, images, captions): + if self.embed_type == "moco": + return self.embed_model(images, captions) + + visual_feat = self.visual_model(images) + textual_feat = self.textual_model(captions) + + outputs_embed, losses_embed = self.embed_model( + visual_feat, textual_feat, captions + ) + + if self.training: + losses = {} + losses.update(losses_embed) + return losses + + return outputs_embed + + +def build_model(cfg): + return Model(cfg) diff --git a/ai/TextReID/lib/solver/__init__.py b/ai/TextReID/lib/solver/__init__.py new file mode 100644 index 0000000000..d5ad372d0a --- /dev/null +++ b/ai/TextReID/lib/solver/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_lr_scheduler, make_optimizer +from .lr_scheduler import LRSchedulerWithWarmup + +__all__ = ["make_lr_scheduler", "make_optimizer", "LRSchedulerWithWarmup"] diff --git a/ai/TextReID/lib/solver/build.py b/ai/TextReID/lib/solver/build.py new file mode 100644 index 0000000000..6a72205c5f --- /dev/null +++ b/ai/TextReID/lib/solver/build.py @@ -0,0 +1,55 @@ +import torch + +from .lr_scheduler import LRSchedulerWithWarmup + + +def make_optimizer(cfg, model): + params = [] + + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "bias" in key: + lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR + weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + + if cfg.SOLVER.OPTIMIZER == "SGD": + optimizer = torch.optim.SGD( + params, lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.SGD_MOMENTUM + ) + elif cfg.SOLVER.OPTIMIZER == "Adam": + optimizer = torch.optim.Adam( + params, + lr=cfg.SOLVER.BASE_LR, + betas=(cfg.SOLVER.ADAM_ALPHA, cfg.SOLVER.ADAM_BETA), + eps=1e-8, + ) + elif cfg.SOLVER.OPTIMIZER == "AdamW": + optimizer = torch.optim.AdamW( + params, + lr=cfg.SOLVER.BASE_LR, + betas=(cfg.SOLVER.ADAM_ALPHA, cfg.SOLVER.ADAM_BETA), + eps=1e-8, + ) + else: + NotImplementedError + + return optimizer + + +def make_lr_scheduler(cfg, optimizer): + return LRSchedulerWithWarmup( + optimizer, + milestones=cfg.SOLVER.STEPS, + gamma=cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_epochs=cfg.SOLVER.WARMUP_EPOCHS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + total_epochs=cfg.SOLVER.NUM_EPOCHS, + mode=cfg.SOLVER.LRSCHEDULER, + target_lr=cfg.SOLVER.TARGET_LR, + power=cfg.SOLVER.POWER, + ) diff --git a/ai/TextReID/lib/solver/lr_scheduler.py b/ai/TextReID/lib/solver/lr_scheduler.py new file mode 100644 index 0000000000..9c63c6a1bd --- /dev/null +++ b/ai/TextReID/lib/solver/lr_scheduler.py @@ -0,0 +1,87 @@ +from bisect import bisect_right +from math import cos, pi + +from torch.optim.lr_scheduler import _LRScheduler + + +class LRSchedulerWithWarmup(_LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + mode="step", + warmup_factor=1.0 / 3, + warmup_epochs=10, + warmup_method="linear", + total_epochs=100, + target_lr=0, + power=0.9, + last_epoch=-1, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" + " increasing integers. Got {}".format(milestones), + ) + if mode not in ("step", "exp", "poly", "cosine", "linear"): + raise ValueError( + "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted" + "got {}".format(mode) + ) + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.mode = mode + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_epochs = warmup_epochs + self.warmup_method = warmup_method + self.total_epochs = total_epochs + self.target_lr = target_lr + self.power = power + super().__init__(optimizer, last_epoch) + + def get_lr(self): + + if self.last_epoch < self.warmup_epochs: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = self.last_epoch / self.warmup_epochs + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [base_lr * warmup_factor for base_lr in self.base_lrs] + + if self.mode == "step": + return [ + base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + epoch_ratio = (self.last_epoch - self.warmup_epochs) / ( + self.total_epochs - self.warmup_epochs + ) + + if self.mode == "exp": + factor = epoch_ratio + return [base_lr * self.power ** factor for base_lr in self.base_lrs] + if self.mode == "linear": + factor = 1 - epoch_ratio + return [base_lr * factor for base_lr in self.base_lrs] + + if self.mode == "poly": + factor = 1 - epoch_ratio + return [ + self.target_lr + (base_lr - self.target_lr) * self.power ** factor + for base_lr in self.base_lrs + ] + if self.mode == "cosine": + factor = 0.5 * (1 + cos(pi * epoch_ratio)) + return [ + self.target_lr + (base_lr - self.target_lr) * factor + for base_lr in self.base_lrs + ] + raise NotImplementedError diff --git a/ai/TextReID/lib/utils/caption.py b/ai/TextReID/lib/utils/caption.py new file mode 100644 index 0000000000..27aab3674f --- /dev/null +++ b/ai/TextReID/lib/utils/caption.py @@ -0,0 +1,99 @@ +import torch + + +class Caption(object): + def __init__( + self, text, length=None, max_length=None, padded=False, dtype=torch.int64 + ): + device = text.device if isinstance(text, torch.Tensor) else torch.device("cpu") + if isinstance(text, list): + text = [torch.as_tensor(line, dtype=dtype, device=device) for line in text] + if length is None: + length = torch.stack( + [ + torch.tensor(line.size(0), dtype=torch.int64, device=device) + for line in text + ] + ) + if max_length is None: + max_length = max([line.size(-1) for line in text]) + elif isinstance(text, str): + if length is None: + length = len(text.split()) + else: + text = torch.as_tensor(text, dtype=dtype, device=device) + if length is None: + length = torch.tensor(text.size(-1), dtype=torch.int64, device=device) + if max_length is None: + max_length = text.size(-1) + + if not padded and not isinstance(text, str): + text = self.pad(text, max_length, device) + + self.text = text + self.length = length + self.max_length = max_length + self.padded = True + self.dtype = dtype + self.extra_fields = {} + + @staticmethod + def pad(text, max_length, device): + padded = [] + for line in text: + length = line.size(0) + if length < max_length: + pad = torch.zeros( + (max_length - length), dtype=torch.int64, device=device + ) + padded.append(torch.cat((line, pad))) + else: + padded.append(line[:max_length]) + return torch.stack(padded) + + def add_field(self, field, field_data): + self.extra_fields[field] = field_data + + def get_field(self, field): + return self.extra_fields[field] + + def has_field(self, field): + return field in self.extra_fields + + def fields(self): + return list(self.extra_fields.keys()) + + # Tensor-like methods + + def to(self, device): + cap = Caption( + self.text, + self.length, + self.max_length, + self.padded, + self.dtype, + ) + if not isinstance(self.text, str): + cap.text = cap.text.to(device) + cap.length = cap.length.to(device) + for k, v in self.extra_fields.items(): + if hasattr(v, "to"): + v = v.to(device) + cap.add_field(k, v) + return cap + + def __getitem__(self, item): + cap = Caption(self.text[item], self.max_length, self.padded) + for k, v in self.extra_fields.items(): + cap.add_field(k, v[item]) + return cap + + def __len__(self): + return len(self.text) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "length={}, ".format(self.length) + s += "max_length={}, ".format(self.max_length) + s += "padded={}, ".format(self.padded) + return s diff --git a/ai/TextReID/lib/utils/checkpoint.py b/ai/TextReID/lib/utils/checkpoint.py new file mode 100644 index 0000000000..ec27a10912 --- /dev/null +++ b/ai/TextReID/lib/utils/checkpoint.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +from collections import OrderedDict + +import torch + + +class Checkpointer: + def __init__( + self, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_dir = save_dir + self.save_to_disk = save_to_disk + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + + def save(self, name, **kwargs): + if not self.save_dir: + return + + if not self.save_to_disk: + return + + data = {} + data["model"] = self.model.state_dict() + if self.optimizer is not None: + data["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + data["scheduler"] = self.scheduler.state_dict() + data.update(kwargs) + + save_file = os.path.join(self.save_dir, "{}.pth".format(name)) + self.logger.info("Saving checkpoint to {}".format(save_file)) + torch.save(data, save_file) + + def load(self, f=None): + if not f: + # no checkpoint could be found + self.logger.info("No checkpoint found.") + return {} + self.logger.info("Loading checkpoint from {}".format(f)) + checkpoint = self._load_file(f) + self._load_model(checkpoint) + + def resume(self, f=None): + if not f: + # no checkpoint could be found + self.logger.info("No checkpoint found.") + return {} + self.logger.info("Loading checkpoint from {}".format(f)) + checkpoint = self._load_file(f) + self._load_model(checkpoint) + if "optimizer" in checkpoint and self.optimizer: + self.logger.info("Loading optimizer from {}".format(f)) + self.optimizer.load_state_dict(checkpoint.pop("optimizer")) + if "scheduler" in checkpoint and self.scheduler: + self.logger.info("Loading scheduler from {}".format(f)) + self.scheduler.load_state_dict(checkpoint.pop("scheduler")) + # return any further checkpoint data + return checkpoint + + def _load_file(self, f): + return torch.load(f, map_location=torch.device("cpu")) + + def _load_model(self, checkpoint, except_keys=None): + load_state_dict(self.model, checkpoint.pop("model"), except_keys) + + +def check_key(key, except_keys): + if except_keys is None: + return False + else: + for except_key in except_keys: + if except_key in key: + return True + return False + + +def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None): + current_keys = sorted(list(model_state_dict.keys())) + loaded_keys = sorted(list(loaded_state_dict.keys())) + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # loaded_key string, if it matches + match_matrix = [ + len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys + ] + match_matrix = torch.as_tensor(match_matrix).view( + len(current_keys), len(loaded_keys) + ) + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + # used for logging + max_size = max([len(key) for key in current_keys]) if current_keys else 1 + max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 + log_str_template = "{: <{}} loaded from {: <{}} of shape {}" + logger = logging.getLogger("PersonSearch.checkpoint") + for idx_new, idx_old in enumerate(idxs.tolist()): + if idx_old == -1: + continue + key = current_keys[idx_new] + key_old = loaded_keys[idx_old] + if check_key(key, except_keys): + continue + model_state_dict[key] = loaded_state_dict[key_old] + logger.info( + log_str_template.format( + key, + max_size, + key_old, + max_size_loaded, + tuple(loaded_state_dict[key_old].shape), + ) + ) + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict + + +def load_state_dict(model, loaded_state_dict, except_keys=None): + model_state_dict = model.state_dict() + # if the state_dict comes from a model that was wrapped in a + # DataParallel or DistributedDataParallel during serialization, + # remove the "module" prefix before performing the matching + loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") + align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys) + + # use strict loading + model.load_state_dict(model_state_dict) diff --git a/ai/TextReID/lib/utils/comm.py b/ai/TextReID/lib/utils/comm.py new file mode 100644 index 0000000000..ae4b556fcd --- /dev/null +++ b/ai/TextReID/lib/utils/comm.py @@ -0,0 +1,116 @@ +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import pickle + +import torch +import torch.distributed as dist + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.IntTensor([tensor.numel()]).to("cuda") + size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/ai/TextReID/lib/utils/directory.py b/ai/TextReID/lib/utils/directory.py new file mode 100644 index 0000000000..9c9b8fd39c --- /dev/null +++ b/ai/TextReID/lib/utils/directory.py @@ -0,0 +1,30 @@ +import os + +import numpy as np + + +def makedir(root): + if not os.path.exists(root): + os.makedirs(root) + + +def load_vocab_dict(root, use_onehot): + if use_onehot == "bert_c4": + vocab_dict = np.load( + os.path.join(root, "./datasets/cuhkpedes/bert_vocab_c4.npy") + ) + elif use_onehot == "bert_l2": + vocab_dict = np.load( + os.path.join(root, "./datasets/cuhkpedes/bert_vocab_l2.npy") + ) + elif use_onehot == "clip_vit": + vocab_dict = np.load( + os.path.join(root, "./datasets/cuhkpedes/clip_vocab_vit.npy") + ) + elif use_onehot == "clip_rn50x4": + vocab_dict = np.load( + os.path.join(root, "./datasets/cuhkpedes/clip_vocab_rn50x4.npy") + ) + else: + NotImplementedError + return vocab_dict diff --git a/ai/TextReID/lib/utils/logger.py b/ai/TextReID/lib/utils/logger.py new file mode 100644 index 0000000000..6fa08bc150 --- /dev/null +++ b/ai/TextReID/lib/utils/logger.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys + +from tabulate import tabulate + + +def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = logging.FileHandler(os.path.join(save_dir, filename), mode="w") + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +def table_log(cols, headers): + return tabulate(cols, headers=headers, tablefmt="grid") diff --git a/ai/TextReID/lib/utils/metric_logger.py b/ai/TextReID/lib/utils/metric_logger.py new file mode 100644 index 0000000000..74c7c1ace7 --- /dev/null +++ b/ai/TextReID/lib/utils/metric_logger.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import time +from collections import defaultdict, deque +from datetime import datetime + +import torch + +from .comm import is_main_process + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + self.series = [] + self.total = 0.0 + self.count = 0 + + def update(self, value): + self.deque.append(value) + self.series.append(value) + self.count += 1 + self.total += value + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque)) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + +class TensorboardLogger(MetricLogger): + def __init__(self, log_dir, start_iter=0, delimiter="\t"): + super(TensorboardLogger, self).__init__(delimiter) + self.iteration = start_iter + self.writer = self._get_tensorboard_writer(log_dir) + + @staticmethod + def _get_tensorboard_writer(log_dir): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError( + "To use tensorboard please install tensorboardX " + "[ pip install tensorflow tensorboardX ]." + ) + + if is_main_process(): + timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H:%M") + tb_logger = SummaryWriter("{}-{}".format(log_dir, timestamp)) + return tb_logger + else: + return None + + def update(self, **kwargs): + super(TensorboardLogger, self).update(**kwargs) + if self.writer: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.writer.add_scalar(k, v, self.iteration) + self.iteration += 1 diff --git a/ai/TextReID/requirements.txt b/ai/TextReID/requirements.txt new file mode 100644 index 0000000000..c6ca01f02b --- /dev/null +++ b/ai/TextReID/requirements.txt @@ -0,0 +1,23 @@ +backports.entry-points-selectable==1.1.1 +certifi==2021.10.8 +cfgv==3.3.1 +distlib==0.3.3 +filelock==3.4.0 +identify==2.4.0 +importlib-metadata==4.8.2 +nodeenv==1.6.0 +numpy==1.21.4 +Pillow==8.4.0 +platformdirs==2.4.0 +pre-commit==2.16.0 +PyYAML==6.0 +six==1.16.0 +tabulate==0.8.9 +toml==0.10.2 +torch==1.10.0 +torchvision==0.11.1 +tqdm==4.62.3 +typing_extensions==4.0.1 +virtualenv==20.10.0 +yacs==0.1.8 +zipp==3.6.0 diff --git a/ai/TextReID/run.sh b/ai/TextReID/run.sh new file mode 100644 index 0000000000..e1c4c095da --- /dev/null +++ b/ai/TextReID/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +PYTHONHOME="/vol/research/xmodal_dl/txtreid-env/bin" +HOME="/vol/research/xmodal_dl/TextReID" + +echo $HOME +echo 'args:' $@ + +$PYTHONHOME/python $HOME/train_net.py --root $HOME $@ diff --git a/ai/TextReID/run.submit_file b/ai/TextReID/run.submit_file new file mode 100644 index 0000000000..ee59b5fce6 --- /dev/null +++ b/ai/TextReID/run.submit_file @@ -0,0 +1,29 @@ +executable = run.sh + +universe = docker +docker_image = nvidia/cuda:11.1-runtime-ubuntu18.04 + +log = condor_log/c$(cluster).p$(process).log +output = condor_log/c$(cluster).p$(process).out +error = condor_log/c$(cluster).p$(process).error + +environment = "mount=/vol/research/xmodal_dl/" + ++CanCheckpoint = True ++GPUMem = 11000 ++JobRunTime = 12 + +should_transfer_files = True +stream_output = True + +request_GPUs = 1 +request_CPUs = 1 +request_memory = 11G +requirements = (CUDAGlobalMemoryMb > 4500) && \ + (HasDocker) && \ + (CUDACapability > 2.0) && \ + (CUDADeviceName == "GeForce RTX 3090") + +queue arguments from ( + --config-file $ENV(PWD)/configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml +) diff --git a/ai/TextReID/runs/Mar31_04-52-42_2e078808ccc9/events.out.tfevents.1711860762.2e078808ccc9.3553.0 b/ai/TextReID/runs/Mar31_04-52-42_2e078808ccc9/events.out.tfevents.1711860762.2e078808ccc9.3553.0 new file mode 100644 index 0000000000..730a65cee0 Binary files /dev/null and b/ai/TextReID/runs/Mar31_04-52-42_2e078808ccc9/events.out.tfevents.1711860762.2e078808ccc9.3553.0 differ diff --git a/ai/TextReID/runs/Mar31_04-53-54_2e078808ccc9/events.out.tfevents.1711860834.2e078808ccc9.3916.0 b/ai/TextReID/runs/Mar31_04-53-54_2e078808ccc9/events.out.tfevents.1711860834.2e078808ccc9.3916.0 new file mode 100644 index 0000000000..ff2e8d46df Binary files /dev/null and b/ai/TextReID/runs/Mar31_04-53-54_2e078808ccc9/events.out.tfevents.1711860834.2e078808ccc9.3916.0 differ diff --git a/ai/TextReID/test_net.py b/ai/TextReID/test_net.py new file mode 100644 index 0000000000..1ac1b6f0d5 --- /dev/null +++ b/ai/TextReID/test_net.py @@ -0,0 +1,124 @@ +import argparse +import os + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed + +from lib.config import cfg +from lib.data import make_data_loader +from lib.engine.inference import inference +from lib.models.model import build_model +from lib.utils.checkpoint import Checkpointer +from lib.utils.comm import get_rank, synchronize +from lib.utils.directory import makedir +from lib.utils.logger import setup_logger + + +def main(): + parser = argparse.ArgumentParser( + description="PyTorch Image-Text Matching Inference" + ) + parser.add_argument( + "--root", + default="./", + help="root path", + type=str, + ) + parser.add_argument( + "--config-file", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "--checkpoint-file", + default="", + metavar="FILE", + help="path to checkpoint file", + type=str, + ) + parser.add_argument( + "--local_rank", + default=0, + type=int, + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--load-result", + help="Use saved reslut as prediction", + action="store_true", + ) + + args = parser.parse_args() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + distributed = num_gpus > 1 + + if distributed: + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend="nccl", init_method="env://") + synchronize() + + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.ROOT = args.root + cfg.freeze() + + model = build_model(cfg) + model.to(cfg.MODEL.DEVICE) + + output_dir = os.path.join( + args.root, "./output", "/".join(args.config_file.split("/")[-2:])[:-5] + ) + checkpointer = Checkpointer(model, save_dir=output_dir) + _ = checkpointer.load(args.checkpoint_file) + + output_folders = list() + dataset_names = cfg.DATASETS.TEST + for dataset_name in dataset_names: + output_folder = os.path.join(output_dir, "inference", dataset_name) + makedir(output_folder) + output_folders.append(output_folder) + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + distributed = num_gpus > 1 + + data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) + for output_folder, dataset_name, data_loader_val in zip( + output_folders, dataset_names, data_loaders_val + ): + logger = setup_logger("PersonSearch", output_folder, get_rank()) + logger.info("Using {} GPUs".format(num_gpus)) + logger.info(cfg) + + inference( + model, + data_loader_val, + dataset_name=dataset_name, + device=cfg.MODEL.DEVICE, + output_folder=output_folder, + save_data=False, + rerank=True, + ) + synchronize() + + print("finish") + print(len(data_loaders_val)) + + for i in data_loaders_val: + print("asdf") + #print(len(images), len(captions), len(image_ids)) + break + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ai/TextReID/train_net.py b/ai/TextReID/train_net.py new file mode 100644 index 0000000000..5a4c6b74ee --- /dev/null +++ b/ai/TextReID/train_net.py @@ -0,0 +1,187 @@ +import argparse +import os +import random + +import numpy as np +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed + +from lib.config import cfg +from lib.data import make_data_loader +from lib.engine.trainer import do_train +from lib.models.model import build_model +from lib.solver import make_lr_scheduler, make_optimizer +from lib.utils.checkpoint import Checkpointer +from lib.utils.comm import get_rank, synchronize +from lib.utils.directory import makedir +from lib.utils.logger import setup_logger +from lib.utils.metric_logger import MetricLogger, TensorboardLogger + + +def set_random_seed(seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def train(cfg, output_dir, local_rank, distributed, resume_from, use_tensorboard): + data_loader = make_data_loader( + cfg, + is_train=True, + is_distributed=distributed, + ) + data_loader_val = make_data_loader( + cfg, + is_train=False, + is_distributed=distributed, + ) + model = build_model(cfg) + device = torch.device(cfg.MODEL.DEVICE) + model.to(device) + + optimizer = make_optimizer(cfg, model) + scheduler = make_lr_scheduler(cfg, optimizer) + + if distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + output_device=local_rank, + # this should be removed if we update BatchNorm stats + broadcast_buffers=False, + ) + + arguments = {} + arguments["iteration"] = 0 + arguments["epoch"] = 0 + + save_to_disk = get_rank() == 0 + checkpointer = Checkpointer(model, optimizer, scheduler, output_dir, save_to_disk) + if cfg.MODEL.WEIGHT != "imagenet": + if os.path.isfile(cfg.MODEL.WEIGHT): + checkpointer.load(cfg.MODEL.WEIGHT) + else: + raise IOError("{} is not a checkpoint file".format(cfg.MODEL.WEIGHT)) + if resume_from: + if os.path.isfile(resume_from): + extra_checkpoint_data = checkpointer.resume(resume_from) + arguments.update(extra_checkpoint_data) + else: + raise IOError("{} is not a checkpoint file".format(resume_from)) + + if use_tensorboard: + meters = TensorboardLogger( + log_dir=os.path.join(output_dir, "tensorboard"), + start_iter=arguments["iteration"], + delimiter=" ", + ) + else: + meters = MetricLogger(delimiter=" ") + + checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD + evaluate_period = cfg.SOLVER.EVALUATE_PERIOD + arguments["max_epoch"] = cfg.SOLVER.NUM_EPOCHS + arguments["distributed"] = distributed + + do_train( + model, + data_loader, + data_loader_val, + optimizer, + scheduler, + checkpointer, + meters, + device, + checkpoint_period, + evaluate_period, + arguments, + ) + + +def main(): + set_random_seed() + + parser = argparse.ArgumentParser(description="PyTorch Person Search Training") + parser.add_argument( + "--root", + default="./", + help="root path", + type=str, + ) + parser.add_argument( + "--config-file", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "--resume-from", + help="the checkpoint file to resume from", + type=str, + ) + parser.add_argument( + "--local_rank", + default=0, + type=int, + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--use-tensorboard", + dest="use_tensorboard", + help="Use tensorboardX logger (Requires tensorboardX and tensorflow installed)", + action="store_true", + default=False, + ) + + args = parser.parse_args() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + args.distributed = num_gpus > 1 + + if args.distributed: + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend="nccl", init_method="env://") + synchronize() + + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.ROOT = args.root + cfg.freeze() + + output_dir = os.path.join( + args.root, "./output", "/".join(args.config_file.split("/")[-2:])[:-5] + ) + makedir(output_dir) + + logger = setup_logger("PersonSearch", output_dir, get_rank()) + logger.info("Using {} GPUs".format(num_gpus)) + logger.info(args) + + logger.info("Loaded configuration file {}".format(args.config_file)) + with open(args.config_file, "r") as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + logger.info("Running with config:\n{}".format(cfg)) + + train( + cfg, + output_dir, + args.local_rank, + args.distributed, + args.resume_from, + args.use_tensorboard, + ) + + +if __name__ == "__main__": + main()