Skip to content

Commit

Permalink
feat: implement limit/offset in adapters (#271)
Browse files Browse the repository at this point in the history
* feat: implement limit/offset in adapters

* Add tests
  • Loading branch information
betodealmeida authored Jul 24, 2022
1 parent 0639892 commit 8d2ebd7
Show file tree
Hide file tree
Showing 17 changed files with 628 additions and 38 deletions.
3 changes: 3 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ disable =
[MASTER]
ignore=templates,docs
disable =

[TYPECHECK]
ignored-modules=apsw
18 changes: 16 additions & 2 deletions src/shillelagh/adapters/api/datasette.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class DatasetteAPI(Adapter):

safe = True

supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = urllib.parse.urlparse(uri)
Expand Down Expand Up @@ -168,16 +171,24 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
offset = 0
offset = offset or 0
while True:
if limit is None:
# request 1 more, so we know if there are more pages to be fetched
end = DEFAULT_LIMIT + 1
else:
end = min(limit, DEFAULT_LIMIT + 1)

sql = build_sql(
self.columns,
bounds,
order,
f'"{self.table}"',
limit=DEFAULT_LIMIT + 1,
limit=end,
offset=offset,
)
payload = self._run_query(sql)
Expand All @@ -199,4 +210,7 @@ def get_data(

if not payload["truncated"] and len(rows) <= DEFAULT_LIMIT:
break

offset += i + 1
if limit is not None:
limit -= i + 1
50 changes: 41 additions & 9 deletions src/shillelagh/adapters/api/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Column:
field: Field

# A default value for when the column is not specified. Eg, for ``pulls``
# the API defaults to show only PRs with an open state, so we need to
# the API defaults to showing only PRs with an open state, so we need to
# default the column to ``all`` to fetch all PRs when state is not
# specified in the query.
default: Optional[Filter] = None
Expand Down Expand Up @@ -75,6 +75,9 @@ class GitHubAPI(Adapter):

safe = True

supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = urllib.parse.urlparse(uri)
Expand Down Expand Up @@ -133,6 +136,8 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
# apply default values
Expand All @@ -142,14 +147,22 @@ def get_data(

if "number" in bounds:
number = bounds.pop("number").value # type: ignore
return self._get_single_resource(number)
return self._get_single_resource(number, limit, offset)

return self._get_multiple_resources(bounds)
return self._get_multiple_resources(bounds, limit, offset)

def _get_single_resource(self, number: int) -> Iterator[Row]:
def _get_single_resource(
self,
number: int,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Iterator[Row]:
"""
Return a specific resource.
"""
if offset or (limit is not None and limit < 1):
return

headers = {"Accept": "application/vnd.github.v3+json"}
if self.access_token:
headers["Authorization"] = f"Bearer {self.access_token}"
Expand All @@ -171,7 +184,12 @@ def _get_single_resource(self, number: int) -> Iterator[Row]:
_logger.debug(row)
yield row

def _get_multiple_resources(self, bounds: Dict[str, Filter]) -> Iterator[Row]:
def _get_multiple_resources(
self,
bounds: Dict[str, Filter],
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Iterator[Row]:
"""
Return multiple resources.
"""
Expand All @@ -185,8 +203,12 @@ def _get_multiple_resources(self, bounds: Dict[str, Filter]) -> Iterator[Row]:
params = {name: filter_.value for name, filter_ in bounds.items()} # type: ignore
params["per_page"] = PAGE_SIZE

page = 1
while True:
offset = offset or 0
page = (offset // PAGE_SIZE) + 1
offset %= PAGE_SIZE

rowid = 0
while limit is None or rowid < limit:
_logger.info("GET %s (page %d)", url, page)
params["page"] = page
response = self._session.get(url, headers=headers, params=params)
Expand All @@ -198,13 +220,23 @@ def _get_multiple_resources(self, bounds: Dict[str, Filter]) -> Iterator[Row]:
if not response.ok:
raise ProgrammingError(payload["message"])

for i, resource in enumerate(payload):
if offset is not None:
payload = payload[offset:]
offset = None

for resource in payload:
if limit is not None and rowid == limit:
# this never happens because SQLite stops consuming from the generator
# as soon as the limit is hit
break

row = {
column.name: JSONPath(column.json_path).parse(resource)[0]
for column in TABLES[self.base][self.resource]
}
row["rowid"] = i + (page - 1) * PAGE_SIZE
row["rowid"] = rowid
_logger.debug(row)
yield row
rowid += 1

page += 1
19 changes: 13 additions & 6 deletions src/shillelagh/adapters/api/gsheets/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from shillelagh.fields import Field, Order
from shillelagh.filters import Filter
from shillelagh.lib import SimpleCostModel, build_sql
from shillelagh.lib import SimpleCostModel, apply_limit_and_offset, build_sql
from shillelagh.typing import RequestedOrder, Row

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +79,8 @@ class GSheetsAPI(Adapter): # pylint: disable=too-many-instance-attributes
"""

safe = True
supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
Expand Down Expand Up @@ -380,10 +382,12 @@ def _get_header_rows(self, values: List[List[Any]]) -> int:

return i + 1

def get_data(
def get_data( # pylint: disable=too-many-locals
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
"""
Expand All @@ -406,14 +410,15 @@ def get_data(
}:
values = self._get_values()
headers = self._get_header_rows(values)
rows = (
rows: Iterator[Row] = (
{
reverse_map[letter]: cell
for letter, cell in zip(gen_letters(), row)
if letter in reverse_map
}
for row in values[headers:]
)
rows = apply_limit_and_offset(rows, limit, offset)

# For ``BIDIRECTIONAL`` mode we continue using the Chart API to
# retrieve data. This will happen before every DML query.
Expand All @@ -425,7 +430,8 @@ def get_data(
order,
None,
self._column_map,
None,
limit,
offset,
)
except ImpossibleFilterError:
return
Expand All @@ -442,8 +448,9 @@ def get_data(
)

for i, row in enumerate(rows):
self._row_ids[i] = row
row["rowid"] = i
rowid = (offset or 0) + i
self._row_ids[rowid] = row
row["rowid"] = rowid
_logger.debug(row)
yield row

Expand Down
8 changes: 6 additions & 2 deletions src/shillelagh/adapters/api/s3select.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ class S3SelectAPI(Adapter):

safe = True

supports_limit = True
supports_offset = False

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = urllib.parse.urlparse(uri)
Expand Down Expand Up @@ -303,18 +306,19 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
try:
sql = build_sql(self.columns, bounds, order, table="s3object")
sql = build_sql(self.columns, bounds, order, table="s3object", limit=limit)
except ImpossibleFilterError:
return

rows = self._run_query(sql)
for i, row in enumerate(rows):
row["rowid"] = i
yield row
_logger.debug(row)
yield row

def drop_table(self) -> None:
self.s3_client.delete_object(Bucket=self.bucket, Key=self.key, **self.s3_kwargs)
9 changes: 7 additions & 2 deletions src/shillelagh/adapters/api/socrata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class SocrataAPI(Adapter):

safe = True

supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
"""https://data.cdc.gov/resource/unsk-b7fc.json"""
Expand Down Expand Up @@ -147,10 +150,12 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
try:
sql = build_sql(self.columns, bounds, order)
sql = build_sql(self.columns, bounds, order, limit=limit, offset=offset)
except ImpossibleFilterError:
return

Expand All @@ -172,5 +177,5 @@ def get_data(

for i, row in enumerate(payload):
row["rowid"] = i
yield row
_logger.debug(row)
yield row
19 changes: 14 additions & 5 deletions src/shillelagh/adapters/api/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
See https://github.com/giampaolo/psutil for more information.
"""
import logging
import time
import urllib.parse
from datetime import datetime, timezone
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
Expand All @@ -30,6 +31,9 @@ class SystemAPI(Adapter):

safe = False

supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = urllib.parse.urlparse(uri)
Expand Down Expand Up @@ -76,22 +80,27 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
i = 0
while True:
rowid = 0
while limit is None or rowid < limit:
if offset is not None:
time.sleep(self.interval * offset)

try:
values = psutil.cpu_percent(interval=self.interval, percpu=True)
except KeyboardInterrupt:
return

row = {
"rowid": i,
"rowid": rowid,
"timestamp": datetime.now(timezone.utc),
}
for i, value in enumerate(values):
row[f"cpu{i}"] = value / 100.0

yield row
_logger.debug(row)
i += 1
yield row
rowid += 1
7 changes: 6 additions & 1 deletion src/shillelagh/adapters/api/weatherapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class WeatherAPI(Adapter):

safe = True

# Since the adapter doesn't return exact data (see the time columns below)
# implementing limit/offset is not worth the trouble.
supports_limit = False
supports_offset = False

# These two columns can be used to filter the results from the API. We
# define them as inexact since we will retrieve data for the whole day,
# even if specific hours are requested. The post-filtering will be done
Expand Down Expand Up @@ -202,7 +207,7 @@ def get_data( # pylint: disable=too-many-locals
tzinfo=local_timezone,
)
row["rowid"] = int(row["time_epoch"])
yield row
_logger.debug(row)
yield row

start += timedelta(days=1)
9 changes: 7 additions & 2 deletions src/shillelagh/adapters/file/csvfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class CSVFile(Adapter):
# the filesystem, or potentially overwrite existing files
safe = False

supports_limit = True
supports_offset = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
path = Path(uri)
Expand Down Expand Up @@ -159,6 +162,8 @@ def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
_logger.info("Opening file CSV file %s to load data", self.path)
Expand All @@ -177,9 +182,9 @@ def get_data(
# Filter and sort the data. It would probably be more efficient to simply
# declare the columns as having no filter and no sort order, and let the
# backend handle this; but it's nice to have an example of how to do this.
for row in filter_data(data, bounds, order):
yield row
for row in filter_data(data, bounds, order, limit, offset):
_logger.debug(row)
yield row

def insert_data(self, row: Row) -> int:
row_id: Optional[int] = row.pop("rowid")
Expand Down
Loading

0 comments on commit 8d2ebd7

Please sign in to comment.