Skip to content

Commit

Permalink
Had it working for torch as well but adding target repl had me change…
Browse files Browse the repository at this point in the history
… some mechanics. Feeling like DECOMP DS performance degraged. Have to get it all running again then repeat for tf2. Then I'll have pretty much all I need!
  • Loading branch information
ga84jog committed Jul 30, 2024
1 parent 9258d8e commit 7c4040a
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 54 deletions.
114 changes: 88 additions & 26 deletions src/models/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def __init__(self,
final_activation,
output_dim: int,
model_path: Path = None,
task: Literal["multilabel", "multiclass", "binary"] = None):
task: Literal["multilabel", "multiclass", "binary"] = None,
target_repl_coef: float = 0.0):
super(AbstractTorchNetwork, self).__init__()
self._model_path = model_path
self._target_repl_coef = target_repl_coef
if final_activation is None:
if output_dim == 1:
self._final_activation = nn.Sigmoid()
Expand Down Expand Up @@ -203,12 +205,16 @@ def compile(self,
self._apply_activation = loss != "categorical_crossentropy"

if loss in loss_mapping:
self._loss = loss_mapping[loss](weight=class_weight)
reduction = "none" if self._target_repl_coef else "mean"

self._loss = loss_mapping[loss](weight=class_weight, reduction=reduction)
else:
raise ValueError(f"Loss {loss} not supported."
f"Supported losses are {loss_mapping.keys()}")
else:
self._apply_activation = not isinstance(loss, nn.CrossEntropyLoss)
if self._target_repl_coef and self._loss.reduction != "none":
loss = self._loss.__class__(reduction="none")
self._loss = loss
self._metrics = self._init_metrics(metrics)

Expand Down Expand Up @@ -406,25 +412,33 @@ def _train_with_arrays(self,
input = self._remove_end_padding(input)
mask = None

if self._target_repl_coef:
mask = torch.ones(1, input.shape[1], 1)

# Adjust dimensions
if len(input.shape) < 3:
input = input.unsqueeze(0)

# Create predictions
output = self(input, masks=mask)
label = label.T

# Apply masking
if masking_flag:
output = output[:, mask.squeeze()]
label = label[mask.squeeze(), :]
if len(label.shape) < 3:
label = label.unsqueeze(-1)
label = label[mask]

# Accumulate outputs and labels either flat or with dim of multilabel
aggr_outputs.append(output)
aggr_labels.append(label)

# Optimizer network on abtch
aggr_outputs, \
aggr_labels = self._optimize_batch(outputs=aggr_outputs, labels=aggr_labels, finalize=(not has_val and sample_idx == iter_len))
aggr_labels = self._optimize_batch(outputs=aggr_outputs, \
labels=aggr_labels, \
finalize=(not has_val and sample_idx == iter_len))

# Miss inclomplete batch
if sample_idx == iter_len:
Expand Down Expand Up @@ -465,11 +479,12 @@ def _train_with_dataloader(self,
input = input.to(self._device)
mask = None

label = label.to(self._device).T
label = label.to(self._device)
output = self(input, masks=mask)

if masking_flag:
output = output[:, mask.squeeze()]
# Keep N class dimension intact
output = output[mask.squeeze(-1), :]
label = label[mask.squeeze(-1), :]

# Accumulate outputs and labels
Expand All @@ -478,7 +493,9 @@ def _train_with_dataloader(self,

# Optimizer network on abtch
aggr_outputs, \
aggr_labels = self._optimize_batch(outputs=aggr_outputs, labels=aggr_labels,finalize=not has_val and iter_len == idx)
aggr_labels = self._optimize_batch(outputs=aggr_outputs, \
labels=aggr_labels, \
finalize=not has_val and iter_len == idx)

# Miss inclomplete batch
if idx == self._sample_size:
Expand All @@ -492,27 +509,49 @@ def _optimize_batch(self, outputs: list, labels: list, finalize: bool = False):

if self._sample_count >= self._batch_size:
# Concatenate accumulated outputs and labels
if self._task == "binary":
if self._target_repl_coef:
# T
outputs = torch.cat(outputs, axis=outputs[0].dim() - 1).view(-1)
replication_weights = [
torch.cat([
torch.ones(output.shape[1] - 1, output.shape[2]) * self._target_repl_coef,
torch.ones(1, output.shape[2])
]) for output in outputs
]

replication_weights = torch.cat(replication_weights,
axis=0).to(self._device).squeeze()
labels = [
label.unsqueeze(1).expand(-1, output.shape[1], -1)
for label, output in zip(labels, outputs)
]
labels = torch.cat(labels, axis=1).squeeze()
outputs = torch.cat(outputs, axis=1).squeeze()
elif self._task == "binary":
# T
outputs = torch.cat(outputs, axis=1).view(-1)
else:
# T, N or B, T, N
outputs = torch.cat(outputs, axis=outputs[0].dim() - 2).squeeze()
outputs = torch.cat(outputs, axis=0).squeeze()

# If multilabel, labels are one-hot, else they are sparse
if self._task == "multilabel":
# T, N
labels = torch.stack(labels)
if labels.shape[-1] == 1:
labels = labels.squeeze(-1)
else:
# T
labels = torch.cat(labels).view(-1)
if not self._target_repl_coef:
if self._task == "multilabel":
# T, N
labels = torch.stack(labels).squeeze()
else:
# T
labels = torch.cat(labels).view(-1)

# Compute loss
loss = self._loss(outputs, labels)
if self._target_repl_coef:
# Apply target replication loss here
loss = self._loss(outputs, labels)
loss = (loss * replication_weights).mean()
else:
loss = self._loss(outputs, labels)

# Backward pass and optimization

self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
Expand Down Expand Up @@ -678,25 +717,48 @@ def _evaluate_batch(self,

if self._sample_count >= self._batch_size:
# Concatenate accumulated outputs and labels
if self._task == "binary":
if self._target_repl_coef:
# T
replication_weights = [
torch.cat([
torch.ones(output.shape[1] - 1, output.shape[2]) * self._target_repl_coef,
torch.ones(1, output.shape[2])
]) for output in outputs
]

replication_weights = torch.cat(replication_weights,
axis=0).to(self._device).squeeze()
labels = [
label.unsqueeze(1).expand(-1, output.shape[1], -1)
for label, output in zip(labels, outputs)
]
labels = torch.cat(labels, axis=1).squeeze()
outputs = torch.cat(outputs, axis=1).squeeze()
elif self._task == "binary":
# T
outputs = torch.cat(outputs, axis=outputs[0].dim() - 1).view(-1)
else:
# T, N or B, T, N
outputs = torch.cat(outputs, axis=outputs[0].dim() - 2).squeeze()
outputs = torch.cat(outputs, axis=0).squeeze()

# If multilabel, labels are one-hot, else they are sparse
if self._task == "multilabel":
# T, N
labels = torch.stack(labels)
if labels.shape[-1] == 1:
labels = labels.squeeze(-1)
else:
labels = torch.stack(labels).squeeze()
# if labels.shape[-1] == 1:
# labels = labels.squeeze(-1)
elif not self._target_repl_coef:
# T
labels = torch.cat(labels).view(-1)

# Compute loss
loss = self._loss(outputs, labels)
self._optimizer.zero_grad()
if self._target_repl_coef:
# Apply target replication loss here
loss = self._loss(outputs, labels)
loss = (loss * replication_weights).mean()
else:
loss = self._loss(outputs, labels)
# Reset count
# TODO! update this logic
self._update_metrics(loss,
Expand Down
13 changes: 8 additions & 5 deletions src/models/pytorch/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def __init__(self,
final_activation: str = None,
output_dim: int = 1,
depth: int = 1,
target_repl_coef: float = 0.,
model_path: Path = None):
super().__init__(final_activation, output_dim, model_path)
super().__init__(final_activation=final_activation,
output_dim=output_dim,
model_path=model_path,
target_repl_coef=target_repl_coef)

self._layer_size = layer_size
self._dropout_rate = dropout
Expand Down Expand Up @@ -97,18 +101,17 @@ def forward(self, x, masks=None) -> torch.Tensor:
if self._task == "binary":
# Along time vector
x = torch.cat(outputs, dim=1)
# if len(x.shape) < 3:
# x = x.unsqueeze(-1)
if len(x.shape) < 3:
x = x.unsqueeze(-1)
else:
# Stacking
x = torch.cat(outputs, dim=0)
if len(x.shape) < 3:
x = x.unsqueeze(0)

else:
# Only return the last prediction
x = x[:, -1, :]
x = self._output_layer(x)
x = self._output_layer(x).unsqueeze(-1)

if self._final_activation and self._apply_activation:
x = self._final_activation(x)
Expand Down
Loading

0 comments on commit 7c4040a

Please sign in to comment.