Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yivlad/pretrain mlm #1

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions analysis.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ bash pretraining.sh "$DIR"/pretraining

# Then generate negative decoys
# This step is very CPU and RAM intensive
bash negative_decoys.sh "$DIR"/negative_decoys "$CPU"
# bash negative_decoys.sh "$DIR"/negative_decoys "$CPU"

# Finally perform supervised training and evaluate the model
# This step is faster with a GPU
bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training
# bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training
17 changes: 17 additions & 0 deletions bertrand-job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
#SBATCH --mail-type=ALL # Powiadomienia mailowe. Opcje: NONE, BEGIN, END, FAIL, ALL
#SBATCH [email protected] # adres e-mail
#SBATCH --ntasks=4 # Uruchomienie na jednym procesorze
#SBATCH --mem=32gb
#SBATCH --gpus=a100:1
#SBATCH --time=72:00:00 # maksymalny limit czasu DD-HH:MM:SS
#SBATCH --partition=long

pwd; hostname; date

source /home2/sfglab/yvladyslav/anaconda3/etc/profile.d/conda.sh
cd bertrand
conda activate bertrand
./analysis.sh "/home2/sfglab/yvladyslav/pretrain-mlm/bertrand_results" 4

date
5 changes: 2 additions & 3 deletions bertrand/pretraining/peptide_tcr_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ def read_peptides(fn: str) -> pd.DataFrame:
logging.info(f"{len(presented_peptides)} peptides read")
presented_unique = (
presented_peptides.reset_index()
.groupby("Peptide2")
.groupby("peptide_seq")
.agg(
{
"HLA_type": lambda x: "|".join(sorted(x)),
"index": lambda x: "|".join(sorted(x)),
}
)
.reset_index()
Expand Down Expand Up @@ -114,7 +113,7 @@ def sample_peptide_tcr_repertoire(
)

peptides_sampled.loc[:, "CDR3b"] = synthetic_tcrs.values
peptide_tcr_repertoire = peptides_sampled.rename(columns={"Peptide2": "Peptide"})
peptide_tcr_repertoire = peptides_sampled.rename(columns={"peptide_seq": "Peptide"})
return peptide_tcr_repertoire


Expand Down
1 change: 0 additions & 1 deletion bertrand/training/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import shutil
from copy import deepcopy
from functools import partial
from glob import glob
from typing import Union, List, Tuple, Dict

Expand Down
146 changes: 127 additions & 19 deletions env.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,135 @@
name: bertrand
channels:
- nvidia
- pytorch
- defaults
dependencies:
- biopython=1.78=py38h7b6447c_0
- h5py=2.10.0=py38hd6299e0_1
- hdf5=1.10.6=hb1b8bf9_0
- joblib=1.1.0=pyhd3eb1b0_0
- matplotlib=3.3.4=py38h06a4308_0
- numpy=1.21.2=py38h20f2e39_0
- pandas=1.4.1=py38h295c915_0
- pip=21.2.4=py38h06a4308_0
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- biopython=1.78=py38h7f8727e_0
- blas=1.0=mkl
- bottleneck=1.3.5=py38h7deecbd_0
- brotli=1.0.9=h5eee18b_7
- brotli-bin=1.0.9=h5eee18b_7
- ca-certificates=2023.01.10=h06a4308_0
- contourpy=1.0.5=py38hdb19cb5_0
- cudatoolkit=11.5.1=hcf5317a_9
- cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0
- expat=2.4.9=h6a678d5_0
- fftw=3.3.9=h27cfd23_1
- fontconfig=2.14.1=h52c9d5c_1
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0
- giflib=5.2.1=h5eee18b_3
- glib=2.63.1=h5a9c865_0
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb453b48_1
- h5py=3.7.0=py38h737f45e_0
- hdf5=1.10.6=h3ffc7dd_1
- icu=58.2=he6710b0_3
- importlib_resources=5.2.0=pyhd3eb1b0_1
- intel-openmp=2021.4.0=h06a4308_3561
- joblib=1.2.0=py38h06a4308_0
- jpeg=9e=h5eee18b_1
- kiwisolver=1.4.4=py38h6a678d5_0
- lcms2=2.12=h3be6417_0
- lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_7
- libbrotlidec=1.0.9=h5eee18b_7
- libbrotlienc=1.0.9=h5eee18b_7
- libdeflate=1.17=h5eee18b_0
- libedit=3.1.20221030=h5eee18b_0
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.5.0=h6a678d5_2
- libuuid=1.41.5=h5eee18b_0
- libuv=1.44.2=h5eee18b_0
- libwebp=1.2.4=h11a3e52_1
- libwebp-base=1.2.4=h5eee18b_1
- libxcb=1.15=h7f8727e_0
- libxml2=2.9.14=h74e7548_0
- lz4-c=1.9.4=h6a678d5_0
- matplotlib=3.7.1=py38h06a4308_1
- matplotlib-base=3.7.1=py38h417a72b_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- munkres=1.1.4=py_0
- ncurses=6.4=h6a678d5_0
- numexpr=2.8.4=py38he184ba9_0
- numpy=1.22.3=py38he7a7128_0
- numpy-base=1.22.3=py38hf524024_0
- openssl=1.1.1t=h7f8727e_0
- packaging=23.0=py38h06a4308_0
- pandas=1.5.3=py38h417a72b_0
- pcre=8.45=h295c915_0
- pillow=9.4.0=py38h6a678d5_0
- pip=23.0.1=py38h06a4308_0
- pyparsing=3.0.9=py38h06a4308_0
- pyqt=5.9.2=py38h05f1152_4
- python=3.8.0=h0371630_2
- pytorch=1.11.0=py3.8_cuda10.2_cudnn7.6.5_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.11.0=py3.8_cuda11.5_cudnn8.3.2_0
- pytorch-mutex=1.0=cuda
- pytz=2022.7=py38h06a4308_0
- qt=5.9.7=h5867ecd_1
- readline=7.0=h7b6447c_5
- scikit-learn=0.24.2=py38ha9443f7_0
- scipy=1.7.3=py38hc147768_0
- seaborn=0.11.1=pyhd3eb1b0_0
- tokenizers=0.10.3=py38hb317417_1
- tqdm=4.62.3=pyhd3eb1b0_1
- scipy=1.7.3=py38h6c91a56_2
- seaborn=0.12.2=py38h06a4308_0
- setuptools=66.0.0=py38h06a4308_0
- sip=4.19.13=py38h295c915_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.33.0=h62c20be_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h1ccaba5_0
- tokenizers=0.11.4=py38h3dcd8bd_1
- tornado=6.2=py38h5eee18b_0
- tqdm=4.65.0=py38hb070fc8_0
- typing_extensions=4.5.0=py38h06a4308_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.4.2=h5eee18b_0
- zipp=3.11.0=py38h06a4308_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.5=hc292b87_0
- pip:
- datasets==1.18.3
- fastcluster==1.2.4
- leven==1.0.4
- pytorch-lightning==0.7.1
- transformers==4.16.2
prefix: /home/ardigen/miniconda3/envs/bertrand
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
- attrs==23.1.0
- certifi==2023.5.7
- charset-normalizer==3.1.0
- click==8.1.3
- datasets==2.12.0
- dill==0.3.6
- fastcluster==1.2.6
- filelock==3.12.0
- frozenlist==1.3.3
- fsspec==2023.5.0
- huggingface-hub==0.14.1
- idna==3.4
- leven==1.0.4
- lightning-utilities==0.8.0
- multidict==6.0.4
- multiprocess==0.70.14
- nose==1.3.7
- pyarrow==12.0.0
- pytorch-lightning==2.0.2
- pyyaml==6.0
- regex==2023.5.5
- requests==2.31.0
- responses==0.18.0
- sacremoses==0.0.53
- torchmetrics==0.11.4
- transformers==4.16.2
- urllib3==2.0.2
- xxhash==3.2.0
- yarl==1.9.2
prefix: /home2/sfglab/yvladyslav/anaconda3/envs/bertrand
6 changes: 3 additions & 3 deletions train_and_evaluate.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set -x
set -ex
DATA_DIR=$1
MODEL_DIR=$2
OUT_DIR=$3
Expand All @@ -9,9 +9,9 @@ python -m bertrand.training.train \
--input-dir=$DATA_DIR \
--model-ckpt=$MODEL_DIR \
--output-dir=$OUT_DIR \
--n-splits=21
--n-splits=1

python -m bertrand.training.evaluate \
--datasets-dir=$DATA_DIR \
--results-dir=$OUT_DIR \
--out=$OUT_DIR/results.csv
--out=$OUT_DIR/results.csv