Skip to content

Commit

Permalink
Try CosineEmbeddingLoss for contrastive learning
Browse files Browse the repository at this point in the history
Log the contrastive loss
  • Loading branch information
AngledLuffa committed Dec 16, 2024
1 parent 8e8b7db commit 6f2b56a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
63 changes: 52 additions & 11 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@

TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])

class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'contrastive_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
contrastive_loss = self.contrastive_loss + other.contrastive_loss
nans = self.nans + other.nans
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(epoch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def evaluate(args, model_file, retag_pipeline):
"""
Expand Down Expand Up @@ -339,6 +340,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
# Various experiments generally show about 0.5 F1 loss on various
# datasets when using 'mean' instead of 'sum' for reduction
# (Remember to adjust the weight decay when rerunning that experiment)
device = trainer.device

if args['loss'] == 'cross':
tlogger.info("Building CrossEntropyLoss(sum)")
process_outputs = lambda x: x
Expand All @@ -357,9 +360,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum')
else:
raise ValueError("Unexpected loss term: %s" % args['loss'])

device = trainer.device
model_loss_function.to(device)

if args['contrastive_learning_rate'] > 0:
contrastive_loss_function = nn.CosineEmbeddingLoss(margin=args['contrastive_margin'])
contrastive_loss_function.to(device)
else:
contrastive_loss_function = None

transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
for (y, x) in enumerate(trainer.transitions)}
trainer.train()
Expand Down Expand Up @@ -409,7 +417,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
epoch_data = epoch_data + epoch_silver_data
epoch_data.sort(key=lambda x: len(x[1]))

epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args)
epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, epoch_data, oracle, args)

# print statistics
# by now we've forgotten about the original tags on the trees,
Expand All @@ -430,9 +438,15 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
]
if args['contrastive_learning_rate'] > 0.0:
stats_log_lines.extend([
"Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss
])
stats_log_lines.extend([
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
]
])
tlogger.info("\n ".join(stats_log_lines))

old_lr = trainer.optimizer.param_groups[0]['lr']
Expand Down Expand Up @@ -526,17 +540,17 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d

return trainer

def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args):
def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, epoch_data, oracle, args):
interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
random.shuffle(interval_starts)

optimizer = trainer.optimizer

epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)

for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args)
batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, oracle, args)
trainer.batches_trained += 1

# Early in the training, some trees will be degenerate in a
Expand All @@ -562,7 +576,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m

return epoch_stats

def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):
def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, oracle, args):
"""
Train the model for one batch
Expand All @@ -572,6 +586,29 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
... although the indentation does get pretty ridiculous if this is
merged into train_model_one_epoch and then iterate_training
"""
contrastive_loss = 0.0
if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True)
reparsed_states = [x.state for x in reparsed_results]
reparsed_trees = [x.constituents.value.value.value for x in reparsed_states]
reparsed_tree_hx = [x.constituents.value.value.tree_hx for x in reparsed_states]

gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=False, keep_scores=False)
gold_states = [x.state for x in gold_results]
gold_trees = [x.constituents.value.value.value for x in gold_states]
gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states]

reparsed_negatives = [hx for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
gold_negatives = [hx for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]

if len(reparsed_negatives) > 0:
reparsed_negatives = torch.cat(reparsed_negatives, dim=0)
gold_negatives = torch.cat(gold_negatives, dim=0)

device = next(model.parameters()).device
target = -torch.ones(reparsed_negatives.shape[0]).to(device)
contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(reparsed_negatives, gold_negatives, target)

# now we add the state to the trees in the batch
# the state is built as a bulk operation
current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch],
Expand Down Expand Up @@ -660,6 +697,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te

errors = process_outputs(errors)
tree_loss = model_loss_function(errors, answers)
tree_loss += contrastive_loss
tree_loss.backward()
if args['watch_regex']:
matched = False
Expand All @@ -678,12 +716,15 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
tlogger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
batch_loss = 0.0
contrastive_loss = 0.0
nans = 1
else:
batch_loss = tree_loss.item()
if not isinstance(contrastive_loss, float):
contrastive_loss = contrastive_loss.item()
nans = 0

return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(batch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions stanza/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,10 @@ def build_argparse():
parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum')
parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)')

parser.add_argument('--contrastive_initial_epoch', default=1, type=int, help='When to start contrastive learning')
parser.add_argument('--contrastive_margin', default=0.0, type=float, help='epsilon for the negative examples of contrastive learning')
parser.add_argument('--contrastive_learning_rate', default=0.0, type=float, help='Multiplicative factor for constrastive learning')

parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')

Expand Down

0 comments on commit 6f2b56a

Please sign in to comment.