Skip to content

Commit

Permalink
Add tests for search engine-based similarity search
Browse files Browse the repository at this point in the history
  • Loading branch information
ffont committed Feb 9, 2024
1 parent 820685e commit 199ce0f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 12 deletions.
6 changes: 5 additions & 1 deletion sounds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from utils.mail import send_mail_template
from utils.search import get_search_engine, SearchEngineException
from utils.search.search_sounds import delete_sounds_from_search_engine
from utils.similarity_utilities import delete_sound_from_gaia
from utils.similarity_utilities import delete_sound_from_gaia, get_similarity_search_target_vector
from utils.sound_upload import get_csv_lines, validate_input_csv_file, bulk_describe_from_csv

web_logger = logging.getLogger('web')
Expand Down Expand Up @@ -1372,6 +1372,10 @@ def ready_for_similarity(self):
else:
# If not using search engine based similarity, then use the old similarity_state DB field
return self.similarity_state == "OK"

def get_similarity_search_target_vector(self, analyzer=settings.SEARCH_ENGINE_DEFAULT_SIMILARITY_ANALYZER):
# If the sound has been analyzed for similarity, returns the vector to be used for similarity search
return get_similarity_search_target_vector(self.id, analyzer=analyzer)

class Meta:
ordering = ("-created", )
Expand Down
17 changes: 7 additions & 10 deletions utils/search/backends/solr555pysolr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from utils.text import remove_control_chars
from utils.search import SearchEngineBase, SearchResults, SearchEngineException
from utils.search.backends.solr_common import SolrQuery, SolrResponseInterpreter
from utils.similarity_utilities import get_similarity_search_target_vector


SOLR_FORUM_URL = f"{settings.SOLR5_BASE_URL}/forum"
Expand Down Expand Up @@ -518,9 +519,9 @@ def search_sounds(self, textual_query='', query_fields=None, query_filter='', of

else:
# Similarity search!
query.set_query('')

# We fist set an empty query that will return no results and will be used by default if similarity can't be performed
query.set_query('')
if similar_to_analyzer in settings.SEARCH_ENGINE_SIMILARITY_ANALYZERS:
# Similarity search will find documents close to a target vector. This will match "child" sound documents (of content_type "similarity vectpor")
vector = None
Expand All @@ -529,22 +530,18 @@ def search_sounds(self, textual_query='', query_fields=None, query_filter='', of
vector_field_name = SOLR_VECTOR_FIELDS_DIMENSIONS_MAP.get(len(vector), None)
else:
# similar_to should be a sound_id
sa = SoundAnalysis.objects.filter(sound_id=similar_to, analyzer=similar_to_analyzer, analysis_status="OK")
sound = Sound.objects.get(id=similar_to)
config_options = settings.SEARCH_ENGINE_SIMILARITY_ANALYZERS[similar_to_analyzer]
vector_field_name = SOLR_VECTOR_FIELDS_DIMENSIONS_MAP.get(config_options['vector_size'], None)
if sa.exists():
data = sa.first().get_analysis_data_from_file()
if data is not None:
vector_raw = data[config_options['vector_property_name']]
if vector_raw is not None:
vector = vector_raw[0:config_options['vector_size']]
vector = get_similarity_search_target_vector(sound.id, analyzer=similar_to_analyzer)
if vector is not None:
vector = vector[0:config_options['vector_size']] # Make sure the vector has the right size (just in case)

if vector is not None and vector_field_name is not None:
max_similar_sounds = similar_to_max_num_sounds # Max number of results for similarity search search. Filters are applied before the similarity search, so this number will usually be the total number of results (unless filters are more restrictive)
serialized_vector = ','.join([str(n) for n in vector])
query.set_query(f'{{!knn f={vector_field_name} topK={max_similar_sounds}}}[{serialized_vector}]')


# Process filter
query_filter = self.search_process_filter(query_filter,
only_sounds_within_ids=only_sounds_within_ids,
Expand All @@ -559,7 +556,7 @@ def search_sounds(self, textual_query='', query_fields=None, query_filter='', of
query_filter_modified.append(f'-_nest_parent_:{int(similar_to)}')
# Update the top_similar_sounds_as_filter so we compensate for the fact that we are removing the target sound from the results
top_similar_sounds_as_filter=top_similar_sounds_as_filter.replace(f'topK={similar_to_max_num_sounds}', f'topK={similar_to_max_num_sounds + 1}')
except ValueError:
except TypeError:
# Target is not a sound id, so we don't need to add the filter
pass

Expand Down
55 changes: 54 additions & 1 deletion utils/search/backends/test_search_engine_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
import time

from django.conf import settings
from tags.models import TaggedItem
from unittest import mock

import utils.search
from forum.models import Post
from sounds.models import Sound, Download
from tags.models import TaggedItem
from utils.search import get_search_engine


Expand Down Expand Up @@ -318,7 +319,58 @@ def sound_check_get_pack_tags(self, sounds):
if self.output_file:
self.output_file.write(f'\n* PACK "{pack.id}" TOP TAGS FROM SEARCH ENGINE: {search_engine_tags}\n')

@mock.patch('utils.search.backends.solr555pysolr.get_similarity_search_target_vector')
def sound_check_similarity_search(self, sounds, get_similarity_search_target_vector):
get_similarity_search_target_vector.return_value = [sounds[0].id for _ in range(100)]
# Make sure sounds are sorted by ID so that in similarity search the closest sound is either the next or the previous one
sounds = sorted(sounds, key=lambda x: x.id)

# Make a query for target sound 0 and check that results are sorted by ID (as expected because we set sound similarity vectors to their ID)
# We have to take into account that the target sounds is removed from results
results = self.run_sounds_query_and_save_results(dict(similar_to=sounds[0].id, similar_to_max_num_sounds=10, similar_to_analyzer='test_analyzer'))
results_ids = [r['id'] for r in results.docs]
sounds_ids = [s.id for s in sounds][1:11] # target sound is not expected to be in results
assert_and_continue(results_ids == sounds_ids, 'Similarity search did not return sounds sorted as expected when searching with a target sound ID')


# Now make the same query but passing an arbitrary vector (which happens to be the same as for the first sound). Now the first sound should also be
# included in the results as the closest one
target_sound_vector = [sounds[0].id for _ in range(100)] # Use sound 0 as target sound so we know the other sounds should be sorted by distance)
results = self.run_sounds_query_and_save_results(dict(similar_to=target_sound_vector, similar_to_max_num_sounds=10, similar_to_analyzer='test_analyzer'))
results_ids = [r['id'] for r in results.docs]
sounds_ids = [s.id for s in sounds][0:10] # target sound is expected to be in results
assert_and_continue(results_ids == sounds_ids, 'Similarity search did not return sounds sorted as expected when searching with a target vector')

# Check requesting sounds for an unexisting analyzer, should return 0 results
results = self.run_sounds_query_and_save_results(dict(similar_to=target_sound_vector, similar_to_max_num_sounds=10, similar_to_analyzer='test_analyzer2'))
assert_and_continue(len(results.docs) == 0, 'Similarity search returned results for an unexsiting analyzer')

# Check similar_to_max_num_sounds parmeter
results = self.run_sounds_query_and_save_results(dict(similar_to=target_sound_vector, similar_to_max_num_sounds=5, similar_to_analyzer='test_analyzer'))
assert_and_continue(len(results.docs) == 5, 'Similarity search returned unexpected number of results')


def test_search_enginge_backend_sounds(self):
# Monkey patch 'add_similarity_vectors_to_documents' from search engine so we add fake similarity vectors
# to our testing core. Also override some settings to similarity search works in test environment.
def patched_add_similarity_vectors_to_documents(sound_objects, documents):
for document in documents:
document['similarity_vectors'] = [{
'content_type': 'v', # Content type for similarity vectors
'analyzer': 'test_analyzer',
'timestamp_start': 0,
'timestamp_end': -1,
'sim_vector100': [document['id'] for _ in range(100)], # Use fake vectors using sound ID so we can do some easy checks later
}]
self.search_engine.add_similarity_vectors_to_documents = patched_add_similarity_vectors_to_documents
settings.SEARCH_ENGINE_SIMILARITY_ANALYZERS = {
'test_analyzer': {
'vector_property_name': 'embeddings',
'vector_size': 100,
}
}
settings.SEARCH_ENGINE_DEFAULT_SIMILARITY_ANALYZER = 'test_analyzer'

# Get sounds for testing
test_sound_ids = list(Sound.public
.filter(is_index_dirty=False, num_ratings__gt=settings.MIN_NUMBER_RATINGS)
Expand Down Expand Up @@ -404,6 +456,7 @@ def test_search_enginge_backend_sounds(self):
self.sound_check_extra_queries()
self.sound_check_get_user_tags(sounds[0])
self.sound_check_get_pack_tags(sounds)
self.sound_check_similarity_search(sounds)

console_logger.info('Testing of sound search methods finished!')

Expand Down
15 changes: 15 additions & 0 deletions utils/similarity_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from similarity.client import Similarity
from similarity.similarity_settings import PRESETS, DEFAULT_PRESET, SIMILARITY_CACHE_TIME
import sounds
from utils.encryption import create_hash

web_logger = logging.getLogger('web')
Expand Down Expand Up @@ -167,3 +168,17 @@ def delete_sound_from_gaia(sound_id):

def hash_cache_key(key):
return create_hash(key, limit=32)


def get_similarity_search_target_vector(sound_id, analyzer=settings.SEARCH_ENGINE_DEFAULT_SIMILARITY_ANALYZER):
# If the sound has been analyzed for similarity, returns the vector to be used for similarity search
sa = sounds.models.SoundAnalysis.objects.filter(sound_id=sound_id, analyzer=analyzer, analysis_status="OK")
if sa.exists():
config_options = settings.SEARCH_ENGINE_SIMILARITY_ANALYZERS[analyzer]
if sa.exists():
data = sa.first().get_analysis_data_from_file()
if data is not None:
vector_raw = data[config_options['vector_property_name']]
if vector_raw is not None:
return vector_raw
return None

0 comments on commit 199ce0f

Please sign in to comment.