From 8f7db900d6aad99a397432c56b5b60d11f1095db Mon Sep 17 00:00:00 2001 From: Andrius Senulis Date: Thu, 28 Apr 2022 11:14:52 +0200 Subject: [PATCH] Enable project creation with a local training data file. Bump version to v0.0.3 --- README.md | 18 ++++++++++-------- datarobot_provider/__init__.py | 2 +- datarobot_provider/hooks/datarobot.py | 6 +++++- datarobot_provider/operators/datarobot.py | 7 +++---- requirements.txt | 2 +- setup.py | 4 ++-- tests/unit/operators/test_datarobot.py | 10 +++------- 7 files changed, 25 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 1060d813..4bbea348 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,9 @@ score predictions against the model deployment. ## Installation -Prerequisites: [Apache Airflow](https://pypi.org/project/apache-airflow/) +Prerequisites: +- [Apache Airflow](https://pypi.org/project/apache-airflow/) +- [DataRobot Python API client](https://pypi.org/project/datarobot/) Install the DataRobot provider: ``` @@ -35,7 +37,7 @@ Operators and sensors use parameters from the [config](https://airflow.apache.or which must be submitted when triggering the dag. Example config JSON with required parameters: { - "training_data": "s3-pre-signed-url-of-training-data", + "training_data": "s3-presigned-url-or-local-path-to-training-data", "project_name": "Project created from Airflow", "autopilot_settings": { "target": "readmitted" @@ -73,10 +75,10 @@ in the `context["params"]` variable, e.g. getting a training data you would use Required config params: - training_data: str - pre-signed S3 URL to training dataset + training_data: str - pre-signed S3 URL or local path to training dataset project_name: str - project name - The `training_data` value must be a [pre-signed AWS S3 URL](https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html). + In case of an S3 input, the `training_data` value must be a [pre-signed AWS S3 URL](https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html). For more [project settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#project) see the DataRobot docs. @@ -96,7 +98,7 @@ in the `context["params"]` variable, e.g. getting a training data you would use For more [autopilot settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#datarobot.models.Project.set_target) see the DataRobot docs. -- `DeployModelOperator` - deploys a specified model to production and returns its ID +- `DeployModelOperator` - deploys a specified model and returns the deployment ID Parameters: @@ -108,7 +110,7 @@ in the `context["params"]` variable, e.g. getting a training data you would use For more [deployment settings](https://datarobot-public-api-client.readthedocs-hosted.com/en/v2.27.1/autodoc/api_reference.html#deployment) see the DataRobot docs. -- `DeployRecommendedModelOperator` - deploys a recommended model to production and returns its ID +- `DeployRecommendedModelOperator` - deploys a recommended model and returns the deployment ID Parameters: @@ -149,13 +151,13 @@ in the `context["params"]` variable, e.g. getting a training data you would use ### [Sensors](https://github.com/datarobot/airflow-provider-datarobot/blob/main/datarobot_provider/sensors/datarobot.py) -- `AutopilotCompleteSensor` - check whether the Autopilot has completed +- `AutopilotCompleteSensor` - checks whether the Autopilot has completed Parameters: project_id: str - DataRobot project ID -- `ScoringCompleteSensor` - check whether batch scoring has completed +- `ScoringCompleteSensor` - checks whether batch scoring has completed Parameters: diff --git a/datarobot_provider/__init__.py b/datarobot_provider/__init__.py index d99fb37b..860fe6fe 100644 --- a/datarobot_provider/__init__.py +++ b/datarobot_provider/__init__.py @@ -10,7 +10,7 @@ def get_provider_info(): "package-name": "airflow-provider-datarobot", "name": "DataRobot Airflow Provider", "description": "DataRobot Airflow provider.", - "versions": ["0.0.2"], + "versions": ["0.0.3"], "hook-class-names": ["datarobot_provider.hooks.datarobot.DataRobotHook"], # Deprecated in >=2.2.0 "connection-types": [ {"hook-class-name": "datarobot_provider.hooks.datarobot.DataRobotHook", "connection-type": "http"} diff --git a/datarobot_provider/hooks/datarobot.py b/datarobot_provider/hooks/datarobot.py index 92c6ff79..b86c8998 100644 --- a/datarobot_provider/hooks/datarobot.py +++ b/datarobot_provider/hooks/datarobot.py @@ -34,7 +34,11 @@ def get_connection_form_widgets() -> Dict[str, Any]: from wtforms import PasswordField, StringField return { - "extra__http__endpoint": StringField(lazy_gettext('DataRobot endpoint URL'), widget=BS3TextFieldWidget()), + "extra__http__endpoint": StringField( + lazy_gettext('DataRobot endpoint URL'), + widget=BS3TextFieldWidget(), + default='https://app.datarobot.com/api/v2', + ), "extra__http__api_key": PasswordField(lazy_gettext('API Key'), widget=BS3PasswordFieldWidget()), } diff --git a/datarobot_provider/operators/datarobot.py b/datarobot_provider/operators/datarobot.py index cbffe6ee..319eb7bc 100644 --- a/datarobot_provider/operators/datarobot.py +++ b/datarobot_provider/operators/datarobot.py @@ -50,10 +50,9 @@ def execute(self, context: Dict[str, Any]) -> str: DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() # Create DataRobot project - self.log.info("Creating training dataset in DataRobot AI Catalog") - dataset = dr.Dataset.create_from_url(context["params"]["training_data"]) - self.log.info(f"Created dataset: dataset_id={dataset.id}") - project = dr.Project.create_from_dataset(dataset.id, project_name=context['params']['project_name']) + self.log.info("Creating DataRobot project") + # training_data may be a pre-signed URL to a file on S3 or a path to a local file + project = dr.Project.create(context["params"]["training_data"], context['params']['project_name']) self.log.info(f"Project created: project_id={project.id}") project.unsupervised_mode = context['params'].get('unsupervised_mode') project.use_feature_discovery = context['params'].get('use_feature_discovery') diff --git a/requirements.txt b/requirements.txt index 3e9e0112..14b25009 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ apache-airflow>=2.0 -datarobot>=2.27.1 +datarobot>=2.28.0 flake8>=4.0.1 mypy>=0.931 pytest>=7.0.0 diff --git a/setup.py b/setup.py index b1ab60fa..dbc4adf3 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ """Perform the package airflow-provider-datarobot setup.""" setup( name='airflow-provider-datarobot', - version="0.0.2", + version="0.0.3", description='DataRobot Airflow provider.', long_description=long_description, long_description_content_type='text/markdown', @@ -25,7 +25,7 @@ license='Apache License 2.0', packages=['datarobot_provider', 'datarobot_provider.hooks', 'datarobot_provider.sensors', 'datarobot_provider.operators'], - install_requires=['apache-airflow>=2.0'], + install_requires=['apache-airflow>=2.0', 'datarobot>=2.28.0'], setup_requires=['setuptools', 'wheel'], author='Andrius Senulis', author_email='andrius.senulis@datarobot.com', diff --git a/tests/unit/operators/test_datarobot.py b/tests/unit/operators/test_datarobot.py index df2ecaaf..92f3c866 100644 --- a/tests/unit/operators/test_datarobot.py +++ b/tests/unit/operators/test_datarobot.py @@ -17,18 +17,15 @@ def test_operator_create_project(mocker): - dataset_mock = mocker.Mock() - dataset_mock.id = "dataset-id" - create_dataset_mock = mocker.patch.object(dr.Dataset, "create_from_url", return_value=dataset_mock) project_mock = mocker.Mock() project_mock.id = "project-id" - create_project_mock = mocker.patch.object(dr.Project, "create_from_dataset", return_value=project_mock) + create_project_mock = mocker.patch.object(dr.Project, "create", return_value=project_mock) operator = CreateProjectOperator(task_id='create_project') project_id = operator.execute( context={ "params": { - "training_data": "s3://path/to/training_data.csv", + "training_data": "/path/to/s3/or/local/file", "project_name": "test project", "unsupervised_mode": False, "use_feature_discovery": False, @@ -37,8 +34,7 @@ def test_operator_create_project(mocker): ) assert project_id == "project-id" - create_dataset_mock.assert_called_with("s3://path/to/training_data.csv") - create_project_mock.assert_called_with("dataset-id", project_name="test project") + create_project_mock.assert_called_with("/path/to/s3/or/local/file", "test project") def test_operator_train_models(mocker):