diff --git a/label_studio/data_manager/actions/cache_labels.py b/label_studio/data_manager/actions/cache_labels.py index fc1fe1f54d4..af8aa893f76 100644 --- a/label_studio/data_manager/actions/cache_labels.py +++ b/label_studio/data_manager/actions/cache_labels.py @@ -18,7 +18,7 @@ def cache_labels_job(project, queryset, **kwargs): source_class = Annotation if source == 'annotations' else Prediction control_tag = request_data.get('custom_control_tag') or request_data.get('control_tag') with_counters = request_data.get('with_counters', 'Yes').lower() == 'yes' - + if source == 'annotations': column_name = 'cache' else: @@ -83,7 +83,7 @@ def cache_labels(project, queryset, request, **kwargs): queryset, organization_id=project.organization_id, request_data=request.data, - job_timeout=60*60*5 # max allowed duration is 5 hours + job_timeout=60 * 60 * 5, # max allowed duration is 5 hours ) return {'response_code': 200} diff --git a/label_studio/tests/data_manager/actions/test_cache_labels.py b/label_studio/tests/data_manager/actions/test_cache_labels.py index 7df31c67b73..92258d2cd63 100644 --- a/label_studio/tests/data_manager/actions/test_cache_labels.py +++ b/label_studio/tests/data_manager/actions/test_cache_labels.py @@ -1,15 +1,15 @@ """Tests for the cache_labels action.""" import pytest -from tasks.models import Task, Annotation, Prediction -from projects.models import Project from data_manager.actions.cache_labels import cache_labels_job from django.contrib.auth import get_user_model +from projects.models import Project +from tasks.models import Annotation, Prediction, Task @pytest.mark.django_db @pytest.mark.parametrize( - "source, control_tag, with_counters, expected_cache_column, use_predictions", + 'source, control_tag, with_counters, expected_cache_column, use_predictions', [ # Test case 1: Annotations, control tag 'ALL', with counters ('annotations', 'ALL', 'Yes', 'cache_all', False), @@ -19,7 +19,7 @@ ('annotations', 'ALL', 'No', 'cache_all', False), # Test case 4: Predictions, control tag 'ALL', with counters ('predictions', 'ALL', 'Yes', 'cache_predictions_all', True), - ] + ], ) def test_cache_labels_job(source, control_tag, with_counters, expected_cache_column, use_predictions): # Initialize a test user and project @@ -30,10 +30,7 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col # Create a few tasks tasks = [] for i in range(3): - task = Task.objects.create( - project=project, - data={'text': f'This is task {i}'} - ) + task = Task.objects.create(project=project, data={'text': f'This is task {i}'}) tasks.append(task) # Add a few annotations or predictions to these tasks @@ -43,30 +40,16 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col 'from_name': 'label', # Control tag used in the result 'to_name': 'text', 'type': 'labels', - 'value': {'labels': [f'Label_{i%2+1}']} + 'value': {'labels': [f'Label_{i%2+1}']}, } ] if use_predictions: - Prediction.objects.create( - task=task, - project=project, - result=result, - model_version='v1' - ) + Prediction.objects.create(task=task, project=project, result=result, model_version='v1') else: - Annotation.objects.create( - task=task, - project=project, - completed_by=test_user, - result=result - ) + Annotation.objects.create(task=task, project=project, completed_by=test_user, result=result) # Prepare the request data - request_data = { - 'source': source, - 'control_tag': control_tag, - 'with_counters': with_counters - } + request_data = {'source': source, 'control_tag': control_tag, 'with_counters': with_counters} # Get the queryset of tasks to process queryset = Task.objects.filter(project=project) @@ -96,18 +79,12 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col if control_tag == 'ALL' or control_tag == from_name: value = result.get('value', {}) for key in value: - if ( - isinstance(value[key], list) - and value[key] - and isinstance(value[key][0], str) - ): + if isinstance(value[key], list) and value[key] and isinstance(value[key][0], str): all_labels.extend(value[key]) break if with_counters.lower() == 'yes': - expected_cache = ', '.join( - sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)]) - ) + expected_cache = ', '.join(sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)])) else: expected_cache = ', '.join(sorted(list(set(all_labels))))