A code repository that attempts to implementation the ConCLR (AAAI 2022) (unofficial)
install the dependencies
pip install -r requirements.txt
Train ABINet-Vision-ConCLR model
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/conclr_pretrain_vision_model.yaml
Evaluate ABINet-Vision-ConCLR model
CUDA_VISIBLE_DEVICES=0 python main.py --config=configs/conclr_pretrain_vision_model.yaml --phase test --image_only
python demo.py --config=configs/conclr_pretrain_vision_model.yaml --input=test/img/path
Additional flags:
--config /path/to/config :: Set the path of configuration file
--input /path/to/image-directory :: Set the path of image directory or wildcard path
--checkpoint /path/to/checkpoint :: Set the path of trained model
--cuda [-1|0|1|2|3...] :: Set the cuda id, by default -1 is set and stands for cpu
--model_eval [alignment|vision] :: Specify which sub-model to use
--image_only :: Disable dumping visualization of attention masks
The authors’ key insight is that by pulling together embeddings of the same character in different contexts and pushing apart embeddings of different characters, we can guide models to learn a representation better balances the intrinsic and context information. Here are some key components of the ConCLR.
Before feeding the batch data into the model, Context-based Data Augmentation (ConAug) randomly permutes the input batch twice and randomly concatenates it to the left or right of the original batch. As shown in figure below:
Because ConAug need to process the input data after organizing it into batch data and before feeding it into the model, I use the Callback mechanism of the fastai to implement it. I implemented the “on_batch_begin” method which will be called before the batch feeding into model. Here is my implement code:
class ConAugPretransform(Callback):
def __init__(self, *args, **kwargs):
self.aug_type = kwargs["aug_type"]
self.max_length = kwargs["max_length"]
self.test_conaug = kwargs["test_conaug"]
self.charset_path = kwargs["charset_path"]
def _augment_images(self, images, permuted_indices, left):
permuted_images = images[permuted_indices]
augmented_images = torch.cat(
[permuted_images if left else images, images if left else permuted_images],
dim=3,
)
original_height, original_width = images.size(2), images.size(3)
augmented_images = F.interpolate(
augmented_images,
size=(original_height, original_width),
mode="bilinear",
align_corners=False,
)
return augmented_images
def _visualize_augmented_batch(self, x, x1, x2, y, y1, y2, yl, yl1, yl2):
...
def _process_gt_labels(self, gt_labels, gt_lengths, permuted_indices, left):
new_gt_labels = gt_labels.clone()
new_gt_lengths = gt_lengths.clone()
for i, perm in enumerate(permuted_indices):
if left:
new_gt = torch.cat(
(
gt_labels[perm][: gt_lengths[perm] - 1],
gt_labels[i][: self.max_length + 2 - gt_lengths[perm] - 1],
torch.unsqueeze(gt_labels[perm][-1], dim=0),
),
dim=0,
)
else:
new_gt = torch.cat(
(
gt_labels[i][: gt_lengths[i] - 1],
gt_labels[perm][: self.max_length + 2 - gt_lengths[i] - 1],
torch.unsqueeze(gt_labels[i][-1], dim=0),
),
dim=0,
)
new_gt_labels[i] = new_gt
new_gt_lengths += new_gt_lengths[permuted_indices] - 1
new_gt_lengths = torch.minimum(
new_gt_lengths, torch.tensor(self.max_length + 1)
)
return new_gt_labels, new_gt_lengths
def on_batch_begin(self, last_input, last_target, **kwargs) -> dict:
images = last_input[0]
permuted_indices_one = torch.randperm(images.size(0))
permuted_indices_two = torch.randperm(images.size(0))
left_one = torch.randint(0, 2, (1,), dtype=torch.bool)
left_two = torch.randint(0, 2, (1,), dtype=torch.bool)
augmented_batch_one = self._augment_images(
images, permuted_indices_one, left_one
)
augmented_batch_two = self._augment_images(
images, permuted_indices_two, left_two
)
last_input = ((images, augmented_batch_one, augmented_batch_two), last_input[1])
gt_labels, gt_lengths = last_target[0], last_target[1]
gt_labels_one, gt_lengths_one = self._process_gt_labels(
gt_labels, gt_lengths, permuted_indices_one, left_one
)
gt_labels_two, gt_lengths_two = self._process_gt_labels(
gt_labels, gt_lengths, permuted_indices_two, left_two
)
last_target[0] = [gt_labels, gt_labels_one, gt_labels_two]
last_target[1] = [gt_lengths, gt_lengths_one, gt_lengths_two]
if self.test_conaug:
self._visualize_augmented_batch(
images,
augmented_batch_one,
augmented_batch_two,
gt_labels,
gt_labels_one,
gt_labels_two,
gt_lengths,
gt_lengths_one,
gt_lengths_two,
)
return {"last_input": last_input, "last_target": last_target}
In this function, I process the augmented batch and its label together. (the augmented batch is also used to calculate recognition loss so process the label is necessary)
I plotted the results of conaug to verify the correctness of my implementation, as follows:
Feed the two augmented batch into the backbone and decoder, we obtained their corresponding glimpse vectors. The author mentions that SimCLR proposes to utilize embeddings directly for contrastive learning will harm to the model performance,thereby necessitating a projection head to filter out irrelevant information. Additionally, I noted that SimCLR discusses the benefits of incorporating non-linear layers in the projection head to enhance performance. Consequently, I implemented the head using the following code.
class ProjectionHead(nn.Module):
def __init__(self, in_channel, out_channel, head="mlp") -> None:
super().__init__()
if head == "linear":
self.head = nn.Linear(in_channel, out_channel)
if head == "mlp":
self.head = nn.Sequential(
nn.Linear(in_channel, in_channel),
nn.ReLU(inplace=True),
nn.Linear(in_channel, out_channel),
)
def forward(self, x):
embed = F.normalize(self.head(x), dim=-1, p=2)
return embed
By ConAug、Backbone、Decoder and the Projection Head, we have obtained two embeddings corresponding the two augmented batches, and now how could we do the “pulling together embeddings of the same character in different contexts and pushing apart embeddings of different characters”? The anwser is the Contrastive Loss.
Here are the formulate:
I won’t go into details about the meaning of the letters in the specific formula here, it can be found in the paper. My code is implemented as follows.
class ContrastiveLoss(nn.Module):
def __init__(self, temprature=2):
super().__init__()
self.temprature = temprature
def _clr_loss(self, embed1, embed2, labels1, labels2, loss_name, record=True):
embed = torch.cat((embed1, embed2), dim=1)
labels = torch.cat((labels1, labels2), dim=1)
_, max_length = labels.shape
s = torch.div(
embed @ embed.transpose(2, 1), self.temprature
) # calculate similarity
am = (
~torch.eye(max_length, dtype=torch.bool)
.unsqueeze(0)
.expand(*s.shape)
.cuda()
) # negative mask(include pm)
pm = (
labels[:, :, None]
== labels[:, None, :] # expand label and compare it with its transpose
) & am # positive mask
# exclude eos
p_num = torch.where(
labels.unsqueeze(2) != 0, pm.sum(dim=1).unsqueeze(2), 0
) # count pos num of each m
# include eos
# p_num = pm.sum(dim=1).unsqueeze(2) # count pos num of each m
em = p_num.bool() # deal with no positive
s = torch.masked_fill(s, ~em, 0) # mask no pos
p = torch.masked_fill(s, ~pm, 0)
a = torch.logsumexp(s.masked_fill(~am, 0), dim=2)
a = torch.masked_fill(a, ~em.squeeze(), 0)
smx = torch.where(
p != 0,
p - a.unsqueeze(2),
torch.tensor(0.0, dtype=torch.float32, device="cuda"),
)
dv = torch.where(
em.squeeze(),
-torch.sum(smx, dim=2) / p_num.squeeze(),
torch.tensor(0.0, dtype=torch.float32, device="cuda"),
)
l_pair = torch.sum(dv, dim=1)
loss = torch.mean(l_pair, dim=0)
if record and loss_name is not None:
self.losses[f"{loss_name}_loss"] = loss
return loss
def forward(self, X, Y, *args):
self.losses = {}
features, loss_name = X[0].get("proj"), X[0].get("name")
return self._clr_loss(
features[0],
features[1],
torch.argmax(Y[0], dim=2),
torch.argmax(Y[1], dim=2),
loss_name,
*args,
)
To improve parallel computation efficiency, I compute the contrastive loss directly across the entire batch by matriphy operation. By applying masks , we obtain an equivalent result to the loss function.
For model evaluation to demonstrate effectiveness of reproduction, I trained the ConCLR-Vision and ABInet-Vision models using same training settings. Due to constraints on time and computational resources, both models were trained on the ST dataset for 4 epochs using four NVIDIA 2080 Ti GPUs (on AutoDL 🥲). The learning rate was initialized at 1e-4 and decayed to 1e-5 by the final epoch.
Model | Dataset | CCR | CWR |
---|---|---|---|
conclr-vision | IIIT5k | 0.942 | 0.891 |
IC13 | 0.963 | 0.890 | |
CUTE80 | 0.928 | 0.910 | |
abinet-vision | IIIT5k | 0.926 | 0.872 |
IC13 | 0.938 | 0.852 | |
CUTE80 | 0.771 | 0.882 |
The results indicate that using ConCLR led to significant improvements in both CCR and CWR. However, this experiment was conducted on a dataset that was not sufficiently large, and the training duration was relatively short. Due to resource constraints, I have not been able to conduct a more comprehensive experiment now, More comprehensive experiments are underway…
btw, the training log are stored in the “log/” directory, you can visit it through tensorboard or just look at the .txt file.
(There might be errors in my implementation, and I would greatly appreciate any feedback on this matter 🥰)