Skip to content

Commit

Permalink
DiffAbXL
Browse files Browse the repository at this point in the history
  • Loading branch information
talipucar committed Oct 22, 2024
1 parent fb06487 commit 6474a72
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 2 deletions.
47 changes: 47 additions & 0 deletions compute_loglikelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Author: Talip Ucar
email: [email protected]
Description: Sample script to compute log-likehood.
"""

import torch


def compute_log_likelihood(sequence_tokens_list, posterior_list, parent_aa_list=None):
"""
Compute the log-likelihood for each sequence in the batch.
Args:
sequence_tokens_list (list of Tensors): List of tensors of size (batch_size_i, sequence_length) containing sequence tokens.
posterior_list (list of Tensors): List of tensors of size (batch_size_i, sequence_length, 20) containing posterior probabilities over amino acids.
parent_aa_list (list of Tensors, optional): List of tensors containing parent amino acid tokens.
Returns:
log_likelihoods (Tensor): Tensor of log-likelihood values for the batch.
log_likelihood_per_position (Tensor): Tensor of log-likelihood per position.
"""
# Concatenate the list of tensors along the batch dimension
sequence_tokens = torch.cat(sequence_tokens_list, dim=0)
posterior = torch.cat(posterior_list, dim=0)

# Compute log probabilities from posterior
log_posterior = torch.log(posterior + 1e-9) # Avoid log(0) by adding a small epsilon
log_posterior = log_posterior.sum(0).unsqueeze(0).repeat(sequence_tokens.size(0), 1, 1)

# Gather the log probabilities corresponding to the actual sequence tokens
log_likelihood_per_position = torch.gather(
log_posterior, dim=2, index=sequence_tokens.unsqueeze(-1)
).squeeze(-1)

if parent_aa_list is not None and len(parent_aa_list) > 0:
parent_aa_tokens = torch.cat(parent_aa_list, dim=0)
parent_log_likelihood_per_position = torch.gather(
log_posterior, dim=2, index=parent_aa_tokens.unsqueeze(-1)
).squeeze(-1)
log_likelihood_per_position = log_likelihood_per_position - parent_log_likelihood_per_position

# Sum the log-likelihood over the sequence length to get the total log-likelihood for each sequence
log_likelihoods = log_likelihood_per_position.sum(dim=1)

return log_likelihoods, log_likelihood_per_position
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: modelling-dev
channels:
- conda-forge
dependencies:
- python=3.7.*
- python=3.11.8
- numpy >= 1.16
- pandas >= 1
- matplotlib >= 3.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ nvidia-nvtx-cu12==12.1.105
opt-einsum==3.3.0
pandas==2.2.1
pickleshare==0.7.5
pillow==9.6.0
pillow==10.2.0
progressbar2==4.4.2
python-utils==3.8.2
pytorch-lightning==1.8.6
Expand Down
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Author: Talip Ucar
email: [email protected]
Description: Training script.
"""

# Standard library imports
import os
import traceback
Expand Down

0 comments on commit 6474a72

Please sign in to comment.