-
Notifications
You must be signed in to change notification settings - Fork 7
/
RadioMLtfrecorderRead.py
56 lines (48 loc) · 2.28 KB
/
RadioMLtfrecorderRead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
# In[1]: read training data and testing data from TFrecords
def read_data(file_queue):
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
_, serialized_example = reader.read(file_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'snr': tf.FixedLenFeature([], tf.int64),
'sig_iq':tf.FixedLenFeature([], tf.string),
'sig_ap':tf.FixedLenFeature([], tf.string),
})
sig_iq = tf.decode_raw(features['sig_iq'], tf.float64)
sig_iq = tf.reshape(sig_iq, [-1, 128, 2])
sig_iq = tf.cast(sig_iq, tf.float32)
sig_iq = tf.image.resize_images(sig_iq, [1, 128])
sig_ap = tf.decode_raw(features['sig_ap'], tf.float64)
sig_ap = tf.reshape(sig_ap, [-1, 128, 2])
sig_ap = tf.cast(sig_ap, tf.float32)
sig_ap = tf.image.resize_images(sig_ap, [1, 128])
label = tf.cast(features['label'], tf.int32)
snr = tf.cast(features['snr'], tf.int32)
return sig_iq,sig_ap,label,snr
def read_data_batch(file_queue, batch_size):
sig_iq,sig_ap,label,snr = read_data(file_queue)
capacity = 3 * batch_size
sig_iq_batch,sig_ap_batch,label_batch= tf.train.batch([sig_iq, sig_ap, label], batch_size=batch_size, capacity=capacity, num_threads=1000)
return sig_iq_batch,sig_ap_batch,label_batch
train_data_filename_queue = tf.train.string_input_producer('RMLtrainAll.tfrecords')
train_sigiqs, train_sigaps,train_labels = read_data_batch(train_data_filename_queue, batch_size=1000)
test_data_filename_queue = tf.train.string_input_producer('RMLtestAll.tfrecords')
test_sigiqs, test_sigaps,test_labels = read_data_batch(test_data_filename_queue, 1000)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(10):
exampleIQ, exampleAP, l = sess.run([train_sigiqs,train_sigaps,train_labels])
exampleIQ = np.reshape(exampleIQ,[-1,128,2])
exampleAP = np.reshape(exampleAP,[-1,128,2])
l = np.reshape(l,[-1,])
coord.request_stop()
coord.join(threads)