diff --git a/examples/variant-prediction/predict.py b/examples/variant-prediction/predict.py index 81d72c40..b83b8285 100644 --- a/examples/variant-prediction/predict.py +++ b/examples/variant-prediction/predict.py @@ -115,7 +115,7 @@ def label_row(row, sequence, token_probs, alphabet, offset_idx): return score.item() -def compute_pppl(row, sequence, model, alphabet, offset_idx): +def compute_pppl(row, sequence, model, alphabet, offset_idx, nogpu=False): wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1] assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence" @@ -139,7 +139,9 @@ def compute_pppl(row, sequence, model, alphabet, offset_idx): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = alphabet.mask_idx with torch.no_grad(): - token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1) + if torch.cuda.is_available() and not nogpu: + batch_tokens_masked = batch_tokens_masked.cuda() + token_probs = torch.log_softmax(model(batch_tokens_masked)["logits"], dim=-1) log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item()) # vocab size return sum(log_probs) @@ -171,8 +173,10 @@ def main(args): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, 0, i] = alphabet.mask_idx # mask out first sequence with torch.no_grad(): + if torch.cuda.is_available() and not args.nogpu: + batch_tokens_masked = batch_tokens_masked.cuda() token_probs = torch.log_softmax( - model(batch_tokens_masked.cuda())["logits"], dim=-1 + model(batch_tokens_masked)["logits"], dim=-1 ) all_token_probs.append(token_probs[:, 0, i]) # vocab size token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) @@ -188,10 +192,12 @@ def main(args): ("protein1", args.sequence), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) + if torch.cuda.is_available() and not args.nogpu: + batch_tokens = batch_tokens.cuda() if args.scoring_strategy == "wt-marginals": with torch.no_grad(): - token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1) + token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1) df[model_location] = df.apply( lambda row: label_row( row[args.mutation_col], @@ -208,8 +214,10 @@ def main(args): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = alphabet.mask_idx with torch.no_grad(): + if torch.cuda.is_available() and not args.nogpu: + batch_tokens_masked = batch_tokens_masked.cuda() token_probs = torch.log_softmax( - model(batch_tokens_masked.cuda())["logits"], dim=-1 + model(batch_tokens_masked)["logits"], dim=-1 ) all_token_probs.append(token_probs[:, i]) # vocab size token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) @@ -227,7 +235,8 @@ def main(args): tqdm.pandas() df[model_location] = df.progress_apply( lambda row: compute_pppl( - row[args.mutation_col], args.sequence, model, alphabet, args.offset_idx + row[args.mutation_col], args.sequence, model, alphabet, args.offset_idx, + args.nogpu ), axis=1, )