Skip to content

Commit

Permalink
Blue
Browse files Browse the repository at this point in the history
  • Loading branch information
makseq committed Nov 11, 2024
1 parent faa7c28 commit 19d83e9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 36 deletions.
4 changes: 2 additions & 2 deletions label_studio/data_manager/actions/cache_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def cache_labels_job(project, queryset, **kwargs):
source_class = Annotation if source == 'annotations' else Prediction
control_tag = request_data.get('custom_control_tag') or request_data.get('control_tag')
with_counters = request_data.get('with_counters', 'Yes').lower() == 'yes'

if source == 'annotations':
column_name = 'cache'
else:
Expand Down Expand Up @@ -83,7 +83,7 @@ def cache_labels(project, queryset, request, **kwargs):
queryset,
organization_id=project.organization_id,
request_data=request.data,
job_timeout=60*60*5 # max allowed duration is 5 hours
job_timeout=60 * 60 * 5, # max allowed duration is 5 hours
)
return {'response_code': 200}

Expand Down
45 changes: 11 additions & 34 deletions label_studio/tests/data_manager/actions/test_cache_labels.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Tests for the cache_labels action."""

import pytest
from tasks.models import Task, Annotation, Prediction
from projects.models import Project
from data_manager.actions.cache_labels import cache_labels_job
from django.contrib.auth import get_user_model
from projects.models import Project
from tasks.models import Annotation, Prediction, Task


@pytest.mark.django_db
@pytest.mark.parametrize(
"source, control_tag, with_counters, expected_cache_column, use_predictions",
'source, control_tag, with_counters, expected_cache_column, use_predictions',
[
# Test case 1: Annotations, control tag 'ALL', with counters
('annotations', 'ALL', 'Yes', 'cache_all', False),
Expand All @@ -19,7 +19,7 @@
('annotations', 'ALL', 'No', 'cache_all', False),
# Test case 4: Predictions, control tag 'ALL', with counters
('predictions', 'ALL', 'Yes', 'cache_predictions_all', True),
]
],
)
def test_cache_labels_job(source, control_tag, with_counters, expected_cache_column, use_predictions):
# Initialize a test user and project
Expand All @@ -30,10 +30,7 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col
# Create a few tasks
tasks = []
for i in range(3):
task = Task.objects.create(
project=project,
data={'text': f'This is task {i}'}
)
task = Task.objects.create(project=project, data={'text': f'This is task {i}'})
tasks.append(task)

# Add a few annotations or predictions to these tasks
Expand All @@ -43,30 +40,16 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col
'from_name': 'label', # Control tag used in the result
'to_name': 'text',
'type': 'labels',
'value': {'labels': [f'Label_{i%2+1}']}
'value': {'labels': [f'Label_{i%2+1}']},
}
]
if use_predictions:
Prediction.objects.create(
task=task,
project=project,
result=result,
model_version='v1'
)
Prediction.objects.create(task=task, project=project, result=result, model_version='v1')
else:
Annotation.objects.create(
task=task,
project=project,
completed_by=test_user,
result=result
)
Annotation.objects.create(task=task, project=project, completed_by=test_user, result=result)

# Prepare the request data
request_data = {
'source': source,
'control_tag': control_tag,
'with_counters': with_counters
}
request_data = {'source': source, 'control_tag': control_tag, 'with_counters': with_counters}

# Get the queryset of tasks to process
queryset = Task.objects.filter(project=project)
Expand Down Expand Up @@ -96,18 +79,12 @@ def test_cache_labels_job(source, control_tag, with_counters, expected_cache_col
if control_tag == 'ALL' or control_tag == from_name:
value = result.get('value', {})
for key in value:
if (
isinstance(value[key], list)
and value[key]
and isinstance(value[key][0], str)
):
if isinstance(value[key], list) and value[key] and isinstance(value[key][0], str):
all_labels.extend(value[key])
break

if with_counters.lower() == 'yes':
expected_cache = ', '.join(
sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)])
)
expected_cache = ', '.join(sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)]))
else:
expected_cache = ', '.join(sorted(list(set(all_labels))))

Expand Down

0 comments on commit 19d83e9

Please sign in to comment.