Skip to content

Commit

Permalink
All tf2 tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
ga84jog committed Sep 16, 2024
1 parent 0ec7d5e commit 4343483
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 28 deletions.
3 changes: 3 additions & 0 deletions src/models/tf2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def compile(self,
pss_evaluation_shards=pss_evaluation_shards,
**kwargs)

# Not well understood but this causes problems
self._metrics = list()

@tf.function
def train_step(self, data):
data = data_adapter.expand_1d(data)
Expand Down
26 changes: 8 additions & 18 deletions tests/test_models/test_tf2/test_tf2_cwlstm_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
"pr_auc": 0.9,
},
"LOS": {
"loss": np.inf,
"cohen_kappa": -np.inf,
"custom_mae": -np.inf
"loss": 0.3,
"cohen_kappa": 0.85,
"custom_mae": 3
},
"PHENO": {
"loss": 0.5,
Expand All @@ -50,11 +50,6 @@
}

TARGET_METRICS_DS = {
"IHM": {
"loss": np.nan,
"roc_auc": np.nan,
"pr_auc": np.nan
},
"DECOMP": {
"loss": 0.2,
"roc_auc": 0.8,
Expand All @@ -64,11 +59,6 @@
"loss": 2,
"cohen_kappa": 0.25,
"custom_mae": 110
},
"PHENO": {
"loss": 0.5,
"micro_roc_auc": 0.75,
"macro_roc_auc": 0.7
}
}

Expand Down Expand Up @@ -118,7 +108,7 @@

@pytest.mark.parametrize("data_flavour", ["generator", "numpy"])
@pytest.mark.parametrize("task_name", ["DECOMP", "LOS"])
# @retry(3)
@retry(3)
def test_tf2_cwlstm_with_deep_supvervision(
task_name: str,
data_flavour: str,
Expand Down Expand Up @@ -198,7 +188,7 @@ def test_tf2_cwlstm_with_deep_supvervision(

@pytest.mark.parametrize("data_flavour", ["generator", "numpy"])
@pytest.mark.parametrize("task_name", ["DECOMP", "LOS"])
# @retry(3)
@retry(3)
def test_tf2_cwlstm(
task_name: str,
data_flavour: str,
Expand Down Expand Up @@ -297,8 +287,8 @@ def assert_model_performance(history, task: str, target_metrics: Dict[str, float

if __name__ == "__main__":
disc_reader = dict()
for task_name in ["LOS", "PHENO"]: #["IHM", "DECOMP", "LOS", "PHENO"]:
'''
for task_name in ["IHM", "DECOMP", "LOS", "PHENO"]:

reader = datasets.load_data(chunksize=75836,
source_path=TEST_DATA_DEMO,
storage_path=SEMITEMP_DIR,
Expand All @@ -317,7 +307,7 @@ def assert_model_performance(history, task: str, target_metrics: Dict[str, float
deep_supervision=True,
impute_strategy='previous',
task=task_name)
'''

reader = ProcessedSetReader(Path(SEMITEMP_DIR, "discretized", task_name))
dataset = reader.to_numpy()
for flavour in ["generator", "numpy"]:
Expand Down
10 changes: 0 additions & 10 deletions tests/test_models/test_tf2/test_tf2_lstm_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@
}

TARGET_METRICS_DS = {
"IHM": {
"loss": np.nan,
"roc_auc": np.nan,
"pr_auc": np.nan
},
"DECOMP": {
"loss": 0.2,
"roc_auc": 0.8,
Expand All @@ -64,11 +59,6 @@
"loss": 2,
"cohen_kappa": 0.25,
"custom_mae": 110
},
"PHENO": {
"loss": 0.5,
"micro_roc_auc": 0.75,
"macro_roc_auc": 0.7
}
}

Expand Down

0 comments on commit 4343483

Please sign in to comment.