Skip to content

Commit

Permalink
Fix refactoring params embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandroragano authored Sep 17, 2024
1 parent 4bdfb4b commit 7a9bdde
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/training/train_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(self, config_file):

# Create optimizer with adaptive learning rate
if self.config['freeze_convnet']:
params_names_embeddings = [f'embeddings.{j}.weight' for j in range(7)] + [f'embeddings.{j}.bias' for j in range(7)]
# params_names_embeddings = [f'embeddings.{j}.weight' for j in range(7)] + [f'embeddings.{j}.bias' for j in range(7)]
params_names_embeddings = ['embedding_layer.1.weight', 'embedding_layer.1.bias']
params_pt = [param for name, param in self.model.named_parameters() if name not in params_names_embeddings]
params_embeddings = [param for name, param in self.model.named_parameters() if name in params_names_embeddings]
# Overwrite optimizer
Expand Down Expand Up @@ -485,4 +486,4 @@ def get_nmr_embeddings(self):
# Use this function to check if cdist gives the same result
def euclidean_dist(self, emb_a, emb_b):
dist = np.sqrt(np.dot(emb_a - emb_b, (emb_a - emb_b).T))
return dist
return dist

0 comments on commit 7a9bdde

Please sign in to comment.