Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue104; get valid key value combinations #234

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions terracotta/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def get_keys(self) -> OrderedDict:
"""
pass

@abstractmethod
def get_valid_values(self, where: Mapping[str, Union[str, List[str]]]) -> Dict[str, List[str]]:
pass

@abstractmethod
def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None,
page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]:
Expand Down
29 changes: 29 additions & 0 deletions terracotta/drivers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,35 @@ def _get_keys(self) -> OrderedDict:

return out

@requires_connection
@convert_exceptions('Could not retrieve valid key values')
def get_valid_values(self, where: Mapping[str, Union[str, List[str]]]) -> Dict[str, List[str]]:
cursor = self._cursor

if not all(key in self.key_names for key in where.keys()):
raise exceptions.InvalidKeyError('Encountered unrecognized keys in where clause')

conditions = []
values = []
for key, value in where.items():
if isinstance(value, str):
value = [value]
values.extend(value)
conditions.append(' OR '.join([f'{key}=%s'] * len(value)))
where_fragment = ' AND '.join([f'({condition})' for condition in conditions])
where_fragment = ' WHERE ' + where_fragment if where_fragment else ''

valid_values = {key: [val] if isinstance(val, str) else val for key, val in where.items()}

for key in set(self.key_names) - set(where.keys()):
cursor.execute(
f'SELECT DISTINCT {key} FROM datasets {where_fragment}',
values
)
valid_values[key] = list([row[key] for row in cursor.fetchall()])

return valid_values

@trace('get_datasets')
@requires_connection
@convert_exceptions('Could not retrieve datasets')
Expand Down
29 changes: 29 additions & 0 deletions terracotta/drivers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,35 @@ def get_keys(self) -> OrderedDict:
out[row['key']] = row['description']
return out

@requires_connection
@convert_exceptions('Could not retrieve valid key values')
def get_valid_values(self, where: Mapping[str, Union[str, List[str]]]) -> Dict[str, List[str]]:
conn = self._connection

if not all(key in self.key_names for key in where.keys()):
raise exceptions.InvalidKeyError('Encountered unrecognized keys in where clause')

conditions = []
values = []
for key, value in where.items():
if isinstance(value, str):
value = [value]
values.extend(value)
conditions.append(' OR '.join([f'{key}=?'] * len(value)))
where_fragment = ' AND '.join([f'({condition})' for condition in conditions])
where_fragment = ' WHERE ' + where_fragment if where_fragment else ''

valid_values = {key: [val] if isinstance(val, str) else val for key, val in where.items()}

for key in set(self.key_names) - set(where.keys()):
rows = conn.execute(
f'SELECT DISTINCT {key} FROM datasets {where_fragment}',
values
)
valid_values[key] = list([row[key] for row in rows])

return valid_values

@trace('get_datasets')
@requires_connection
@convert_exceptions('Could not retrieve datasets')
Expand Down
21 changes: 21 additions & 0 deletions terracotta/handlers/valid_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""handlers/valid_values.py

Handle /valid_values API endpoint.
"""

from typing import Dict, Mapping, List, Union

from terracotta import get_settings, get_driver
from terracotta.profile import trace


@trace('valid_values_handler')
def valid_values(some_keys: Mapping[str, Union[str, List[str]]] = None) -> Dict[str, List[str]]:
"""List all available valid values"""
settings = get_settings()
driver = get_driver(settings.DRIVER_PATH, provider=settings.DRIVER_PROVIDER)

with driver.connect():
valid_values = driver.get_valid_values(some_keys or {})

return valid_values
2 changes: 2 additions & 0 deletions terracotta/server/flask_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def create_app(debug: bool = False, profile: bool = False) -> Flask:
from terracotta import get_settings
import terracotta.server.datasets
import terracotta.server.keys
import terracotta.server.valid_values
import terracotta.server.colormap
import terracotta.server.metadata
import terracotta.server.rgb
Expand Down Expand Up @@ -97,6 +98,7 @@ def create_app(debug: bool = False, profile: bool = False) -> Flask:
with new_app.test_request_context():
SPEC.path(view=terracotta.server.datasets.get_datasets)
SPEC.path(view=terracotta.server.keys.get_keys)
SPEC.path(view=terracotta.server.valid_values.get_valid_values)
SPEC.path(view=terracotta.server.colormap.get_colormap)
SPEC.path(view=terracotta.server.metadata.get_metadata)
SPEC.path(view=terracotta.server.rgb.get_rgb)
Expand Down
72 changes: 72 additions & 0 deletions terracotta/server/valid_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""server/valid_values.py

Flask route to handle /valid_values calls.
"""

from typing import Any, Dict, List, Union
from flask import request, jsonify, Response
from marshmallow import Schema, fields, INCLUDE, post_load
import re

from terracotta.server.flask_api import METADATA_API


class KeyValueOptionSchema(Schema):
class Meta:
unknown = INCLUDE

# placeholder values to document keys
key1 = fields.String(example='value1', description='Value of key1', dump_only=True)
key2 = fields.String(example='value2', description='Value of key2', dump_only=True)

@post_load
def list_items(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Union[str, List[str]]]:
# Create lists of values supplied as stringified lists
for key, value in data.items():
if isinstance(value, str) and re.match(r'^\[.*\]$', value):
data[key] = value[1:-1].split(',')
return data


class KeyValueSchema(Schema):
valid_values = fields.Dict(
key=fields.String(example='key1'),
values=fields.List(fields.String(example='value1')),
required=True,
description='Array containing all available key combinations'
)


@METADATA_API.route('/valid_values', methods=['GET'])
def get_valid_values() -> Response:
"""Get all valid values combinations (possibly when given a value for some keys)
---
get:
summary: /datasets
description:
Get unique key values of all available datasets that match given key constraint.
Constraints may be combined freely. Returns all valid key values if no query parameters
are given.
parameters:
- in: query
schema: KeyValueOptionSchema
responses:
200:
description: All available key value combinations
schema:
type: KeyValueSchema
400:
description: Query parameters contain unrecognized keys
"""
from terracotta.handlers.valid_values import valid_values
option_schema = KeyValueOptionSchema()
options = option_schema.load(request.args)

keys = options or None

payload = {
'valid_values': valid_values(keys)
}

schema = KeyValueSchema()
return jsonify(schema.load(payload))
34 changes: 34 additions & 0 deletions tests/drivers/test_raster_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,40 @@ def test_path_override(driver_path, provider, raster_file):
assert bogus_path in exc.value


@pytest.mark.parametrize('provider', DRIVERS)
def test_valid_values(driver_path, provider, raster_file):
from terracotta import drivers, exceptions
db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')

db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))
db.insert({'some': 'a', 'keynames': 'third_value'}, str(raster_file))

data = db.get_valid_values({})
assert len(data) == 2
assert len(data['some']) == 2
assert len(data['keynames']) == 3

data = db.get_valid_values(where=dict(some='some'))
assert len(data) == 2
assert data['some'] == ['some']
assert set(data['keynames']) == set(['value', 'other_value'])

data = db.get_valid_values(where=dict(some='some', keynames='value'))
assert set(data.keys()) == set(['some', 'keynames'])
assert data['some'] == ['some']
assert data['keynames'] == ['value']

data = db.get_valid_values(where=dict(some='unknown'))
assert data == {'some': ['unknown'], 'keynames': []}

with pytest.raises(exceptions.InvalidKeyError) as exc:
db.get_valid_values(where=dict(unknown='foo'))
assert 'unrecognized keys' in str(exc.value)


@pytest.mark.parametrize('provider', DRIVERS)
def test_where(driver_path, provider, raster_file):
from terracotta import drivers, exceptions
Expand Down
10 changes: 10 additions & 0 deletions tests/handlers/test_valid_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

def test_valid_values_handler(testdb, use_testdb):
import terracotta
from terracotta.handlers import valid_values

driver = terracotta.get_driver(str(testdb))

handler_response = valid_values.valid_values({})
assert handler_response
assert set(handler_response.keys()) == set(driver.key_names)
39 changes: 39 additions & 0 deletions tests/server/test_flask_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,45 @@ def test_get_metadata_nonexisting(client, use_testdb):
assert rv.status_code == 404


def test_get_valid_values(client, use_testdb):
rv = client.get('/valid_values')
assert rv.status_code == 200
valid_values = json.loads(rv.data, object_pairs_hook=OrderedDict)['valid_values']
assert len(valid_values) == 3
assert len(valid_values['key1']) == 2
assert 'val11' in valid_values['key1'] and 'val21' in valid_values['key1']
assert valid_values['akey'] == ['x']


def test_get_valid_values_selective(client, use_testdb):
rv = client.get('/valid_values?key1=val21')
assert rv.status_code == 200
valid_values = json.loads(rv.data, object_pairs_hook=OrderedDict)['valid_values']
assert len(valid_values) == 3
assert valid_values['key1'] == ['val21']
assert len(valid_values['key2']) == 3
assert 'val22' in valid_values['key2'] and 'val23' in valid_values['key2']
assert valid_values['akey'] == ['x']

rv = client.get('/valid_values?key1=[val21]')
assert rv.status_code == 200
valid_values = json.loads(rv.data, object_pairs_hook=OrderedDict)['valid_values']
assert len(valid_values) == 3
assert valid_values['key1'] == ['val21']
assert len(valid_values['key2']) == 3
assert 'val22' in valid_values['key2'] and 'val23' in valid_values['key2']
assert valid_values['akey'] == ['x']

rv = client.get('/valid_values?key1=val21&key2=[val23,val24]')
assert rv.status_code == 200
valid_values = json.loads(rv.data, object_pairs_hook=OrderedDict)['valid_values']
assert len(valid_values) == 3
assert valid_values['key1'] == ['val21']
assert len(valid_values['key2']) == 2
assert 'val23' in valid_values['key2'] and 'val24' in valid_values['key2']
assert valid_values['akey'] == ['x']


def test_get_datasets(client, use_testdb):
rv = client.get('/datasets')
assert rv.status_code == 200
Expand Down