Skip to content
This repository has been archived by the owner on May 19, 2022. It is now read-only.

Commit

Permalink
Merge pull request #33 from will1792/more_options
Browse files Browse the repository at this point in the history
Add new option `echeck`, `fcheck`, and `stddev`
  • Loading branch information
Nanco-L authored Oct 11, 2018
2 parents 07eee02 + 68b69dd commit 0900e23
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion simple_nn/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.6"
__version__ = "0.5.0"
13 changes: 9 additions & 4 deletions simple_nn/models/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self):
'intra_op_parallelism_threads': 0,
'print_structure_rmse': False,
'cache': False,
'stddev': 0.3,
'echeck': True,
'fcheck': True,
}
}
self.inputs = dict()
Expand Down Expand Up @@ -93,10 +96,11 @@ def _make_model(self):
else:
dtype = tf.float32

# TODO: input validation for stddev.
dense_basic_setting = {
'dtype': dtype,
'kernel_initializer': tf.initializers.truncated_normal(stddev=0.3, dtype=dtype),
'bias_initializer': tf.initializers.truncated_normal(stddev=0.3, dtype=dtype)
'kernel_initializer': tf.initializers.truncated_normal(stddev=self.inputs['stddev'], dtype=dtype),
'bias_initializer': tf.initializers.truncated_normal(stddev=self.inputs['stddev'], dtype=dtype)
}
dense_last_setting = copy.deepcopy(dense_basic_setting)

Expand Down Expand Up @@ -637,8 +641,9 @@ def train(self, user_optimizer=None, user_atomic_weights_function=None):

# Temp saving
#if (epoch+1) % self.inputs['save_interval'] == 0:
if save_stack > self.inputs['save_interval'] and prev_eloss > eloss and \
((prev_floss > floss) or floss == 0.):
if save_stack > self.inputs['save_interval'] and \
(prev_eloss > eloss or not self.inputs['echeck']) and \
(prev_floss > floss or not self.inputs['fcheck'] or floss == 0.):
temp_time = timeit.default_timer()
self._save(sess, saver)
prev_eloss = eloss
Expand Down

0 comments on commit 0900e23

Please sign in to comment.