Skip to content

Commit

Permalink
Passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ga84jog committed Jul 29, 2024
1 parent c8aa268 commit 9258d8e
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 241 deletions.
2 changes: 1 addition & 1 deletion src/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __len__(self):
return self._steps

def __del__(self):
if self._cpu_count:
if hasattr(self, '_cpu_count') and self._cpu_count:
self._close()

def _create_workers(self):
Expand Down
6 changes: 3 additions & 3 deletions src/metrics/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

class BinedMAE(MeanAbsoluteError):

def __init__(self, binning: Literal["log", "custom"], *args, **kwargs):
def __init__(self, bining: Literal["log", "custom"], *args, **kwargs):
super().__init__(*args, **kwargs)
self._binning = binning
self._binning = bining
if self._binning == "custom":
self._means = torch.tensor(CustomBins.means, dtype=torch.float32)
elif self._binning == "log":
self._means = torch.tensor(LogBins.means, dtype=torch.float32)
else:
raise ValueError(f"Binning must be one of 'log' or 'custom' but is {binning}.")
raise ValueError(f"Binning must be one of 'log' or 'custom' but is {bining}.")

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
prediction_means = self._means[torch.argmax(preds, axis=1)]
Expand Down
18 changes: 8 additions & 10 deletions src/metrics/tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,14 @@ def update_state(self, y_true, y_pred, sample_weight=None):
tf.print("y_true max:", tf.reduce_max(tf.reshape(y_true, [-1])))
tf.print("y_pred max:", tf.reduce_max(tf.reshape(y_pred, [-1])))

# Get the shapes
y_true_shape = tf.shape(y_true)
y_pred_shape = tf.shape(y_pred)

# Use tf.cond for more explicit control flow
y_true = tf.cond(tf.equal(y_true_shape[-1], 1), lambda: tf.squeeze(y_true, axis=-1),
lambda: y_true)

y_pred = tf.cond(tf.equal(y_pred_shape[-1], 1), lambda: tf.squeeze(y_pred, axis=-1),
lambda: y_pred)
def squeeze_equal_one(tensor):
shape = tf.shape(tensor)
is_one = tf.equal(shape, 1)
axes_to_squeeze = tf.where(is_one)[:, 0]
return tf.squeeze(tensor, axis=axes_to_squeeze.numpy().tolist())

y_true = squeeze_equal_one(y_true)
y_pred = squeeze_equal_one(y_pred)

if debug:
tf.print("---- after squeeze ----")
Expand Down
Loading

0 comments on commit 9258d8e

Please sign in to comment.