Skip to content

Commit

Permalink
Now working torch channel wise lstm. Changed up some things, so that …
Browse files Browse the repository at this point in the history
…progbar metrics equal history metrics
  • Loading branch information
ga84jog committed Sep 15, 2024
1 parent 20ccac3 commit 09bf4ec
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 85 deletions.
2 changes: 1 addition & 1 deletion src/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __generator(self):
bining=self._bining,
one_hot=self._one_hot):
yield X, y, M
# TODO! added because remainders seem to destabilize training
# Added because remainders seem to destabilize training
self._remainder_M = np.array([])
self._remainder_X = np.array([])
self._remainder_y = np.array([])
Expand Down
122 changes: 67 additions & 55 deletions src/models/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,7 @@ def _evaluate_with_arrays(self,
y = torch.tensor(y).to(self._device)

# Init counter epoch variables
self._on_epoch_start(prefix,
generator_size=len(y),
batch_size=batch_size if batch_size is not None else len(y))
self._on_epoch_start(prefix=prefix, generator_size=len(y), batch_size=batch_size)

# Evaluate only if necessary
self.eval()
Expand All @@ -626,10 +624,10 @@ def _evaluate_with_arrays(self,
aggr_labels = []

with torch.no_grad():
iter_len = len(x) - 1
for sample_idx, (val_inputs, val_labels) in enumerate(zip(x, y)):
val_labels: torch.Tensor
val_inputs: torch.Tensor

for sample_idx in range(self._iter_len):
val_labels: torch.Tensor = y[sample_idx]
val_inputs: torch.Tensor = x[sample_idx]

# Adjust dimensions
val_labels = val_labels.unsqueeze(0)
Expand All @@ -655,13 +653,12 @@ def _evaluate_with_arrays(self,
aggr_outputs.append(val_outputs)
aggr_labels.append(val_labels)

aggr_outputs, aggr_labels = self._evaluate_batch(aggr_outputs,
aggr_labels,
is_test=is_test,
finalize=iter_len == sample_idx)

if sample_idx == iter_len:
break
aggr_outputs, aggr_labels = self._evaluate_batch(
aggr_outputs,
aggr_labels,
is_test=is_test,
is_epoch_end=sample_idx == self._iter_len - 1,
)

# Only update history if test
if is_test:
Expand Down Expand Up @@ -705,7 +702,7 @@ def _evaluate_with_dataloader(self,

# Init counter epoch variables
generator_size = len(generator)
self._on_epoch_start(prefix,
self._on_epoch_start(prefix=prefix,
generator_size=generator_size,
batch_size=batch_size if batch_size is not None else len(y))

Expand All @@ -717,7 +714,6 @@ def _evaluate_with_dataloader(self,
self.eval()
with torch.no_grad():
# Unroll generator
iter_len = len(generator) - 1
for sample_idx, (val_inputs, val_labels) in enumerate(generator):

if masking_flag == None:
Expand Down Expand Up @@ -749,12 +745,14 @@ def _evaluate_with_dataloader(self,
aggr_outputs.append(val_outputs)
aggr_labels.append(val_labels)

aggr_outputs, aggr_labels = self._evaluate_batch(aggr_outputs,
aggr_labels,
is_test=is_test,
finalize=iter_len == sample_idx)
aggr_outputs, aggr_labels = self._evaluate_batch(
aggr_outputs,
aggr_labels,
is_test=is_test,
is_epoch_end=self._iter_len == sample_idx,
)

if sample_idx == iter_len:
if sample_idx == self._iter_len:
break

# Only update history if test
Expand All @@ -771,7 +769,7 @@ def _evaluate_batch(self,
outputs: list,
labels: list,
is_test: bool = False,
finalize: bool = False):
is_epoch_end: bool = False):
"""Evaluates the batch concatenated by higher level _evaluate_with_dataloader or
_evaluate_with_arrays method and updates metric state. Finalize will finalize
the epoch progbar.
Expand Down Expand Up @@ -826,8 +824,9 @@ def _evaluate_batch(self,
outputs,
labels,
prefix="test" if is_test else "val",
update_progbar=finalize and not is_test,
finalize=finalize)
is_epoch_end=is_epoch_end,
update_progbar=is_epoch_end and not is_test,
finalize=is_epoch_end)
self._sample_count = 0
self._batch_count += 1

Expand Down Expand Up @@ -954,22 +953,25 @@ def _latest_epoch(self, epochs: int, filepath: Path) -> int:
def _on_epoch_start(self,
prefix: str,
generator_size: int = 0,
batch_size: int = 0,
batch_size: int = None,
verbose: bool = True,
has_val: bool = False):
"""Setup epoch state and verbosity before start
"""
# Insecure about consistency of these here
self._current_metrics = dict()
if generator_size:
self._epoch_progbar = Progbar((generator_size // batch_size))
self._sample_size = (generator_size // batch_size) * batch_size
self._batch_size = generator_size if batch_size is None else batch_size
if generator_size and verbose:
self._epoch_progbar = Progbar((generator_size // self._batch_size))

# Batch iter variables
# Counter variables
self._sample_count = 0
self._batch_count = 1

# Iteration control variables
self._generator_size = generator_size
self._batch_size = batch_size
self._has_val = has_val
self._iter_len = (self._generator_size // self._batch_size) * self._batch_size

# Reset metric values
for metric in self._metrics[prefix].values():
Expand Down Expand Up @@ -1001,8 +1003,8 @@ def _on_epoch_end(self, epoch: int, prefix: str = ""):
# Update metric history
if hasattr(self._history, f"{prefix}_metrics"):
metric_history = getattr(self._history, f"{prefix}_metrics")
for key, metric in self._metrics[prefix].items():
metric_history[key][epoch] = metric.compute()
for key, metric_value in self._current_metrics.items():
metric_history[key][epoch] = metric_value

self._current_metrics = dict()
self._sample_count = 0
Expand All @@ -1011,7 +1013,11 @@ def _on_epoch_end(self, epoch: int, prefix: str = ""):
self._batch_size = 0
self._has_val = False

def _optimize_batch(self, outputs: list, labels: list, finalize: bool = False):
def _optimize_batch(self,
outputs: list,
labels: list,
is_epoch_end: int,
has_val: bool = False):
"""Applies optimizer to the concatenated outputs and labels, handed down by the higher level
train_with_dataloader or train_with_array methods and returns updated list of outputs
and labels. The list is reset to zero if the batch was complete and is unmodified if
Expand Down Expand Up @@ -1077,7 +1083,8 @@ def _optimize_batch(self, outputs: list, labels: list, finalize: bool = False):
labels,
prefix="train",
update_progbar=True,
finalize=finalize)
is_epoch_end=is_epoch_end,
finalize=not has_val and is_epoch_end)
self._sample_count = 0
self._batch_count += 1

Expand All @@ -1097,6 +1104,7 @@ def _prefixed_metrics(self, metrics: List[Tuple[str, float]],
def _predict_dataloader(self, generator: DataLoader, **kwargs):
""" Make predictions on a dataloader while handling masking and dimensions.
"""
self._on_epoch_start(prefix="", generator_size=len(generator), has_val=False, verbose=False)
self.eval()

aggr_outputs = []
Expand Down Expand Up @@ -1142,14 +1150,15 @@ def _predict_numpy(self, x: np.ndarray, **kwargs):
masking_flag = False

x = torch.tensor(x, dtype=torch.float32).to(self._device)
self._on_epoch_start(prefix="", generator_size=len(x), has_val=False, verbose=False)

# Evaluate only if necessary
self.eval()

# Batch iter variables
aggr_outputs = []

with torch.no_grad():
iter_len = len(x) - 1
for sample_idx, input in enumerate(x):
input: torch.Tensor
input = input.unsqueeze(0)
Expand All @@ -1175,7 +1184,7 @@ def _predict_numpy(self, x: np.ndarray, **kwargs):
output = placeholder
aggr_outputs.append(output.cpu().numpy())

if sample_idx == iter_len:
if sample_idx == self._iter_len - 1:
break

return np.concatenate(aggr_outputs)
Expand Down Expand Up @@ -1252,15 +1261,17 @@ def _train_with_arrays(self,
masks = torch.tensor(masks, dtype=torch.float32)[idx].bool().to(self._device)

# Init counter epoch variables
self._on_epoch_start("train", data_size, batch_size, has_val)
self._on_epoch_start(prefix="train",
generator_size=data_size,
batch_size=batch_size,
has_val=has_val)

# Batch iter variables
aggr_outputs = []
aggr_labels = []

# Main loop
iter_len = (len(x) // batch_size) * batch_size - 1
for sample_idx in range(x.shape[0]):
for sample_idx in range(self._iter_len):
label: torch.Tensor = y[sample_idx]
input: torch.Tensor = x[sample_idx]

Expand Down Expand Up @@ -1288,14 +1299,12 @@ def _train_with_arrays(self,
aggr_labels.append(label)

# Optimizer network on abtch
aggr_outputs, \
aggr_labels = self._optimize_batch(outputs=aggr_outputs, \
labels=aggr_labels, \
finalize=(not has_val and sample_idx == iter_len))

# Miss inclomplete batch
if sample_idx == iter_len:
break
aggr_outputs, aggr_labels = self._optimize_batch(
outputs=aggr_outputs,
labels=aggr_labels,
has_val=has_val,
is_epoch_end=sample_idx == self._iter_len - 1,
)

self._on_epoch_end(epoch, prefix="train")

Expand All @@ -1315,15 +1324,14 @@ def _train_with_dataloader(self,

# Tracking variables
generator_size = len(generator)
self._on_epoch_start("train",
self._on_epoch_start(prefix="train",
generator_size=generator_size,
batch_size=batch_size,
has_val=has_val)
aggr_outputs = []
aggr_labels = []

# Main loop
iter_len = (len(generator) // batch_size) * batch_size - 1
for sample_idx, (input, label) in enumerate(generator):
input: torch.Tensor
label: torch.Tensor
Expand Down Expand Up @@ -1360,10 +1368,11 @@ def _train_with_dataloader(self,
aggr_outputs, \
aggr_labels = self._optimize_batch(outputs=aggr_outputs, \
labels=aggr_labels, \
finalize=not has_val and iter_len == sample_idx)
has_val=has_val, \
is_epoch_end=self._iter_len == sample_idx)

# Miss inclomplete batch
if sample_idx == iter_len:
if sample_idx == self._iter_len:
break

self._on_epoch_end(epoch, prefix="train")
Expand All @@ -1373,6 +1382,7 @@ def _update_metrics(self,
outputs: torch.Tensor,
labels: torch.Tensor,
prefix: str,
is_epoch_end: bool = False,
finalize: bool = False,
update_progbar: bool = False):
""" Update the loss running average and the metrics base on the outputs and labels for
Expand All @@ -1396,12 +1406,14 @@ def _update_metrics(self,
labels.int() if hasattr(labels, "int") else labels.astype(int))

# Reset accumulators and count
self._current_metrics["loss"] = loss.item()
self._current_metrics.update(dict(self._get_metrics(self._metrics[prefix])))
epoch_loss = loss.item()
epoch_metrics = self._get_metrics(self._metrics[prefix])
self._current_metrics["loss"] = epoch_loss
self._current_metrics.update(dict(epoch_metrics))

if update_progbar:
self._epoch_progbar.update(
self._epoch_progbar.target if prefix == "val" else self._batch_count,
values=self._prefixed_metrics(self._get_metrics(self._metrics[prefix]),
values=self._prefixed_metrics(epoch_metrics,
prefix=prefix if prefix != "train" else None),
finalize=finalize) # and self._batch_count == self._epoch_progbar.target)
finalize=finalize)
24 changes: 12 additions & 12 deletions src/models/pytorch/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(self,
p.data.fill_(0)

def forward(self, x, masks=None) -> torch.Tensor:
masking_falg = masks is not None
if masking_falg:
masking_flag = masks is not None
if masking_flag:
masks = masks.to(self._device)
x = x.to(self._device)

Expand All @@ -94,19 +94,19 @@ def forward(self, x, masks=None) -> torch.Tensor:
x, _ = self._lstm_final(x)

# Case 1: deep supervision
if masking_falg:
if masking_flag:
# Apply the linear layer to each LSTM output at each timestep (ts)
B, T, hidden_size = x.shape
x = x.reshape(B * T, hidden_size)
x = x.view(B * T, hidden_size)
x = self._output_layer(x)
x = x.reshape(B, T, -1)
x = x.view(B, T, -1)

# Case 2: standard LSTM or target replication
else:
# Apply linear layer only to the last output of the LSTM
x = x[:, -1, :]
x = x.reshape(x.shape[0], 1, x.shape[1])
x = self._output_layer(x)
x = x.unsqueeze(1)

# Apply final activation if specified
if self._final_activation and self._apply_activation:
Expand Down Expand Up @@ -229,8 +229,8 @@ def __init__(self,
p.data.fill_(0)

def forward(self, x, masks=None) -> torch.Tensor:
masking_falg = masks is not None
if masking_falg:
masking_flag = masks is not None
if masking_flag:
masks = masks.to(self._device)

x = x.to(self._device)
Expand All @@ -251,19 +251,19 @@ def forward(self, x, masks=None) -> torch.Tensor:
x, _ = lstm(x)

# Case 1: deep supervision
if masking_falg:
if masking_flag:
# Apply the linear layer to each LSTM output at each timestep (ts)
B, T, hidden_size = x.shape
x = x.reshape(B * T, hidden_size)
x = x.view(B * T, hidden_size)
x = self._output_layer(x)
x = x.reshape(B, T, -1)
x = x.view(B, T, -1)

# Case 2: standard LSTM or target replication
else:
# Apply linear layer only to the last output of the LSTM
x = x[:, -1, :]
x = x.reshape(x.shape[0], 1, x.shape[1])
x = self._output_layer(x)
x = x.unsqueeze(1)

# Apply final activation if specified
if self._final_activation and self._apply_activation:
Expand Down
Loading

0 comments on commit 09bf4ec

Please sign in to comment.