Skip to content

Commit

Permalink
make release-tag: Merge branch 'master' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
csala committed Jul 30, 2019
2 parents c98f492 + bfce1e8 commit f1d6d60
Show file tree
Hide file tree
Showing 11 changed files with 361 additions and 40 deletions.
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# History

## 0.2.2 (2019-07-30)

### New Features

* Curate dependencies - [Issue #152](https://github.com/HDI-Project/ATM/issues/152) by @csala
* POST request blocked by CORS policy - [Issue #151](https://github.com/HDI-Project/ATM/issues/151) by @pvk-developer

## 0.2.1 (2019-06-24)

### New Features
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<i>An open source project from Data to AI Lab at MIT.</i>
</p>

# ATM - Auto Tune Models


[![CircleCI](https://circleci.com/gh/HDI-Project/ATM.svg?style=shield)](https://circleci.com/gh/HDI-Project/ATM)
Expand All @@ -12,7 +13,7 @@
[![Downloads](https://pepy.tech/badge/atm)](https://pepy.tech/project/atm)


# ATM - Auto Tune Models


- License: MIT
- Documentation: https://HDI-Project.github.io/ATM/
Expand Down Expand Up @@ -143,7 +144,7 @@ For this demo we will be using the pollution csv from the atm-data bucket, which
[from here](https://atm-data.s3.amazonaws.com/pollution_1.csv), or using the following command:

```bash
wget https://atm-data.s3.amazonaws.com/pollution_1.csv
atm download_demo pollution_1.csv
```

## 2. Create an ATM instance
Expand Down
2 changes: 1 addition & 1 deletion atm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__author__ = """MIT Data To AI Lab"""
__email__ = '[email protected]'
__version__ = '0.2.1'
__version__ = '0.2.2-dev'

# this defines which modules will be imported by "from atm import *"
__all__ = ['ATM', 'Model', 'config', 'constants', 'data', 'database',
Expand Down
5 changes: 4 additions & 1 deletion atm/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def create_app(atm, debug=False):
# Allow the CORS header
@app.after_request
def add_cors_headers(response):
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
Expand All @@ -28,7 +29,9 @@ def atm_run():
data = request.json
run_conf = RunConfig(data)

dataruns = atm.create_dataruns(run_conf)
dataruns = atm.add_datarun(**run_conf.to_dict())
if not isinstance(dataruns, list):
dataruns = [dataruns]

response = {
'status': 200,
Expand Down
31 changes: 27 additions & 4 deletions atm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from atm.api import create_app
from atm.config import AWSConfig, DatasetConfig, LogConfig, RunConfig, SQLConfig
from atm.core import ATM
from atm.data import copy_files, get_demos
from atm.data import copy_files, download_demo, get_demos

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,7 +25,13 @@ def _get_atm(args):
aws_conf = AWSConfig(args)
log_conf = LogConfig(args)

return ATM(**sql_conf.to_dict(), **aws_conf.to_dict(), **log_conf.to_dict())
# Build params dictionary to pass to ATM.
# Needed because Python 2.7 does not support multiple star operators in a single statement.
atm_args = sql_conf.to_dict()
atm_args.update(aws_conf.to_dict())
atm_args.update(log_conf.to_dict())

return ATM(**atm_args)


def _work(args, wait=False):
Expand Down Expand Up @@ -209,7 +215,19 @@ def _make_config(args):


def _get_demos(args):
get_demos()
datasets = get_demos()
for dataset in datasets:
print(dataset)


def _download_demo(args):
paths = download_demo(args.dataset, args.path)
if isinstance(paths, list):
for path in paths:
print('Dataset has been saved to {}'.format(path))

else:
print('Dataset has been saved to {}'.format(paths))


def _get_parser():
Expand Down Expand Up @@ -330,8 +348,13 @@ def _get_parser():

# Get Demos
get_demos = subparsers.add_parser('get_demos', parents=[logging_args],
help='Create a demos folder and put the demo CSVs inside.')
help='Print a list with the available demo datasets.')
get_demos.set_defaults(action=_get_demos)
download_demo = subparsers.add_parser('download_demo', parents=[logging_args],
help='Downloads a demo dataset from AWS3.')
download_demo.set_defaults(action=_download_demo)
download_demo.add_argument('dataset', nargs='+', help='Name of the dataset to be downloaded.')
download_demo.add_argument('--path', help='Directory to be used to store the dataset.')

return parser

Expand Down
39 changes: 36 additions & 3 deletions atm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import boto3
import pandas as pd
import requests
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ClientError

LOGGER = logging.getLogger('atm')
Expand Down Expand Up @@ -66,9 +68,40 @@ def copy_files(extension, source, target=None):
return file_paths


def get_demos():
"""Copy the demo CSV files to the ``{cwd}/demos`` folder."""
return copy_files('csv', 'demos')
def download_demo(datasets, path=None):

if not isinstance(datasets, list):
datasets = [datasets]

if path is None:
path = os.path.join(os.getcwd(), 'demos')

if not os.path.exists(path):
os.makedirs(path)

client = boto3.client('s3', config=Config(signature_version=UNSIGNED))

paths = list()

for dataset in datasets:
save_path = os.path.join(path, dataset)

try:
LOGGER.info('Downloading {}'.format(dataset))
client.download_file('atm-data', dataset, save_path)
paths.append(save_path)

except ClientError as e:
LOGGER.error('An error occurred trying to download from AWS3.'
'The following error has been returned: {}'.format(e))

return paths[0] if len(paths) == 1 else paths


def get_demos(args=None):
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
available_datasets = [obj['Key'] for obj in client.list_objects(Bucket='atm-data')['Contents']]
return available_datasets


def _download_from_s3(path, local_path, aws_access_key=None, aws_secret_key=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.1
current_version = 0.2.2-dev
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+))?
Expand Down
43 changes: 20 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,25 @@
history = history_file.read()

install_requires = [
'baytune==0.2.5',
'boto3>=1.9.146',
'future>=0.16.0',
'joblib>=0.11',
'pymysql>=0.9.3',
'cryptography>=2.6.1',
'numpy>=1.13.1',
'pandas>=0.22.0',
'psutil>=5.6.1',
'python-daemon>=2.2.3',
'pyyaml>=3.12',
'requests>=2.18.4',
'scikit-learn>=0.18.2',
'scipy>=0.19.1',
'sklearn-pandas>=1.5.0',
'sqlalchemy>=1.1.14',
'flask>=1.0.2',
'flask-restless>=0.17.0',
'flask-sqlalchemy>=2.3.2',
'flask-restless-swagger-2>=0.0.3',
'simplejson>=3.16.0',
'tqdm>=4.31.1',
'baytune>=0.2.5,<0.3',
'boto3>=1.9.146,<2',
'future>=0.16.0,<0.18',
'pymysql>=0.9.3,<0.10',
'numpy>=1.13.1,<1.17',
'pandas>=0.22.0,<0.25',
'psutil>=5.6.1,<6',
'python-daemon>=2.2.3,<3',
'requests>=2.18.4,<3',
'scikit-learn>=0.18.2,<0.22',
'scipy>=0.19.1,<1.4',
'sqlalchemy>=1.1.14,<1.4',
'flask>=1.0.2,<2',
'flask-restless>=0.17.0,<0.18',
'flask-sqlalchemy>=2.3.2,<2.5',
'flask-restless-swagger-2==0.0.3',
'simplejson>=3.16.0,<4',
'tqdm>=4.31.1,<5',
'docutils>=0.10,<0.15',
]

setup_requires = [
Expand Down Expand Up @@ -113,6 +110,6 @@
test_suite='tests',
tests_require=tests_require,
url='https://github.com/HDI-project/ATM',
version='0.2.1',
version='0.2.2-dev',
zip_safe=False,
)
23 changes: 18 additions & 5 deletions tests/api/test___init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,38 +54,51 @@ def test_create_app_debug(atm):
assert app.config['DEBUG']


def test_home(client):
def test_get_home(client):
res = client.get('/', follow_redirects=False)

assert res.status == '302 FOUND'
assert res.location == 'http://localhost/static/swagger/swagger-ui/index.html'


def test_dataset(client):
def test_get_dataset(client):
res = client.get('api/datasets')
data = json.loads(res.data.decode('utf-8'))

assert res.status == '200 OK'
assert data.get('num_results') == 1


def test_datarun(client):
def test_options_dataset(client):
res = client.options('api/datasets')

expected_headers = [
('Content-Type', 'text/html; charset=utf-8'),
('Access-Control-Allow-Headers', 'Content-Type, Authorization'),
('Access-Control-Allow-Origin', '*'),
('Access-Control-Allow-Credentials', 'true'),
]

assert set(expected_headers).issubset(set(res.headers.to_list()))


def test_get_datarun(client):
res = client.get('api/dataruns')
data = json.loads(res.data.decode('utf-8'))

assert res.status == '200 OK'
assert data.get('num_results') == 2


def test_hyperpartition(client):
def test_get_hyperpartition(client):
res = client.get('api/hyperpartitions')
data = json.loads(res.data.decode('utf-8'))

assert res.status == '200 OK'
assert data.get('num_results') == 40


def test_classifier(client):
def test_get_classifier(client):
res = client.get('api/classifiers')
data = json.loads(res.data.decode('utf-8'))

Expand Down
94 changes: 94 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from mock import Mock, patch

from atm import cli


@patch('atm.cli.get_demos')
def test__get_demos(mock_get_demos):
"""Test that the method get_demos is being called properly."""

# run
cli._get_demos(None) # Args are not being used.

# assert
mock_get_demos.assert_called_once_with()


@patch('atm.cli.download_demo')
def test__download_demo(mock_download_demo):
"""Test that the method _download_demo is being called properly with a single dataset."""

# setup
args_mock = Mock(dataset='test.csv', path=None)

# run
cli._download_demo(args_mock)

# assert
mock_download_demo.assert_called_once_with('test.csv', None)


@patch('atm.cli.download_demo')
def test__download_demo_array(mock_download_demo):
"""Test that the method _download_demo is being called properly with a two datasets."""

# setup
args_mock = Mock(dataset=['test.csv', 'test2.csv'], path=None)
mock_download_demo.return_value = ['test.csv', 'test2.csv']

# run
cli._download_demo(args_mock)

# assert
mock_download_demo.assert_called_once_with(['test.csv', 'test2.csv'], None)


@patch('atm.cli.download_demo')
def test__download_demo_path(mock_download_demo):
"""Test that the method _download_demo is being called properly with a given path."""

# setup
args_mock = Mock(dataset=['test.csv', 'test2.csv'], path='my_test_path')
mock_download_demo.return_value = ['test.csv', 'test2.csv']

# run
cli._download_demo(args_mock)

# assert
mock_download_demo.assert_called_once_with(['test.csv', 'test2.csv'], 'my_test_path')


@patch('atm.cli._get_atm')
def test__work(mock__get_atm):
# setup
args_mock = Mock(dataruns=[1], total_time=[1], save_files=False, cloud_mode=False)

# run
cli._work(args_mock)

# assert
mock__get_atm.assert_called_once_with(args_mock)

mock__get_atm.return_value.work.assert_called_once_with(
datarun_ids=[1],
choose_randomly=False,
save_files=False,
cloud_mode=False,
total_time=[1],
wait=False
)


@patch('atm.cli.create_app')
@patch('atm.cli._get_atm')
def test__serve(mock__get_atm, mock_create_app):
# setup
args_mock = Mock(debug=False, host='1.2.3', port='456')

# run
cli._serve(args_mock)

# assert
mock__get_atm.assert_called_once_with(args_mock)
mock_create_app.assert_called_once_with(mock__get_atm.return_value, False)
mock_create_app.return_value.run.assert_called_once_with(host='1.2.3', port='456')
Loading

0 comments on commit f1d6d60

Please sign in to comment.