Skip to content

Commit

Permalink
Merge pull request #7 from cinder-technologies/jamie/add-cache-key-fn
Browse files Browse the repository at this point in the history
Support cache key fn for dataloaders
  • Loading branch information
jamie-cndr authored Nov 20, 2024
2 parents 8d40934 + f63e5aa commit 3e6fafc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
7 changes: 5 additions & 2 deletions graphql_sync_dataloaders/sync_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextvars
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional

from graphql.pyutils import is_collection

Expand Down Expand Up @@ -42,12 +42,15 @@ def run_all_callbacks(self):


class SyncDataLoader:
def __init__(self, batch_load_fn):
def __init__(self, batch_load_fn, cache_key_fn: Callable[[Any], str] | None = None):
self._batch_load_fn = batch_load_fn
self._cache_key_fn = cache_key_fn
self._cache = {}
self._queue = []

def load(self, key):
if self._cache_key_fn:
key = self._cache_key_fn(key)
try:
return self._cache[key]
except KeyError:
Expand Down
73 changes: 72 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from unittest import mock
from unittest.mock import Mock
from functools import partial

Expand Down Expand Up @@ -66,6 +65,78 @@ def resolve_name(_, __, key):
assert mock_load_fn.call_count == 1


def test_cache_key_fn():
NAMES = {}

class TestEntityClass:
def __init__(self, username: str, count: int):
self.username = username
self.count = count

def cache_key_fn(entity: TestEntityClass) -> str:
return f"{entity.username}-{entity.count}"

def load_fn(entities):
return [NAMES[entity] for entity in entities]

entities = [
TestEntityClass(username="Sarah", count=1),
TestEntityClass(username="Lucy", count=2),
]

for entity in entities:
NAMES[cache_key_fn(entity)] = entity.username

mock_load_fn = Mock(wraps=load_fn)
dataloader = SyncDataLoader(mock_load_fn, cache_key_fn)

def resolve_name(_, __, key):
entity = entities[int(key)]
return dataloader.load(entity)

schema = GraphQLSchema(
query=GraphQLObjectType(
name="Query",
fields={
"name": GraphQLField(
GraphQLString,
args={
"key": GraphQLArgument(GraphQLString),
},
resolve=resolve_name,
)
},
)
)

result = graphql_sync_deferred(
schema,
"""
query {
name1: name(key: "0")
name2: name(key: "1")
}
""",
)
assert not result.errors
assert result.data == {"name1": "Sarah", "name2": "Lucy"}
assert mock_load_fn.call_count == 1

# Ensure the cache is used for requests with the same key instead of the loader function
mock_load_fn.reset_mock()
result = graphql_sync_deferred(
schema,
"""
query {
name1: name(key: "0")
name2: name(key: "1")
}
""",
)
assert result.data == {"name1": "Sarah", "name2": "Lucy"}
assert mock_load_fn.call_count == 0


def test_nested_deferred_execution():
USERS = {
"1": {
Expand Down

0 comments on commit 3e6fafc

Please sign in to comment.