-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Working model test for tf2 lstm with deep supervision. but torch lear…
…ning too slow
- Loading branch information
ga84jog
committed
Jun 20, 2024
1 parent
e4dbb2d
commit 8d765c6
Showing
11 changed files
with
1,583 additions
and
1,648 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,43 @@ | ||
from models.tf2.mappings import metric_mapping | ||
from tensorflow import config | ||
from tensorflow.keras import Model | ||
from utils.IO import * | ||
|
||
import tensorflow as tf | ||
|
||
try: | ||
gpus = config.experimental.list_physical_devices('GPU') | ||
for gpu in gpus: | ||
config.experimental.set_memory_growth(gpu, True) | ||
except: | ||
warn_io("Could not set dynamic memory growth for GPUs. This may lead to memory errors.") | ||
|
||
from models.tf2.mappings import metric_mapping | ||
|
||
|
||
class AbstractTf2Model(Model): | ||
|
||
def compile(self, | ||
optimizer='rmsprop', | ||
loss=None, | ||
metrics=None, | ||
loss_weights=None, | ||
metrics=[], | ||
weighted_metrics=None, | ||
run_eagerly=False, | ||
steps_per_execution=1, | ||
jit_compile='auto', | ||
auto_scale_loss=True): | ||
for metric in metrics: | ||
if metric in metric_mapping: | ||
metrics[metrics.index(metric)] = metric_mapping[metric] | ||
run_eagerly=None, | ||
steps_per_execution=None, | ||
jit_compile=None, | ||
pss_evaluation_shards=0, | ||
**kwargs): | ||
if metrics is not None: | ||
for metric in metrics: | ||
if metric in metric_mapping: | ||
metrics[metrics.index(metric)] = metric_mapping[metric] | ||
super().compile(optimizer=optimizer, | ||
loss=loss, | ||
loss_weights=loss_weights, | ||
metrics=metrics, | ||
loss_weights=loss_weights, | ||
weighted_metrics=weighted_metrics, | ||
run_eagerly=run_eagerly, | ||
steps_per_execution=steps_per_execution, | ||
jit_compile=jit_compile) | ||
jit_compile=jit_compile, | ||
pss_evaluation_shards=pss_evaluation_shards, | ||
**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.