Skip to content

Commit

Permalink
Merge pull request #3 from datarobot/andrius/add-default-endpoint-url
Browse files Browse the repository at this point in the history
Enable project creation with a local training data file. Bump version to v0.0.3
andrius-senulis authored Apr 28, 2022
2 parents 6253854 + 8f7db90 commit c7efa43
Showing 7 changed files with 25 additions and 24 deletions.
18 changes: 10 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -7,7 +7,9 @@ score predictions against the model deployment.

## Installation

Prerequisites: [Apache Airflow](https://pypi.org/project/apache-airflow/)
Prerequisites:
- [Apache Airflow](https://pypi.org/project/apache-airflow/)
- [DataRobot Python API client](https://pypi.org/project/datarobot/)

Install the DataRobot provider:
```
@@ -35,7 +37,7 @@ Operators and sensors use parameters from the [config](https://airflow.apache.or
which must be submitted when triggering the dag. Example config JSON with required parameters:

{
"training_data": "s3-pre-signed-url-of-training-data",
"training_data": "s3-presigned-url-or-local-path-to-training-data",
"project_name": "Project created from Airflow",
"autopilot_settings": {
"target": "readmitted"
@@ -73,10 +75,10 @@ in the `context["params"]` variable, e.g. getting a training data you would use

Required config params:

training_data: str - pre-signed S3 URL to training dataset
training_data: str - pre-signed S3 URL or local path to training dataset
project_name: str - project name

The `training_data` value must be a [pre-signed AWS S3 URL](https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html).
In case of an S3 input, the `training_data` value must be a [pre-signed AWS S3 URL](https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html).

For more [project settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#project) see the DataRobot docs.

@@ -96,7 +98,7 @@ in the `context["params"]` variable, e.g. getting a training data you would use

For more [autopilot settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#datarobot.models.Project.set_target) see the DataRobot docs.

- `DeployModelOperator` - deploys a specified model to production and returns its ID
- `DeployModelOperator` - deploys a specified model and returns the deployment ID

Parameters:

@@ -108,7 +110,7 @@ in the `context["params"]` variable, e.g. getting a training data you would use

For more [deployment settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#deployment) see the DataRobot docs.

- `DeployRecommendedModelOperator` - deploys a recommended model to production and returns its ID
- `DeployRecommendedModelOperator` - deploys a recommended model and returns the deployment ID

Parameters:

@@ -149,13 +151,13 @@ in the `context["params"]` variable, e.g. getting a training data you would use

### [Sensors](https://github.com/datarobot/airflow-provider-datarobot/blob/main/datarobot_provider/sensors/datarobot.py)

- `AutopilotCompleteSensor` - check whether the Autopilot has completed
- `AutopilotCompleteSensor` - checks whether the Autopilot has completed

Parameters:

project_id: str - DataRobot project ID

- `ScoringCompleteSensor` - check whether batch scoring has completed
- `ScoringCompleteSensor` - checks whether batch scoring has completed

Parameters:

2 changes: 1 addition & 1 deletion datarobot_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ def get_provider_info():
"package-name": "airflow-provider-datarobot",
"name": "DataRobot Airflow Provider",
"description": "DataRobot Airflow provider.",
"versions": ["0.0.2"],
"versions": ["0.0.3"],
"hook-class-names": ["datarobot_provider.hooks.datarobot.DataRobotHook"], # Deprecated in >=2.2.0
"connection-types": [
{"hook-class-name": "datarobot_provider.hooks.datarobot.DataRobotHook", "connection-type": "http"}
6 changes: 5 additions & 1 deletion datarobot_provider/hooks/datarobot.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,11 @@ def get_connection_form_widgets() -> Dict[str, Any]:
from wtforms import PasswordField, StringField

return {
"extra__http__endpoint": StringField(lazy_gettext('DataRobot endpoint URL'), widget=BS3TextFieldWidget()),
"extra__http__endpoint": StringField(
lazy_gettext('DataRobot endpoint URL'),
widget=BS3TextFieldWidget(),
default='https://app.datarobot.com/api/v2',
),
"extra__http__api_key": PasswordField(lazy_gettext('API Key'), widget=BS3PasswordFieldWidget()),
}

7 changes: 3 additions & 4 deletions datarobot_provider/operators/datarobot.py
Original file line number Diff line number Diff line change
@@ -50,10 +50,9 @@ def execute(self, context: Dict[str, Any]) -> str:
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run()

# Create DataRobot project
self.log.info("Creating training dataset in DataRobot AI Catalog")
dataset = dr.Dataset.create_from_url(context["params"]["training_data"])
self.log.info(f"Created dataset: dataset_id={dataset.id}")
project = dr.Project.create_from_dataset(dataset.id, project_name=context['params']['project_name'])
self.log.info("Creating DataRobot project")
# training_data may be a pre-signed URL to a file on S3 or a path to a local file
project = dr.Project.create(context["params"]["training_data"], context['params']['project_name'])
self.log.info(f"Project created: project_id={project.id}")
project.unsupervised_mode = context['params'].get('unsupervised_mode')
project.use_feature_discovery = context['params'].get('use_feature_discovery')
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
apache-airflow>=2.0
datarobot>=2.27.1
datarobot>=2.28.0
flake8>=4.0.1
mypy>=0.931
pytest>=7.0.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
"""Perform the package airflow-provider-datarobot setup."""
setup(
name='airflow-provider-datarobot',
version="0.0.2",
version="0.0.3",
description='DataRobot Airflow provider.',
long_description=long_description,
long_description_content_type='text/markdown',
@@ -25,7 +25,7 @@
license='Apache License 2.0',
packages=['datarobot_provider', 'datarobot_provider.hooks',
'datarobot_provider.sensors', 'datarobot_provider.operators'],
install_requires=['apache-airflow>=2.0'],
install_requires=['apache-airflow>=2.0', 'datarobot>=2.28.0'],
setup_requires=['setuptools', 'wheel'],
author='Andrius Senulis',
author_email='andrius.senulis@datarobot.com',
10 changes: 3 additions & 7 deletions tests/unit/operators/test_datarobot.py
Original file line number Diff line number Diff line change
@@ -17,18 +17,15 @@


def test_operator_create_project(mocker):
dataset_mock = mocker.Mock()
dataset_mock.id = "dataset-id"
create_dataset_mock = mocker.patch.object(dr.Dataset, "create_from_url", return_value=dataset_mock)
project_mock = mocker.Mock()
project_mock.id = "project-id"
create_project_mock = mocker.patch.object(dr.Project, "create_from_dataset", return_value=project_mock)
create_project_mock = mocker.patch.object(dr.Project, "create", return_value=project_mock)

operator = CreateProjectOperator(task_id='create_project')
project_id = operator.execute(
context={
"params": {
"training_data": "s3://path/to/training_data.csv",
"training_data": "/path/to/s3/or/local/file",
"project_name": "test project",
"unsupervised_mode": False,
"use_feature_discovery": False,
@@ -37,8 +34,7 @@ def test_operator_create_project(mocker):
)

assert project_id == "project-id"
create_dataset_mock.assert_called_with("s3://path/to/training_data.csv")
create_project_mock.assert_called_with("dataset-id", project_name="test project")
create_project_mock.assert_called_with("/path/to/s3/or/local/file", "test project")


def test_operator_train_models(mocker):

0 comments on commit c7efa43

Please sign in to comment.