Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Germey committed Feb 2, 2019
0 parents commit 60f75e9
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 0 deletions.
111 changes: 111 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.idea/
model_zoo
scores/
# C extensions
*.so
checkpoints/
notebooks/.ipynb_checkpoints
debug.log
# Distribution / packaging
.Python
build/
examples/checkpoints/
examples/events/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
*.pyc
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache

nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
events/
checkpoints/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

79 changes: 79 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# TextClassification

TextClassification Model implemented by [ModelZoo](https://github.com/ModelZoo/ModelZoo).


## Installation

Firstly you need to clone this repository and install dependencies with pip:

```
pip3 install -r requirements.txt
```

## Dataset

We use IMDB dataset for example.

## Usage

We can run this model like this:

```
python3 train.py
```

Outputs like this:

```
...
```

It runs only 42 epochs and stopped early, because there are no more good evaluation results for 20 epochs.

When finished, we can find two folders generated named `checkpoints` and `events`.

Go to `events` and run TensorBoard:

```
cd events
tensorboard --logdir=.
```

TensorBoard like this:

![](https://ws4.sinaimg.cn/large/006tNbRwgy1fvxrcajse2j31kw0hkgnf.jpg)

There are training batch loss, epoch loss, eval loss.

And also we can find checkpoints in `checkpoints` dir.

It saved the best model named `model.ckpt` according to eval score, and it also saved checkpoints every 2 epochs.

Next we can predict using existing checkpoints and `infer.py`.

Now we've restored the specified model `model.ckpt-38` and prepared test data, outputs like this:

```python
[[ 9.637125 ]
[21.368305 ]
[20.898445 ]
[33.832504 ]
[25.756516 ]
[21.264557 ]
[29.069794 ]
[24.968184 ]
...
[36.027283 ]
[39.06852 ]
[25.728745 ]
[41.62165 ]
[34.340042 ]
[24.821484 ]]
```

OK, we've finished restoring and predicting. Just so quickly.

## License

MIT
23 changes: 23 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from model_zoo.model import BaseModel
import tensorflow as tf


class DenseModel(BaseModel):
def __init__(self, config):
super(DenseModel, self).__init__(config)
self.embedding = tf.keras.layers.Embedding(config['vocab_size'], config['embedding_size'])
self.pool = tf.keras.layers.GlobalAveragePooling1D()
self.dense1 = tf.keras.layers.Dense(16, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)

def call(self, inputs, training=None, mask=None):
o = self.embedding(inputs)
o = self.pool(o)
o = self.dense1(o)
o = self.dense2(o)
return o

def init(self):
self.compile(optimizer=tf.train.AdamOptimizer(),
loss='binary_crossentropy',
metrics=['accuracy'])
25 changes: 25 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
absl-py==0.5.0
astor==0.7.1
cycler==0.10.0
gast==0.2.0
grpcio==1.15.0
h5py==2.8.0
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
kiwisolver==1.0.1
Markdown==3.0.1
matplotlib==3.0.0
model-zoo
numpy>=1.15.2
pandas
protobuf==3.6.1
pyparsing==2.2.2
python-dateutil==2.7.3
scikit-learn==0.20.0
scipy>=1.1.0
six==1.11.0
sklearn
tensorboard>=1.11.0
tensorflow>=1.11.0
termcolor==1.1.0
Werkzeug==0.14.1
35 changes: 35 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import tensorflow as tf
from model_zoo.trainer import BaseTrainer
from tensorflow.python.keras.datasets import imdb
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

tf.flags.DEFINE_integer('epochs', 50, 'Max epochs')
tf.flags.DEFINE_float('learning_rate', 0.001, 'Learning rate')
tf.flags.DEFINE_string('model_class', 'DenseModel', help='Model class name')
tf.flags.DEFINE_integer('vocab_size', 10000, help='Vocab size')
tf.flags.DEFINE_integer('embedding_size', 200, help='Embedding size')


class Trainer(BaseTrainer):

def build_word_index(self):
word_index = imdb.get_word_index()
word_index = {k: (v + 3) for k, v in word_index.items()}
word_index['<PAD>'] = 0
word_index['<START>'] = 1
word_index['<UNK>'] = 2
word_index['<UNUSED>'] = 3
return word_index

def prepare_data(self):
(x_train, y_train), (_, _) = imdb.load_data(num_words=self.flags.vocab_size)
word_index = self.build_word_index()
x_train = pad_sequences(x_train, maxlen=250, value=word_index['<PAD>'], padding='post')
(x_train, x_eval) = x_train[:20000], x_train[20000:]
(y_train, y_eval) = y_train[:20000], y_train[20000:]
train_data, eval_data = self.build_generator(x_train, y_train), self.build_generator(x_eval, y_eval)
return train_data, eval_data, len(x_train), len(x_eval)


if __name__ == '__main__':
Trainer().run()

0 comments on commit 60f75e9

Please sign in to comment.