Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): add wrapped dataset to store the response from a POST/PUT request #905

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions kedro-datasets/kedro_datasets/api/api_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
"""
from __future__ import annotations

import json as json_ # make pylint happy
from copy import deepcopy
from typing import Any
from typing import Any, Type

import requests
from kedro.io.core import AbstractDataset, DatasetError
from kedro.io.memory_dataset import MemoryDataset
from requests import Session, sessions
from requests.auth import AuthBase

import json as json_ # make pylint happy
from kedro_datasets.json import JSONDataset
from kedro_datasets.pickle.pickle_dataset import PickleDataset
from kedro_datasets.text import TextDataset


class APIDataset(AbstractDataset[None, requests.Response]):
"""``APIDataset`` loads/saves data from/to HTTP(S) APIs.
Expand Down Expand Up @@ -97,6 +102,8 @@ def __init__( # noqa: PLR0913
save_args: dict[str, Any] | None = None,
credentials: tuple[str, str] | list[str] | AuthBase | None = None,
metadata: dict[str, Any] | None = None,
extension: str | None = None,
wrapped_dataset: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``APIDataset`` to fetch data from an API endpoint.

Expand Down Expand Up @@ -155,6 +162,9 @@ def __init__( # noqa: PLR0913
}

self.metadata = metadata
self._extension = extension
self._wrapped_dataset_args = wrapped_dataset
self._wrapped_dataset = None

@staticmethod
def _convert_type(value: Any):
Expand All @@ -171,6 +181,8 @@ def _describe(self) -> dict[str, Any]:
# prevent auth from logging
request_args_cp = self._request_args.copy()
request_args_cp.pop("auth", None)
if self._extension:
request_args_cp["wrapped_dataset"] = self.wrapped_dataset._describe()
return request_args_cp

def _execute_request(self, session: Session) -> requests.Response:
Expand All @@ -184,10 +196,12 @@ def _execute_request(self, session: Session) -> requests.Response:

return response

def load(self) -> requests.Response:
def load(self) -> requests.Response | str | Any:
if self._request_args["method"] == "GET":
with sessions.Session() as session:
return self._execute_request(session)
elif self._request_args["method"] in ["PUT", "POST"] and self.wrapped_dataset is not None:
return self.wrapped_dataset.load()

raise DatasetError("Only GET method is supported for load")

Expand Down Expand Up @@ -222,13 +236,52 @@ def _execute_save_request(self, json_data: Any) -> requests.Response:
def save(self, data: Any) -> requests.Response: # type: ignore[override]
if self._request_args["method"] in ["PUT", "POST"]:
if isinstance(data, list):
return self._execute_save_with_chunks(json_data=data)

return self._execute_save_request(json_data=data)
response: requests.Response = self._execute_save_with_chunks(json_data=data)
else:
response: requests.Response = self._execute_save_request(json_data=data)

if self._wrapped_dataset is None:
return response
if self._extension == "json":
self.wrapped_dataset.save(response.json()) #TODO(npfp): expose json loads arguments
elif self._extension == "text":
self.wrapped_dataset.save(response.text)
elif self._extension:
self.wrapped_dataset.save(response)
Comment on lines +245 to +250
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to check on the self.wrapped_dataset type:

Suggested change
if self._extension == "json":
self.wrapped_dataset.save(response.json()) #TODO(npfp): expose json loads arguments
elif self._extension == "text":
self.wrapped_dataset.save(response.text)
elif self._extension:
self.wrapped_dataset.save(response)
elif isinstance(self.wrapped_dataset, JSONDataset):
self.wrapped_dataset.save(response.json()) #TODO(npfp): expose json loads arguments
elif isinstance(self.wrapped_dataset, TextDataset)
self.wrapped_dataset.save(response.text)
elif self._extension:
self.wrapped_dataset.save(response)

return response

raise DatasetError("Use PUT or POST methods for save")

def _exists(self) -> bool:
with sessions.Session() as session:
response = self._execute_request(session)
return response.ok

@property
def _nested_dataset_type(
self,
) -> Type[JSONDataset | PickleDataset | MemoryDataset]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a better way to do this.

if self._extension == "json":
return JSONDataset
elif self._extension == "text":
return TextDataset
elif self._extension == "pickle":
return PickleDataset
elif self._extension == "memory":
#I'm not sure we need this
return MemoryDataset
else:
raise DatasetError(
f"Unknown extension for WrappedDataset: {self._extension}"
)

@property
def wrapped_dataset(
self,
) -> JSONDataset | PickleDataset | MemoryDataset | None:
"""The wrapped dataset where response data is stored."""
if self._wrapped_dataset is None and self._extension is not None:
self._wrapped_dataset = self._nested_dataset_type(
**self._wrapped_dataset_args
)
return self._wrapped_dataset
Loading