Skip to content

Commit

Permalink
Pass optimizer into training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 25, 2024
1 parent ab89d37 commit aa1b544
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ logs/
/src/ramjet/photometric_database/microlensing_signal_meta_data/moa9yr_events_meta_oct2018.txt
/wandb/
!/ramjet.iml
/.idea/
117 changes: 117 additions & 0 deletions src/qusi/internal/chyrin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import math

from torch import permute
from torch.nn import (
BatchNorm1d,
Conv1d,
Dropout1d,
LeakyReLU,
MaxPool1d,
Module,
ModuleList, ConstantPad1d, Sigmoid,
)


class Chyrin(Module):
def __init__(self):
super().__init__()
self.blocks = ModuleList()
self.activation = LeakyReLU()
self.sigmoid = Sigmoid()
output_channels = 10
self.blocks.append(ResidualLightCurveNetworkBlock(
output_channels=output_channels, input_channels=1, dropout_rate=0.0,
batch_normalization=False))
input_channels = output_channels
for output_channels in [10, 10, 20, 20, 30, 30, 40, 40, 50, 50]:
self.blocks.append(ResidualLightCurveNetworkBlock(
output_channels=output_channels, input_channels=input_channels, pooling_scale_factor=2,
dropout_rate=0.0,
batch_normalization=False))
input_channels = output_channels
for _ in range(1):
self.blocks.append(ResidualLightCurveNetworkBlock(
input_channels=input_channels, output_channels=output_channels, dropout_rate=0.0,
batch_normalization=False))
input_channels = output_channels
self.end_conv = Conv1d(input_channels, 1, kernel_size=3)

def forward(self, x):
x = x.reshape([-1, 1, 3500])
for index, block in enumerate(self.blocks):
x = block(x)
x = self.end_conv(x)
x = self.sigmoid(x)
outputs = x.reshape([-1])
return outputs


class ResidualLightCurveNetworkBlock(Module):
def __init__(self, input_channels: int, output_channels: int, kernel_size: int = 3,
pooling_scale_factor: int = 1, batch_normalization: bool = False, dropout_rate: float = 0.0,
renorm: bool = False):
super().__init__()
self.activation = LeakyReLU()
dimension_decrease_factor = 4
if batch_normalization:
self.batch_normalization = BatchNorm1d(num_features=input_channels, track_running_stats=renorm)
else:
self.batch_normalization = None
reduced_channels = output_channels // dimension_decrease_factor
self.dimension_decrease_layer = Conv1d(
in_channels=input_channels, out_channels=reduced_channels, kernel_size=1)
self.convolutional_layer = Conv1d(
in_channels=reduced_channels, out_channels=reduced_channels, kernel_size=kernel_size,
padding=math.floor(kernel_size / 2)
)
self.dimension_increase_layer = Conv1d(
in_channels=reduced_channels, out_channels=output_channels, kernel_size=1)
if pooling_scale_factor > 1:
self.pooling_layer = MaxPool1d(kernel_size=pooling_scale_factor)
else:
self.pooling_layer = None
self.input_to_output_channel_difference = input_channels - output_channels
if output_channels != input_channels:
if output_channels < input_channels:
self.output_channels = output_channels
else:
self.dimension_change_layer = ConstantPad1d(padding=(0, -self.input_to_output_channel_difference),
value=0)
else:
self.dimension_change_layer = None
if dropout_rate > 0:
self.dropout_layer = Dropout1d(p=dropout_rate)
else:
self.dropout_layer = None

def forward(self, x):
"""
The forward pass of the block.
:param x: The input tensor.
:return: The output tensor of the layer.
"""
y = x
if self.batch_normalization is not None:
y = self.batch_normalization(y)
y = self.dimension_decrease_layer(y)
y = self.activation(y)
y = self.convolutional_layer(y)
y = self.activation(y)
y = self.dimension_increase_layer(y)
y = self.activation(y)
if self.pooling_layer is not None:
x = self.pooling_layer(x)
y = self.pooling_layer(y)
if self.input_to_output_channel_difference != 0:
x = permute(x, (0, 2, 1))
if self.input_to_output_channel_difference < 0:
x = self.dimension_change_layer(x)
else:
x = x[:, :, 0:self.output_channels]
x = permute(x, (0, 2, 1))
if self.dropout_layer is not None:
y = self.dropout_layer(y)
return x + y
2 changes: 1 addition & 1 deletion src/qusi/internal/infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def infer_session(
infer_datasets: list[FiniteStandardLightCurveDataset],
infer_datasets: list[FiniteStandardLightCurveDataset],
model: Module,
*,
batch_size: int,
Expand Down
12 changes: 12 additions & 0 deletions src/qusi/internal/light_curve_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,15 @@ def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray:
else:
example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap")
return example


def remove_random_elements(array: np.ndarray, ratio: float = 0.01) -> np.ndarray:
"""Removes random values from an array."""
light_curve_length = array.shape[0]
max_values_to_remove = int(light_curve_length * ratio)
if max_values_to_remove != 0:
values_to_remove = np.random.randint(max_values_to_remove)
else:
values_to_remove = 0
random_indexes = np.random.choice(range(light_curve_length), values_to_remove, replace=False)
return np.delete(array, random_indexes, axis=0)
29 changes: 29 additions & 0 deletions src/qusi/internal/simple_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from torch.nn import (
LeakyReLU,
Module,
Sigmoid, Linear,
)


class SimpleDense(Module):
def __init__(self):
super().__init__()
self.activation = LeakyReLU()
self.sigmoid = Sigmoid()
self.dense0 = Linear(in_features=3500, out_features=100)
self.dense1 = Linear(in_features=100, out_features=100)
self.dense2 = Linear(in_features=100, out_features=1)

def forward(self, x):
x = x.reshape([-1, 3500])
x = self.dense0(x)
x = self.activation(x)
x = self.dense1(x)
x = self.activation(x)
x = self.dense2(x)
x = self.activation(x)
x = self.sigmoid(x)
outputs = x.reshape([-1])
return outputs
14 changes: 1 addition & 13 deletions src/qusi/internal/train_hyperparameter_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class TrainHyperparameterConfiguration:
train_steps_per_cycle: int
validation_steps_per_cycle: int
batch_size: int
learning_rate: float
optimizer_epsilon: float
weight_decay: float
norm_based_gradient_clip: float

@classmethod
Expand All @@ -37,9 +34,6 @@ def new(
train_steps_per_cycle: int = 100,
validation_steps_per_cycle: int = 10,
batch_size: int = 100,
learning_rate: float = 1e-4,
optimizer_epsilon: float = 1e-7,
weight_decay: float = 0.0001,
norm_based_gradient_clip: float = 1.0,
):
"""
Expand All @@ -53,19 +47,13 @@ def new(
:param validation_steps_per_cycle: The number of validation steps per cycle.
:param batch_size: The size of the batch for each train process. Each training step will use a number of observations
equal to this value multiplied by the number of train processes.
:param learning_rate: The learning rate.
:param optimizer_epsilon: The epsilon to be used by the optimizer.
:param weight_decay: The weight decay of the optimizer.
:param norm_based_gradient_clip: The norm based gradient clipping value.
:return: The hyperparameter configuration.
"""
return cls(
learning_rate=learning_rate,
optimizer_epsilon=optimizer_epsilon,
weight_decay=weight_decay,
batch_size=batch_size,
cycles=cycles,
train_steps_per_cycle=train_steps_per_cycle,
validation_steps_per_cycle=validation_steps_per_cycle,
batch_size=batch_size,
norm_based_gradient_clip=norm_based_gradient_clip,
)
13 changes: 7 additions & 6 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import numpy as np
import torch
import wandb
from torch.nn import BCELoss, Module
from torch.optim import AdamW
from torch.optim import AdamW, Optimizer
from torch.utils.data import DataLoader

import wandb
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset
Expand All @@ -25,6 +24,7 @@ def train_session(
train_datasets: list[LightCurveDataset],
validation_datasets: list[LightCurveDataset],
model: Module,
optimizer: Optimizer | None = None,
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
*,
Expand Down Expand Up @@ -91,12 +91,13 @@ def train_session(
)
validation_dataloaders.append(validation_dataloader)
if torch.cuda.is_available() and not debug:
device = torch.device("cuda")
device = torch.device('cuda')
else:
device = torch.device("cpu")
device = torch.device('cpu')
model = model.to(device, non_blocking=True)
loss_function = loss_function.to(device, non_blocking=True)
optimizer = AdamW(model.parameters())
if optimizer is None:
optimizer = AdamW(model.parameters())
metric_functions: list[Module] = [
metric_function.to(device, non_blocking=True)
for metric_function in metric_functions
Expand Down

0 comments on commit aa1b544

Please sign in to comment.