-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_tf_norm.py
18 lines (14 loc) · 888 Bytes
/
test_tf_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import config
import tensorflow as tf
lnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.local_norm_key, dtype=tf.int64),
tf.constant(config.local_norm_lvalues, dtype=tf.int64)), 0)
rnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.local_norm_key, dtype=tf.int64),
tf.constant(config.local_norm_rvalues, dtype=tf.int64)), 1)
class_id = tf.convert_to_tensor(0, dtype=tf.int64)
norm_gather_ind = tf.stack([lnorm_table.lookup(class_id), rnorm_table.lookup(class_id)], axis=-1)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(tf.tables_initializer())
# print(sess.run(rnorm_table))
print(sess.run(norm_gather_ind))