diff --git a/label_studio/core/utils/db.py b/label_studio/core/utils/db.py index 21f19d6e7727..e8e0445a6b85 100644 --- a/label_studio/core/utils/db.py +++ b/label_studio/core/utils/db.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, TypeVar import logging from core.feature_flags import flag_set from django.db import models -from django.db.models import Subquery +from django.db.models import Model, QuerySet, Subquery if TYPE_CHECKING: from users.models import User @@ -16,14 +16,16 @@ class SQCount(Subquery): output_field = models.IntegerField() -def fast_first(queryset): +ModelType = TypeVar('ModelType', bound=Model) + +def fast_first(queryset: QuerySet[ModelType]) -> Optional[ModelType]: """Replacement for queryset.first() when you don't need ordering, queryset.first() works slowly in some cases """ - try: - return queryset.all()[0] - except IndexError: - return None + + if result := queryset[:1]: + return result[0] + return None def should_run_bulk_update_in_transaction(organization_created_by_user: "User") -> bool: diff --git a/label_studio/tasks/mixins.py b/label_studio/tasks/mixins.py index ab71dc807eda..8bd19caec2b4 100644 --- a/label_studio/tasks/mixins.py +++ b/label_studio/tasks/mixins.py @@ -19,4 +19,4 @@ def post_process_bulk_update_stats(cls, tasks) -> None: class AnnotationMixin: def has_permission(self, user: "User") -> bool: """Called by Annotation#has_permission""" - return True \ No newline at end of file + return True diff --git a/label_studio/tasks/models.py b/label_studio/tasks/models.py index df5102037edd..845b94cdfa6b 100644 --- a/label_studio/tasks/models.py +++ b/label_studio/tasks/models.py @@ -22,7 +22,7 @@ string_is_url, temporary_disconnect_list_signal, ) -from core.utils.db import should_run_bulk_update_in_transaction +from core.utils.db import fast_first, should_run_bulk_update_in_transaction from core.utils.params import get_env from data_import.models import FileUpload from data_manager.managers import PreparedTaskManager, TaskManager @@ -189,13 +189,15 @@ def get_locked_by(cls, user, project=None, tasks=None): """Retrieve the task locked by specified user. Returns None if the specified user didn't lock anything.""" lock = None if project is not None: - lock = TaskLock.objects.filter( - user=user, expire_at__gt=now(), task__project=project - ).first() + lock = fast_first( + TaskLock.objects.filter( + user=user, expire_at__gt=now(), task__project=project + ) + ) elif tasks is not None: - locked_task = tasks.filter( - locks__user=user, locks__expire_at__gt=now() - ).first() + locked_task = fast_first( + tasks.filter(locks__user=user, locks__expire_at__gt=now()) + ) if locked_task: return locked_task else: @@ -356,10 +358,11 @@ def resolve_uri(self, task_data, project): prepared_filename ): # permission check: resolve uploaded files to the project only - file_upload = None - file_upload = FileUpload.objects.filter( - project=project, file=prepared_filename - ).first() + file_upload = fast_first( + FileUpload.objects.filter( + project=project, file=prepared_filename + ) + ) if file_upload is not None: if flag_set( "ff_back_dev_2915_storage_nginx_proxy_26092022_short",