Skip to content

Commit

Permalink
Fix memoization of bodyJson
Browse files Browse the repository at this point in the history
The old cacheBodyJson was not clean
It used to replace the getBodyJson method which is very dangerous
Instead, add a kwargs named "memoizedBodyJson" to the function call
  • Loading branch information
bruyeret committed Dec 4, 2023
1 parent 8498d7c commit 0f172f1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from girder.api.rest import Resource, loadmodel
from girder.constants import AccessType
from girder.exceptions import AccessException, RestException
from ..helpers.proxiedModel import recordable, cacheBodyJson
from ..helpers.proxiedModel import recordable, memoizeBodyJson
from ..models.annotation import Annotation as AnnotationModel

from bson.objectid import ObjectId
Expand All @@ -12,19 +12,19 @@
# Helper functions to get dataset ID for recordable endpoints

def getDatasetIdFromAnnotationInBody(self: 'Annotation', *args, **kwargs):
annotation = self.getBodyJson()
annotation = kwargs["memoizedBodyJson"]
return annotation['datasetId']

def getDatasetIdFromAnnotationListInBody(self: 'Annotation', *args, **kwargs):
annotations = self.getBodyJson()
annotations = kwargs["memoizedBodyJson"]
return None if len(annotations) <= 0 else annotations[0]['datasetId']

def getDatasetIdFromLoadedAnnotation(self: 'Annotation', *args, **kwargs):
annotation = kwargs['upenn_annotation']
return annotation['datasetId']

def getDatasetIdFromAnnotationIdListInBody(self: 'Annotation', *args, **kwargs):
annotationStringIds = self.getBodyJson()
annotationStringIds = kwargs["memoizedBodyJson"]
query = {
'_id': { '$in': [ObjectId(stringId) for stringId in annotationStringIds] },
}
Expand Down Expand Up @@ -57,25 +57,27 @@ def __init__(self):
# TODO(performance): use objectId whenever possible
# TODO: error handling and documentation

@cacheBodyJson
@access.user
@describeRoute(Description("Create a new annotation").param('body', 'Annotation Object', paramType='body'))
@memoizeBodyJson
@recordable('Create an annotation', getDatasetIdFromAnnotationInBody)
def create(self, params):
def create(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
currentUser = self.getCurrentUser()
if not currentUser:
raise AccessException('User not found', 'currentUser')
return self._annotationModel.create(currentUser, self.getBodyJson())
return self._annotationModel.create(currentUser, bodyJson)

@cacheBodyJson
@access.user
@describeRoute(Description("Create multiple new annotations").param('body', 'Annotation Object List', paramType='body'))
@memoizeBodyJson
@recordable('Create multiple annotations', getDatasetIdFromAnnotationListInBody)
def createMultiple(self, params):
def createMultiple(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
currentUser = self.getCurrentUser()
if not currentUser:
raise AccessException('User not found', 'currentUser')
return [self._annotationModel.create(currentUser, annotation) for annotation in self.getBodyJson()]
return [self._annotationModel.create(currentUser, annotation) for annotation in bodyJson]

@describeRoute(Description("Delete an existing annotation").param('id', 'The annotation\'s Id', paramType='path').errorResponse('ID was invalid.')
.errorResponse('Write access was denied for the annotation.', 403))
Expand All @@ -85,13 +87,14 @@ def createMultiple(self, params):
def delete(self, upenn_annotation, params):
self._annotationModel.delete(upenn_annotation)

@cacheBodyJson
@access.user
@describeRoute(Description("Delete all annotations in the id list")
.param('body', 'A list of all annotation ids to delete.', paramType='body'))
@memoizeBodyJson
@recordable('Delete multiple annotations', getDatasetIdFromAnnotationIdListInBody)
def deleteMultiple(self, params):
stringIds = [stringId for stringId in self.getBodyJson()]
def deleteMultiple(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
stringIds = [stringId for stringId in bodyJson]
self._annotationModel.deleteMultiple(stringIds)

@describeRoute(Description("Update an existing annotation")
Expand All @@ -103,9 +106,11 @@ def deleteMultiple(self, params):
.errorResponse("Validation Error: JSON doesn't follow schema."))
@access.user
@loadmodel(model='upenn_annotation', plugin='upenncontrast_annotation', level=AccessType.WRITE)
@memoizeBodyJson
@recordable('Update an annotation', getDatasetIdFromLoadedAnnotation)
def update(self, upenn_annotation, params):
upenn_annotation.update(self.getBodyJson())
def update(self, upenn_annotation, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
upenn_annotation.update(bodyJson)
self._annotationModel.update(upenn_annotation)

@access.user
Expand Down Expand Up @@ -141,8 +146,10 @@ def get(self, upenn_annotation, params):
@describeRoute(Description("Compute annotations from a worker tool")
.param('datasetId', 'The dataset Id', required=False)
.param('body', 'A JSON object containing the worker tool', paramType='body'))
def compute(self, params):
@memoizeBodyJson
def compute(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
datasetId = params.get('datasetId', None)
if not datasetId:
raise RestException(code=400, message="Missing datasetId parameter")
return self._annotationModel.compute(datasetId, self.getBodyJson(), self.getCurrentUser())
return self._annotationModel.compute(datasetId, bodyJson, self.getCurrentUser())
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from girder.api.rest import Resource, loadmodel
from girder.constants import AccessType
from girder.exceptions import AccessException
from ..helpers.proxiedModel import recordable, cacheBodyJson
from ..helpers.proxiedModel import recordable, memoizeBodyJson
from ..models.connections import AnnotationConnection as ConnectionModel
from ..models.annotation import Annotation as AnnotationModel

Expand All @@ -13,19 +13,19 @@
# Helper functions to get dataset ID for recordable endpoints

def getDatasetIdFromConnectionInBody(self: 'AnnotationConnection', *args, **kwargs):
connection = self.getBodyJson()
connection = kwargs["memoizedBodyJson"]
return connection['datasetId']

def getDatasetIdFromConnectionListInBody(self: 'AnnotationConnection', *args, **kwargs):
connections = self.getBodyJson()
connections = kwargs["memoizedBodyJson"]
return None if len(connections) <= 0 else connections[0]['datasetId']

def getDatasetIdFromLoadedConnection(self: 'AnnotationConnection', *args, **kwargs):
connection = kwargs['annotation_connection']
return connection['datasetId']

def getDatasetIdFromConnectionIdListInBody(self: 'AnnotationConnection', *args, **kwargs):
connectionStringIds = self.getBodyJson()
connectionStringIds = kwargs["memoizedBodyJson"]
query = {
'_id': { '$in': [ObjectId(stringId) for stringId in connectionStringIds] },
}
Expand All @@ -34,7 +34,7 @@ def getDatasetIdFromConnectionIdListInBody(self: 'AnnotationConnection', *args,
return None if connection is None else connection['datasetId']

def getDatasetIdFromInfoInBody(self: 'AnnotationConnection', *args, **kwargs):
info = self.getBodyJson()
info = kwargs["memoizedBodyJson"]
annotationsIdsToConnect = info['annotationsIds']
annotationModel: AnnotationModel = AnnotationModel()
for stringId in annotationsIdsToConnect:
Expand Down Expand Up @@ -67,25 +67,27 @@ def __init__(self):
# TODO: creation date, update date, creatorId
# TODO: error handling and documentation

@cacheBodyJson
@access.user
@describeRoute(Description("Create a new connection").param('body', 'Connection Object', paramType='body'))
@memoizeBodyJson
@recordable('Create a connection', getDatasetIdFromConnectionInBody)
def create(self, params):
def create(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
currentUser = self.getCurrentUser()
if not currentUser:
raise AccessException('User not found', 'currentUser')
return self._connectionModel.create(currentUser, self.getBodyJson())
return self._connectionModel.create(currentUser, bodyJson)

@cacheBodyJson
@access.user
@describeRoute(Description("Create multiple new connections").param('body', 'Connection Object List', paramType='body'))
@memoizeBodyJson
@recordable('Create multiple connections', getDatasetIdFromConnectionListInBody)
def multipleCreate(self, params):
def multipleCreate(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
currentUser = self.getCurrentUser()
if not currentUser:
raise AccessException('User not found', 'currentUser')
return [self._connectionModel.create(currentUser, connection) for connection in self.getBodyJson()]
return [self._connectionModel.create(currentUser, connection) for connection in bodyJson]

@describeRoute(Description("Delete an existing connection").param('id', 'The connection\'s Id', paramType='path').errorResponse('ID was invalid.')
.errorResponse('Write access was denied for the connection.', 403))
Expand All @@ -95,13 +97,14 @@ def multipleCreate(self, params):
def delete(self, annotation_connection, params):
self._connectionModel.remove(annotation_connection)

@cacheBodyJson
@access.user
@describeRoute(Description("Delete all annotation connections in the id list")
.param('body', 'A list of all annotation connection ids to delete.', paramType='body'))
@memoizeBodyJson
@recordable('Delete multiple connections', getDatasetIdFromConnectionIdListInBody)
def deleteMultiple(self, params):
stringIds = [stringId for stringId in self.getBodyJson()]
def deleteMultiple(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
stringIds = [stringId for stringId in bodyJson]
self._connectionModel.deleteMultiple(stringIds)

@describeRoute(Description("Update an existing connection")
Expand All @@ -113,9 +116,11 @@ def deleteMultiple(self, params):
.errorResponse("Validation Error: JSON doesn't follow schema."))
@access.user
@loadmodel(model='annotation_connection', plugin='upenncontrast_annotation', level=AccessType.WRITE)
@memoizeBodyJson
@recordable('Update a connection', getDatasetIdFromLoadedConnection)
def update(self, connection, params):
connection.update(self.getBodyJson())
def update(self, connection, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
connection.update(bodyJson)
self._connectionModel.update(connection)

@access.user
Expand Down Expand Up @@ -156,12 +161,13 @@ def get(self, annotation_connection):
return annotation_connection


@cacheBodyJson
@access.user
@describeRoute(Description("Create connections between annotations").param('body', 'Connection Object', paramType='body'))
@memoizeBodyJson
@recordable('Create connections with nearest', getDatasetIdFromInfoInBody)
def connectToNearest(self, params):
def connectToNearest(self, params, *args, **kwargs):
bodyJson = kwargs["memoizedBodyJson"]
currentUser = self.getCurrentUser()
if not currentUser:
raise AccessException('User not found', 'currentUser')
return self._connectionModel.connectToNearest(user=currentUser, info=self.getBodyJson())
return self._connectionModel.connectToNearest(user=currentUser, info=bodyJson)
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,23 @@
from functools import wraps
from bson.objectid import ObjectId

def cacheBodyJson(func):
def memoizeBodyJson(func):
'''
A decorator on rest.Resource methods to cache the result of self.getBodyJson()
This is usefull when some decorators and the decorated function use it
For example, when using @recordable
Use this decorator before any other decorator, for example:
For example, when using @recordable with a findDatasetIdFn that uses bodyJson
Use this decorator before any other decorator using memoizedBodyJson:
```
@cacheBodyJson
@memoizeBodyJson
@recordable('Foo', bar)
def f(self, *args, **kwargs):
pass
```
instead of
```
@recordable('Foo', bar)
@cacheBodyJson
def f(self, *args, **kwargs):
pass
```
'''
@wraps(func)
def wrapped(self: rest.Resource, *args, **kwargs):
# Get the result of getBodyJson
error = None
bodyJson = None
try:
bodyJson = self.getBodyJson()
except rest.RestException as err:
error = err

# Create memoised version of the method
def getCachedBodyJson(self):
return deepcopy(bodyJson)
def raiseBodyJsonError(self):
raise error
newFunction = getCachedBodyJson if error is None else raiseBodyJsonError
newMethod = MethodType(newFunction, self)

# Wrap the function call
originalMethod = self.getBodyJson
try:
self.getBodyJson = newMethod
return func(self, *args, **kwargs)
finally:
self.getBodyJson = originalMethod
return func(self, *args, **kwargs, memoizedBodyJson=self.getBodyJson())

return wrapped

Expand Down

0 comments on commit 0f172f1

Please sign in to comment.