Skip to content

Commit

Permalink
Merge pull request #39 from NREL/gb/better_random_seed
Browse files Browse the repository at this point in the history
fixed random seed setting to be compatible with all tf 2 versions
  • Loading branch information
grantbuster authored Oct 13, 2022
2 parents 2e06723 + 916c4df commit 7352fe2
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion phygnn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Custom Neural Network Infrastructure.
"""
from abc import ABC, abstractmethod
import random
import os
import pickle
import pprint
Expand Down Expand Up @@ -241,9 +242,9 @@ def seed(s=0):
s : int
Random seed
"""
random.seed(s)
np.random.seed(s)
tf.random.set_seed(s)
tf.keras.utils.set_random_seed(s)

@classmethod
def get_val_split(cls, *args, shuffle=True, validation_split=0.2):
Expand Down
3 changes: 2 additions & 1 deletion phygnn/model_interfaces/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Base Model Interface
"""
from abc import ABC
import random
import copy
import pprint
import logging
Expand Down Expand Up @@ -446,10 +447,10 @@ def seed(s=0):
s : int
Random number generator seed
"""
random.seed(s)
np.random.seed(s)
if TF2:
tf.random.set_seed(s)
tf.keras.utils.set_random_seed(s)
else:
tf.random.set_random_seed(s)

Expand Down
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.21'
__version__ = '0.0.22'

0 comments on commit 7352fe2

Please sign in to comment.