Skip to content
This repository has been archived by the owner on May 19, 2022. It is now read-only.

Commit

Permalink
change save mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanco-L committed Aug 24, 2018
1 parent 4091c88 commit 4a70aa6
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions simple_nn/models/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,14 @@ def _save(self, sess, saver):
self.inputs['continue'] = True
self.parent.write_inputs()

self.parent.logfile.write("Save the weights and write the LAMMPS potential..\n")
cutline = '----------------------------------------------'
if self.inputs['use_force']:
cutline += '------------------------'

if not self.inputs['print_structure_rmse']:
self.parent.logfile.write(cutline + "\n")
self.parent.logfile.write("Save the weights and write the LAMMPS potential..\n")
self.parent.logfile.write(cutline + "\n")
saver.save(sess, './SAVER')
self._generate_lammps_potential(sess)

Expand Down Expand Up @@ -471,6 +478,10 @@ def train(self, user_optimizer=None, user_atomic_weights_function=None):
#options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
#run_metadata = tf.RunMetadata()

prev_eloss = float('inf')
prev_floss = float('inf')
save_stack = 1

if self.inputs['train']:
train_handle = sess.run(train_iter.string_handle())
train_fdict = {self.handle: train_handle}
Expand Down Expand Up @@ -550,6 +561,8 @@ def train(self, user_optimizer=None, user_atomic_weights_function=None):
# fil.write(chrome_trace)

# TODO: need to fix the calculation part for training loss
save_stack += self.inputs['show_interval']

result = "epoch {:7d}: ".format(sess.run(self.global_step)+1)

t_eloss, t_floss, t_str_eloss, t_str_floss, _, _, _, t_str_set = self._get_loss_for_print(
Expand Down Expand Up @@ -605,10 +618,14 @@ def train(self, user_optimizer=None, user_atomic_weights_function=None):
self.parent.logfile.write(result)

# Temp saving
if (epoch+1) % self.inputs['save_interval'] == 0:
#if (epoch+1) % self.inputs['save_interval'] == 0:
if save_stack > self.inputs['save_interval'] and prev_eloss > eloss and prev_floss > floss:
self._save(sess, saver)
prev_eloss = eloss
prev_floss = floss
save_stack = 1

self._save(sess, saver)
#self._save(sess, saver)

if self.inputs['test']:
test_handle = sess.run(test_iter.string_handle())
Expand Down

0 comments on commit 4a70aa6

Please sign in to comment.