-
Notifications
You must be signed in to change notification settings - Fork 8
/
imdb.py
68 lines (56 loc) · 2.03 KB
/
imdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import absolute_import
from six.moves import cPickle
import gzip
from six.moves import zip
import numpy as np
def load_data(path="imdb.pkl", nb_words=None, skip_top=0,
maxlen=None, test_split=0.2, seed=113,
start_char=1, oov_char=2, index_from=3):
#path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/imdb.pkl")
if path.endswith(".gz"):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
X, labels = cPickle.load(f)
f.close()
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(labels)
if start_char is not None:
X = [[start_char] + [w + index_from for w in x] for x in X]
elif index_from:
X = [[w + index_from for w in x] for x in X]
if maxlen:
new_X = []
new_labels = []
for x, y in zip(X, labels):
if len(x) < maxlen:
new_X.append(x)
new_labels.append(y)
X = new_X
labels = new_labels
if not X:
raise Exception('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
'Increase maxlen.')
if not nb_words:
nb_words = max([max(x) for x in X])
# by convention, use 2 as OOV word
# reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X]
else:
nX = []
for x in X:
nx = []
for w in x:
if (w >= nb_words or w < skip_top):
nx.append(w)
nX.append(nx)
X = nX
X_train = np.array(X[:int(len(X) * (1 - test_split))])
y_train = np.array(labels[:int(len(X) * (1 - test_split))])
X_test = np.array(X[int(len(X) * (1 - test_split)):])
y_test = np.array(labels[int(len(X) * (1 - test_split)):])
return (X_train, y_train), (X_test, y_test)