From 7a9bddeb02884d0dcbcae585ea5e5843a3ec42c4 Mon Sep 17 00:00:00 2001 From: Alessandro Ragano <44505487+alessandroragano@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:29:10 +0100 Subject: [PATCH] Fix refactoring params embeddings --- src/training/train_triplet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/training/train_triplet.py b/src/training/train_triplet.py index 6175b84..0755ce0 100644 --- a/src/training/train_triplet.py +++ b/src/training/train_triplet.py @@ -96,7 +96,8 @@ def __init__(self, config_file): # Create optimizer with adaptive learning rate if self.config['freeze_convnet']: - params_names_embeddings = [f'embeddings.{j}.weight' for j in range(7)] + [f'embeddings.{j}.bias' for j in range(7)] + # params_names_embeddings = [f'embeddings.{j}.weight' for j in range(7)] + [f'embeddings.{j}.bias' for j in range(7)] + params_names_embeddings = ['embedding_layer.1.weight', 'embedding_layer.1.bias'] params_pt = [param for name, param in self.model.named_parameters() if name not in params_names_embeddings] params_embeddings = [param for name, param in self.model.named_parameters() if name in params_names_embeddings] # Overwrite optimizer @@ -485,4 +486,4 @@ def get_nmr_embeddings(self): # Use this function to check if cdist gives the same result def euclidean_dist(self, emb_a, emb_b): dist = np.sqrt(np.dot(emb_a - emb_b, (emb_a - emb_b).T)) - return dist \ No newline at end of file + return dist