Skip to content

Commit

Permalink
Set up code as a package and switch to relative imports
Browse files Browse the repository at this point in the history
  • Loading branch information
samgelman committed Jan 24, 2025
1 parent f0df7d1 commit 9e5a695
Show file tree
Hide file tree
Showing 19 changed files with 53 additions and 56 deletions.
Empty file added code/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion code/compute_rosetta_standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging

import pandas as pd
import split_dataset as sd
from . import split_dataset as sd

logger = logging.getLogger("METL." + __name__)
logger.setLevel(logging.DEBUG)
Expand Down
8 changes: 3 additions & 5 deletions code/convert_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
import torch.nn as nn
import torchinfo

import constants
import models
import tasks
import utils
import encode as enc
from . import models
from . import utils
from . import encode as enc


def convert_checkpoint(ckpt_dict):
Expand Down
14 changes: 7 additions & 7 deletions code/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from torch.utils.data import DataLoader
import pytorch_lightning as pl

import datasets
import pdb_sampler
import utils
import constants
import split_dataset as sd
import encode as enc
from datasets import RosettaDatasetSQL
from . import datasets
from . import pdb_sampler
from . import utils
from . import constants
from . import split_dataset as sd
from . import encode as enc
from .datasets import RosettaDatasetSQL


class DMSDataModule(pl.LightningDataModule):
Expand Down
6 changes: 3 additions & 3 deletions code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torch.utils.data
from torch import Tensor

import constants
import split_dataset as sd
import encode as enc
from . import constants
from . import split_dataset as sd
from . import encode as enc


def load_standardization_params(split_dir, train_only=True):
Expand Down
6 changes: 3 additions & 3 deletions code/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import pandas as pd

import rosetta_data_utils as rd
import constants
import utils
from . import rosetta_data_utils as rd
from . import constants
from . import utils


def is_seq_level_encoding(encoding: str):
Expand Down
4 changes: 2 additions & 2 deletions code/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch.nn.functional as F
from torch import Tensor

import relative_attention as ra
import tasks
from . import relative_attention as ra
from . import tasks


def reset_parameters_helper(m: nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion code/parse_raw_dms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

import utils
from . import utils


def sort_and_save_to_csv(df, out_fn, precision=7, sort_muts=True, sort_variants=True, na_rep=""):
Expand Down
7 changes: 3 additions & 4 deletions code/parse_rosetta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import sqlite3
from tqdm import tqdm

import constants

import utils
import rosetta_data_utils as rd
from . import constants
from . import utils
from . import rosetta_data_utils as rd

logger = logging.getLogger("METL." + __name__)
logger.setLevel(logging.DEBUG)
Expand Down
6 changes: 3 additions & 3 deletions code/relative_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from torch import Tensor
from torch.nn import Linear, Dropout, LayerNorm
import time

import structure
import networkx as nx
import models

from . import structure
from . import models


class RelativePosition3D(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion code/rosetta_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import sqlalchemy as sqla

import utils
from . import utils


def convert_dms_to_rosettafy_indexing(ds_name, variants, reverse=False):
Expand Down
2 changes: 1 addition & 1 deletion code/split_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
from sklearn.model_selection import train_test_split

import utils
from . import utils


logger = logging.getLogger("METL." + __name__)
Expand Down
2 changes: 1 addition & 1 deletion code/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import networkx as nx
from biopandas.pdb import PandasPdb

import utils
from . import utils


class GraphType(Enum):
Expand Down
6 changes: 3 additions & 3 deletions code/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import pytorch_lightning as pl
import torchmetrics

import training_utils
from training_utils import CosineWarmupScheduler, ConstantWarmupScheduler
import models
from . import training_utils
from .training_utils import CosineWarmupScheduler, ConstantWarmupScheduler
from . import models


class RosettaTask(pl.LightningModule):
Expand Down
3 changes: 1 addition & 2 deletions code/tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" testing code """
import metl
import torch
import utils

from argparse import ArgumentParser

from . import utils

def load_checkpoint_run_inference(checkpoint_path, variants, dataset):
""" loads a finetuned 3D model from a checkpoint and scores variants with the model """
Expand Down
10 changes: 5 additions & 5 deletions code/train_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import wandb

import utils
from training_utils import BestMetricLogger, save_metrics_ptl, CondorStopping, create_log_dir, get_next_version
from datamodules import RosettaDataModule
import models
import tasks
from . import utils
from .training_utils import BestMetricLogger, save_metrics_ptl, CondorStopping, create_log_dir, get_next_version
from .datamodules import RosettaDataModule
from . import models
from . import tasks


class ModelCheckpoint(pytorch_lightning.callbacks.ModelCheckpoint):
Expand Down
16 changes: 8 additions & 8 deletions code/train_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
LearningRateMonitor, Checkpoint, StochasticWeightAveraging, ModelSummary, RichProgressBar, RichModelSummary
import numpy as np

import models
import training_utils
import utils
from datamodules import DMSDataModule
import finetuning_callbacks
from finetuning_callbacks import AnyFinetuning
from tasks import DMSTask
import analysis_utils as an
from . import models
from . import training_utils
from . import utils
from .datamodules import DMSDataModule
from . import finetuning_callbacks
from .finetuning_callbacks import AnyFinetuning
from .tasks import DMSTask
from . import analysis_utils as an

logging.basicConfig(level=logging.INFO)

Expand Down
6 changes: 3 additions & 3 deletions code/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from torch import optim, Tensor
from torch.optim.lr_scheduler import LambdaLR

import utils
import datamodules
from metrics import compute_metrics
from . import utils
from . import datamodules
from .metrics import compute_metrics


def save_scatterplots(dm, predictions_d, log_dir, suffix=""):
Expand Down
7 changes: 4 additions & 3 deletions code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from Bio import PDB
from Bio.PDB.PDBParser import PDBParser

import constants
from . import constants


def mkdir(d):
Expand Down Expand Up @@ -81,13 +81,14 @@ def load_dataset_metadata(metadata_fn: str = "data/dms_data/datasets.yml"):
def load_dataset(ds_name: Optional[str] = None,
ds_fn: Optional[str] = None,
sort_mutations: bool = False,
load_epistasis: bool = False):
load_epistasis: bool = False,
metadata_fn: str = "data/dms_data/datasets.yml"):
""" load a dataset as pandas dataframe """
if ds_name is None and ds_fn is None:
raise ValueError("must provide either ds_name or ds_fn to load a dataset")

if ds_fn is None:
datasets = load_dataset_metadata()
datasets = load_dataset_metadata(metadata_fn)
ds_fn = datasets[ds_name]["ds_fn"]

if not isfile(ds_fn):
Expand Down

0 comments on commit 9e5a695

Please sign in to comment.