diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index a4290e0b9b..ed5bc4a896 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -665,3 +665,65 @@ def __str__(self) -> str: super().__str__() + f": cursor_path: {self.cursor_path} cursor_param: {self.cursor_param}" ) + + +class HeaderCursorPaginator(BaseReferencePaginator): + """A paginator that uses a cursor in the HTTP responses header + for pagination. + + For example, consider an API response that includes 'NextPageToken' header: + + ... + Content-Type: application/json + NextPageToken: 123456" + + [ + {"id": 1, "name": "item1"}, + {"id": 2, "name": "item2"}, + ... + ] + + In this scenario, the parameter to construct the URL for the next page (`https://api.example.com/items?page=123456`) + is identified by the `NextPageToken` header. `HeaderCursorPaginator` extracts + this parameter from the header and uses it to fetch the next page of results: + + from dlt.sources.helpers.rest_client import RESTClient + client = RESTClient( + base_url="https://api.example.com", + paginator=HeaderCursorPaginator() + ) + + @dlt.resource + def get_issues(): + for page in client.paginate("/items"): + yield page + """ + + def __init__(self, cursor_key: str = "next", cursor_param: str = "cursor") -> None: + """ + Args: + cursor_key (str, optional): The key in the header + that contains the next page cursor value. Defaults to 'next'. + cursor_param (str, optional): The param name to pass the token to + for the next request. Defaults to 'cursor'. + """ + super().__init__() + self.cursor_key = cursor_key + self.cursor_param = cursor_param + + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: + """Extracts the next page cursor from the header in the response.""" + self._next_reference = response.headers.get(self.cursor_key) + + def update_request(self, request: Request) -> None: + """Updates the request with the cursor query parameter.""" + if request.params is None: + request.params = {} + + request.params[self.cursor_param] = self._next_reference + + def __str__(self) -> str: + return ( + super().__str__() + + f": cursor_value: {self._next_reference} cursor_param: {self.cursor_param}" + ) diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 85276a263f..d716d0858e 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -13,6 +13,7 @@ HeaderLinkPaginator, JSONLinkPaginator, JSONResponseCursorPaginator, + HeaderCursorPaginator, ) from .conftest import assert_pagination @@ -552,3 +553,40 @@ def test_client_pagination(self, rest_client): pages = list(pages_iter) assert_pagination(pages) + + +@pytest.mark.usefixtures("mock_api_server") +class TestHeaderCursorPaginator: + def test_update_state(self): + paginator = HeaderCursorPaginator(cursor_key="next_cursor") + response = Mock(Response) + response.headers = {"next_cursor": "cursor-2"} + paginator.update_state(response) + assert paginator._next_reference == "cursor-2" + assert paginator.has_next_page is True + + def test_update_state_when_cursor_path_is_empty_string(self): + paginator = HeaderCursorPaginator(cursor_key="next_cursor") + response = Mock(Response) + response.headers = {} + paginator.update_state(response) + assert paginator.has_next_page is False + + def test_update_request(self): + paginator = HeaderCursorPaginator(cursor_key="next_cursor", cursor_param="cursor") + response = Mock(Response) + response.headers = {"next_cursor": "cursor-2"} + paginator.update_state(response) + request = Request(method="GET", url="http://example.com/api/resource") + paginator.update_request(request) + assert request.params["cursor"] == "cursor-2" + + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts_header_cursor", + paginator=HeaderCursorPaginator(cursor_key="cursor", cursor_param="page"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py index a416c1d1d6..75cce2705a 100644 --- a/tests/sources/rest_api/conftest.py +++ b/tests/sources/rest_api/conftest.py @@ -94,6 +94,19 @@ def posts_header_link(request, context): return response + @router.get(r"/posts_header_cursor(\?page=\d+)?$") + def posts_header_cursor(request, context): + records = generate_posts() + page_number = get_page_number(request.qs) + paginator = PageNumberPaginator(records, page_number) + + response = paginator.page_records + + if paginator.next_page_url_params: + context.headers["cursor"] = f"{page_number+1}" + + return response + @router.get(r"/posts_relative_next_url(\?page=\d+)?$") def posts_relative_next_url(request, context): return paginate_by_page_number(request, generate_posts(), use_absolute_url=False)