diff --git a/README.md b/README.md index f8bb867..14cd9d8 100644 --- a/README.md +++ b/README.md @@ -17,21 +17,14 @@ Tested on: ## Usage -### Use pretrained model -Download pretrained model weights for TensorFlow backend: - -```sh -mkdir -p pretrained_models -wget -P pretrained_models https://www.dropbox.com/s/rf8hgoev8uqjv3z/weights.18-4.06.hdf5 -``` - -Run demo script (requires web cam) +### Use pretrained model for demo +Run demo the script (requires web cam) ```sh python3 demo.py ``` -Model weights for Theano backend is also available from [here](https://drive.google.com/file/d/0B_cG1nzvVZlQWGJMc2JjdzkwcVk/view?usp=sharing). +The pretrained model for TensorFlow backend will be automatically downloaded to the `pretrained_models` directory. ### Train a model using the IMDB-WIKI dataset diff --git a/demo.py b/demo.py index 4ef7d79..cbd3bd8 100644 --- a/demo.py +++ b/demo.py @@ -4,6 +4,10 @@ import numpy as np import argparse from wide_resnet import WideResNet +from keras.utils.data_utils import get_file + +pretrained_model = "https://www.dropbox.com/s/rf8hgoev8uqjv3z/weights.18-4.06.hdf5?dl=1" +modhash = '89f56a39a78454e96379348bddd78c0d' def get_args(): @@ -35,7 +39,8 @@ def main(): weight_file = args.weight_file if not weight_file: - weight_file = os.path.join("pretrained_models", "weights.18-4.06.hdf5") + weight_file = get_file("weights.18-4.06.hdf5", pretrained_model, cache_subdir="pretrained_models", + file_hash=modhash, cache_dir=os.path.dirname(os.path.abspath(__file__))) # for face detection detector = dlib.get_frontal_face_detector()