From 2bfd740dd9ca22824026ec7e2e101d8864da31ba Mon Sep 17 00:00:00 2001 From: edwardyehuang Date: Mon, 7 Mar 2022 13:59:28 +0800 Subject: [PATCH] fix adamw --- core_optimizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core_optimizer.py b/core_optimizer.py index 7885f2f..e13cefe 100644 --- a/core_optimizer.py +++ b/core_optimizer.py @@ -19,6 +19,7 @@ def get_optimizer( decay_strategy="poly", optimizer="sgd", sgd_momentum_rate=0.9, + adamw_weight_decay=0.0001, ): kwargs = { @@ -31,6 +32,7 @@ def get_optimizer( "decay_strategy": decay_strategy, "optimizer": optimizer, "sgd_momentum_rate": sgd_momentum_rate, + "adamw_weight_decay": adamw_weight_decay, } keys = kwargs.keys() @@ -98,6 +100,7 @@ def __get_optimizer( decay_strategy="poly", optimizer="sgd", sgd_momentum_rate=0.9, + adamw_weight_decay=0.0001, ): learning_rate = initial_lr @@ -127,7 +130,7 @@ def __get_optimizer( elif optimizer == "amsgrad": _optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, amsgrad=True) elif optimizer == "adamw": - _optimizer = AdamW(weight_decay=0, learning_rate=learning_rate) + _optimizer = AdamW(weight_decay=adamw_weight_decay, learning_rate=learning_rate) else: raise ValueError(f"Unsupported optimizer {optimizer}")