Skip to content

Commit

Permalink
Update imports to support use as a package or standalone scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
samgelman committed Jan 28, 2025
1 parent 9e5a695 commit 02ddabd
Show file tree
Hide file tree
Showing 18 changed files with 137 additions and 56 deletions.
6 changes: 5 additions & 1 deletion code/compute_rosetta_standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import logging

import pandas as pd
from . import split_dataset as sd

try:
from . import split_dataset as sd
except ImportError:
import split_dataset as sd

logger = logging.getLogger("METL." + __name__)
logger.setLevel(logging.DEBUG)
Expand Down
11 changes: 8 additions & 3 deletions code/convert_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
import torch.nn as nn
import torchinfo

from . import models
from . import utils
from . import encode as enc
try:
from . import models
from . import utils
from . import encode as enc
except ImportError:
import models
import utils
import encode as enc


def convert_checkpoint(ckpt_dict):
Expand Down
23 changes: 16 additions & 7 deletions code/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from . import datasets
from . import pdb_sampler
from . import utils
from . import constants
from . import split_dataset as sd
from . import encode as enc
from .datasets import RosettaDatasetSQL
try:
from . import datasets
from . import pdb_sampler
from . import utils
from . import constants
from . import split_dataset as sd
from . import encode as enc
from .datasets import RosettaDatasetSQL
except ImportError:
import datasets
import pdb_sampler
import utils
import constants
import split_dataset as sd
import encode as enc
from datasets import RosettaDatasetSQL


class DMSDataModule(pl.LightningDataModule):
Expand Down
12 changes: 8 additions & 4 deletions code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
import torch.utils.data
from torch import Tensor

from . import constants
from . import split_dataset as sd
from . import encode as enc

try:
from . import constants
from . import split_dataset as sd
from . import encode as enc
except ImportError:
import constants
import split_dataset as sd
import encode as enc

def load_standardization_params(split_dir, train_only=True):
# if train_only True, then will only load standardization params for the training set (filename energy_X_train)
Expand Down
12 changes: 8 additions & 4 deletions code/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import numpy as np
import pandas as pd

from . import rosetta_data_utils as rd
from . import constants
from . import utils

try:
from . import rosetta_data_utils as rd
from . import constants
from . import utils
except ImportError:
import rosetta_data_utils as rd
import constants
import utils

def is_seq_level_encoding(encoding: str):
""" helper function to differentiate sequence-level vs. residue-level encodings """
Expand Down
9 changes: 6 additions & 3 deletions code/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import torch.nn.functional as F
from torch import Tensor

from . import relative_attention as ra
from . import tasks

try:
from . import relative_attention as ra
from . import tasks
except ImportError:
import relative_attention as ra
import tasks

def reset_parameters_helper(m: nn.Module):
""" helper function for resetting model parameters, meant to be used with model.apply() """
Expand Down
5 changes: 4 additions & 1 deletion code/parse_raw_dms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import numpy as np
import pandas as pd

from . import utils
try:
from . import utils
except ImportError:
import utils


def sort_and_save_to_csv(df, out_fn, precision=7, sort_muts=True, sort_variants=True, na_rep=""):
Expand Down
11 changes: 8 additions & 3 deletions code/parse_rosetta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import sqlite3
from tqdm import tqdm

from . import constants
from . import utils
from . import rosetta_data_utils as rd
try:
from . import constants
from . import utils
from . import rosetta_data_utils as rd
except ImportError:
import constants
import utils
import rosetta_data_utils as rd

logger = logging.getLogger("METL." + __name__)
logger.setLevel(logging.DEBUG)
Expand Down
9 changes: 6 additions & 3 deletions code/relative_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import time
import networkx as nx

from . import structure
from . import models

try:
from . import structure
from . import models
except ImportError:
import structure
import models

class RelativePosition3D(nn.Module):
""" Contact map-based relative position embeddings """
Expand Down
5 changes: 4 additions & 1 deletion code/rosetta_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pandas as pd
import sqlalchemy as sqla

from . import utils
try:
from . import utils
except ImportError:
import utils


def convert_dms_to_rosettafy_indexing(ds_name, variants, reverse=False):
Expand Down
5 changes: 4 additions & 1 deletion code/split_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import pandas as pd
from sklearn.model_selection import train_test_split

from . import utils
try:
from . import utils
except ImportError:
import utils


logger = logging.getLogger("METL." + __name__)
Expand Down
5 changes: 4 additions & 1 deletion code/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import networkx as nx
from biopandas.pdb import PandasPdb

from . import utils
try:
from . import utils
except ImportError:
import utils


class GraphType(Enum):
Expand Down
11 changes: 8 additions & 3 deletions code/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
import pytorch_lightning as pl
import torchmetrics

from . import training_utils
from .training_utils import CosineWarmupScheduler, ConstantWarmupScheduler
from . import models
try:
from . import training_utils
from .training_utils import CosineWarmupScheduler, ConstantWarmupScheduler
from . import models
except ImportError:
import training_utils
from training_utils import CosineWarmupScheduler, ConstantWarmupScheduler
import models


class RosettaTask(pl.LightningModule):
Expand Down
9 changes: 6 additions & 3 deletions code/tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
""" testing code """
import metl
import code
import torch
from argparse import ArgumentParser

from . import utils
try:
from . import utils
except ImportError:
import utils

def load_checkpoint_run_inference(checkpoint_path, variants, dataset):
""" loads a finetuned 3D model from a checkpoint and scores variants with the model """
model, data_encoder = metl.get_from_checkpoint(checkpoint_path)
model, data_encoder = code.get_from_checkpoint(checkpoint_path)

# load the wild-type sequence and the PDB file (needed for 3D RPE) for the dataset
datasets = utils.load_dataset_metadata()
Expand Down
17 changes: 12 additions & 5 deletions code/train_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

import wandb

from . import utils
from .training_utils import BestMetricLogger, save_metrics_ptl, CondorStopping, create_log_dir, get_next_version
from .datamodules import RosettaDataModule
from . import models
from . import tasks
try:
from . import utils
from .training_utils import BestMetricLogger, save_metrics_ptl, CondorStopping, create_log_dir, get_next_version
from .datamodules import RosettaDataModule
from . import models
from . import tasks
except ImportError:
import utils
from training_utils import BestMetricLogger, save_metrics_ptl, CondorStopping, create_log_dir, get_next_version
from datamodules import RosettaDataModule
import models
import tasks


class ModelCheckpoint(pytorch_lightning.callbacks.ModelCheckpoint):
Expand Down
26 changes: 18 additions & 8 deletions code/train_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@
LearningRateMonitor, Checkpoint, StochasticWeightAveraging, ModelSummary, RichProgressBar, RichModelSummary
import numpy as np

from . import models
from . import training_utils
from . import utils
from .datamodules import DMSDataModule
from . import finetuning_callbacks
from .finetuning_callbacks import AnyFinetuning
from .tasks import DMSTask
from . import analysis_utils as an
try:
from . import models
from . import training_utils
from . import utils
from .datamodules import DMSDataModule
from . import finetuning_callbacks
from .finetuning_callbacks import AnyFinetuning
from .tasks import DMSTask
from . import analysis_utils as an
except ImportError:
import models
import training_utils
import utils
from datamodules import DMSDataModule
import finetuning_callbacks
from finetuning_callbacks import AnyFinetuning
from tasks import DMSTask
import analysis_utils as an

logging.basicConfig(level=logging.INFO)

Expand Down
11 changes: 8 additions & 3 deletions code/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@
from torch import optim, Tensor
from torch.optim.lr_scheduler import LambdaLR

from . import utils
from . import datamodules
from .metrics import compute_metrics
try:
from . import utils
from . import datamodules
from .metrics import compute_metrics
except ImportError:
import utils
import datamodules
from metrics import compute_metrics


def save_scatterplots(dm, predictions_d, log_dir, suffix=""):
Expand Down
6 changes: 4 additions & 2 deletions code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from Bio import PDB
from Bio.PDB.PDBParser import PDBParser

from . import constants

try:
from . import constants
except ImportError:
import constants

def mkdir(d):
""" creates given dir if it does not already exist """
Expand Down

0 comments on commit 02ddabd

Please sign in to comment.