From 380d75670702b9fa8b51833f18aea80b043458d4 Mon Sep 17 00:00:00 2001 From: Yang Feng Date: Sun, 19 Feb 2017 17:26:52 -0500 Subject: [PATCH] add scope --- kaffe/tensorflow/network.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/kaffe/tensorflow/network.py b/kaffe/tensorflow/network.py index 6f3b153..bd645ba 100644 --- a/kaffe/tensorflow/network.py +++ b/kaffe/tensorflow/network.py @@ -1,5 +1,6 @@ import numpy as np import tensorflow as tf +from tensorflow.python.ops import variable_scope as vs DEFAULT_PADDING = 'SAME' @@ -44,6 +45,7 @@ def __init__(self, inputs, trainable=True): self.use_dropout = tf.placeholder_with_default(tf.constant(1.0), shape=[], name='use_dropout') + self.scope = vs.get_variable_scope() self.setup() def setup(self): @@ -57,15 +59,16 @@ def load(self, data_path, session, ignore_missing=False): ignore_missing: If true, serialized weights for missing layers are ignored. ''' data_dict = np.load(data_path).item() - for op_name in data_dict: - with tf.variable_scope(op_name, reuse=True): - for param_name, data in data_dict[op_name].iteritems(): - try: - var = tf.get_variable(param_name) - session.run(var.assign(data)) - except ValueError: - if not ignore_missing: - raise + with vs.variable_scope(self.scope): + for op_name in data_dict: + with tf.variable_scope(op_name, reuse=True): + for param_name, data in data_dict[op_name].iteritems(): + try: + var = tf.get_variable(param_name) + session.run(var.assign(data)) + except ValueError: + if not ignore_missing: + raise def feed(self, *args): '''Set the input(s) for the next operation by replacing the terminal nodes.