From 1bc914ab71caa254ccc1e4c9c259fdd3ba4129d8 Mon Sep 17 00:00:00 2001 From: Edward Hope-Morley Date: Thu, 30 Jan 2025 20:47:06 +0000 Subject: [PATCH] performance improvements --- requirements.txt | 1 + searchkit/result.py | 15 +-- searchkit/search.py | 205 +++++++++++++++++++++++++------------- searchkit/task.py | 35 ++++--- tests/unit/test_search.py | 14 +-- 5 files changed, 170 insertions(+), 100 deletions(-) diff --git a/requirements.txt b/requirements.txt index a063978..3fa14b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ fasteners +psutil diff --git a/searchkit/result.py b/searchkit/result.py index f40bc38..7a7388e 100644 --- a/searchkit/result.py +++ b/searchkit/result.py @@ -57,12 +57,15 @@ def get(self, field): def __iter__(self): """ Only return part values when iterating over this object. """ for part in self.data: - yield self.results_store[part[self.PART_OFFSET_VALUE]] + yield self.results_store.get(part[self.PART_OFFSET_VALUE]) def __repr__(self): - r_list = [f"{rp[self.PART_OFFSET_IDX]}=" - f"'{self.results_store[rp[self.PART_OFFSET_VALUE]]}'" - for rp in self.data] + if self.results_store is None: + r_list = [] + else: + r_list = [f"{rp[self.PART_OFFSET_IDX]}=" + f"'{self.results_store.get(rp[self.PART_OFFSET_VALUE])}'" + for rp in self.data] return (f"ln:{self.linenumber} {', '.join(r_list)} " f"(section={self.section_id})") @@ -111,7 +114,7 @@ def tag(self): if idx is None: return None - return self.results_store.tag_store[idx] + return self.results_store.get(idx) @property def sequence_id(self): @@ -119,7 +122,7 @@ def sequence_id(self): if idx is None: return None - return self.results_store.sequence_id_store[idx] + return self.results_store.get(idx) def register_results_store(self, store): """ diff --git a/searchkit/search.py b/searchkit/search.py index 41500ad..ff9d1dd 100755 --- a/searchkit/search.py +++ b/searchkit/search.py @@ -17,8 +17,11 @@ import subprocess import threading import time -from collections import namedtuple, UserDict, UserList +from collections import namedtuple, UserDict +from datetime import datetime, timedelta +from functools import cached_property +import psutil from searchkit.log import log from searchkit.constraints import CouldNotApplyConstraint from searchkit.exception import FileSearchException @@ -32,6 +35,7 @@ RESULTS_QUEUE_TIMEOUT = 60 MAX_QUEUE_RETRIES = 10 +RESULTS_LOCK = multiprocessing.Lock() RS_LOCK = multiprocessing.Lock() NUM_BUFFERED_RESULTS = 100 @@ -44,36 +48,31 @@ def _rs_locked_inner(*args, **kwargs): return _rs_locked_inner -class ResultStoreBase(UserList): +class ResultStoreBase(UserDict): """ This class is used to de-duplicate values stored in search results such that allowing their reference to be saved in the result for later lookup. """ - def __init__(self): + def __init__(self, f_allocator=None): super().__init__() - self.counters = {} - self.value_store = self.data = [] - self.tag_store = [] - self.sequence_id_store = [] + self.counter = {} + self.data = {} + self.value_store = {} + self.tag_store = {} + self.sequence_id_store = {} + self.f_allocator = f_allocator or self._get_next_available - def __getitem__(self, result_id): - if result_id >= len(self.value_store): - return None - - return self.value_store[result_id] - - @property - def parts_deduped(self): - counters = self.counters.values() - return sum(counters) - len(counters) + def _get_next_available(self, value): + if value in self.counter: + return self.counter[value] - @property - def parts_non_deduped(self): - return len(self.value_store) + current = len(self.counter) + self.counter[value] = current + # log.debug(f"{os.getpid()} allocated {current}") + return current - @staticmethod - def _get_store_index(value, store): + def _get_store_index(self, value, store, idx=None): """ Add value to the provided store and return its position. If the value is None do not save in the store and return None. @@ -85,10 +84,17 @@ def _get_store_index(value, store): return None if value in store: - return store.index(value) + return store[value] - store.append(value) - return len(store) - 1 + if idx is None: + idx = self.f_allocator(value) + + store[value] = idx + self.data[idx] = value + return idx + + def sync(self): + """ Only required for parallel. """ def add(self, tag, sequence_id, value): """ @@ -104,13 +110,6 @@ def add(self, tag, sequence_id, value): @param value: search result value """ value_idx = self._get_store_index(value, self.value_store) - if value_idx is not None: - # increment global counter - if value_idx not in self.counters: - self.counters[value_idx] = 1 - else: - self.counters[value_idx] += 1 - tag_idx = self._get_store_index(tag, self.tag_store) sequence_id_idx = self._get_store_index(sequence_id, self.sequence_id_store) @@ -127,28 +126,33 @@ class ResultStoreParallel(ResultStoreBase): def __init__(self, mgr): super().__init__() # Replace super attributes with MP-safe equivalents - self.counters = mgr.dict() - self.value_store = self.data = mgr.list() - self.tag_store = mgr.list() - self.sequence_id_store = mgr.list() + self.counter = mgr.dict() + self.data = mgr.dict() + self.value_store = mgr.dict() + self.tag_store = mgr.dict() + self.sequence_id_store = mgr.dict() - @rs_locked - def __getitem__(self, result_id): - return super().__getitem__(result_id) + def _get_next_available(self, value): + with RS_LOCK: + return super()._get_next_available(value) + + @cached_property + def local(self): + return ResultStoreSimple(f_allocator=self._get_next_available) - @rs_locked def add(self, *args, **kwargs): - return super().add(*args, **kwargs) + return self.local.add(*args, **kwargs) - @property @rs_locked - def parts_deduped(self): - return super().parts_deduped + def sync(self): + for value, idx in self.local.value_store.items(): + self._get_store_index(value, self.value_store, idx=idx) - @property - @rs_locked - def parts_non_deduped(self): - return super().parts_non_deduped + for value, idx in self.local.tag_store.items(): + self._get_store_index(value, self.tag_store, idx=idx) + + for value, idx in self.local.sequence_id_store.items(): + self._get_store_index(value, self.sequence_id_store, idx=idx) @rs_locked def unproxy_results(self): @@ -160,7 +164,6 @@ def unproxy_results(self): self.value_store = self.data = copy.deepcopy(self.data) self.tag_store = copy.deepcopy(self.tag_store) self.sequence_id_store = copy.deepcopy(self.sequence_id_store) - self.counters = self.counters.copy() class ResultFieldInfo(UserDict): @@ -663,31 +666,71 @@ def stats(self): return self._stats @staticmethod - def _get_results(results, results_queue, event, stats): + def _get_info(results, results_store, event, stats): """ Collect results from all search task processes. @param results: SearchResultsCollection object. - @param results_queue: results queue used for this search session. @param event: event object used to notify this thread to stop. @param stats: SearchTaskStats object """ + proc = psutil.Process() + while True: + then = datetime.now() + if event.is_set(): + log.debug("exiting info thread") + break + + then = None + with RESULTS_LOCK: + if not then or datetime.now() == then + timedelta(seconds=5): + with RS_LOCK: + allocations = len(results_store.counter) + + rss = int(proc.memory_info().rss / 1024 ** 2) + log.debug("total %s results received (rss=%sM, " + "store_allocations=%s), " + "%s/%s jobs completed - " + "waiting for more", len(results), rss, + allocations, + stats['jobs_completed'], stats['total_jobs']) + + then = datetime.now() + + time.sleep(0.1) + + @staticmethod + def _get_results(results, results_queue, event): + """ + Collect results from all search task processes. + + @param results: SearchResultsCollection object. + @param results_queue: results queue used for this search session. + @param event: event object used to notify this thread to stop. + """ log.debug("fetching results from worker queues") while True: if not results_queue.empty(): - results.add(results_queue.get()) + _results = results_queue.get() + # log.debug(f"received {sys.getsizeof(_results) / 1024 ** 2} " + # "Mbytes") + with RESULTS_LOCK: + for r in _results: + results.add(r) + + # if len(results) % 1000 == 0: + # log.debug("received %s results", len(results)) elif event.is_set(): log.debug("exiting results thread") break else: - log.debug("total %s results received, %s/%s jobs completed - " - "waiting for more", len(results), - stats['jobs_completed'], stats['total_jobs']) # yield time.sleep(0.1) - log.debug("stopped fetching results (total received=%s)", len(results)) + with RESULTS_LOCK: + log.debug("stopped fetching results (total received=%s)", + len(results)) @staticmethod def _purge_results(results, results_queue, expected): @@ -704,11 +747,15 @@ def _purge_results(results, results_queue, expected): while True: if not results_queue.empty(): - results.add(results_queue.get()) + _results = results_queue.get() + for r in _results: + results.add(r) + elif expected > len(results): try: - r = results_queue.get(timeout=RESULTS_QUEUE_TIMEOUT) - results.add(r) + _results = results_queue.get(timeout=RESULTS_QUEUE_TIMEOUT) + for r in _results: + results.add(r) except queue.Empty: log.info("timeout waiting > %s secs to receive results - " "expected=%s, actual=%s", RESULTS_QUEUE_TIMEOUT, @@ -719,20 +766,28 @@ def _purge_results(results, results_queue, expected): log.debug("stopped purging results (total received=%s)", len(results)) - def _create_results_thread(self, results, results_queue, stats): + def _create_results_thread(self, results, results_queue): log.debug("creating results queue consumer thread") event = threading.Event() event.clear() t = threading.Thread(target=self._get_results, - args=[results, results_queue, event, stats]) + args=[results, results_queue, event]) + return t, event + + def _create_info_thread(self, results, results_store, stats): + log.debug("creating info thread") + event = threading.Event() + event.clear() + t = threading.Thread(target=self._get_info, + args=[results, results_store, event, stats]) return t, event @staticmethod - def _stop_results_thread(thread, event): - log.debug("joining/stopping queue consumer thread") + def _stop_thread(ttype, thread, event): + log.debug("joining/stopping queue %s thread", ttype) event.set() thread.join() - log.debug("consumer thread stopped successfully") + log.debug("%s thread stopped successfully", ttype) @staticmethod def _ensure_worker_processes_killed(): @@ -792,9 +847,12 @@ def _run_mp(self, mgr, results, results_store): # noqa,pylint: disable=too-many @param results: SearchResultsCollection object """ results_queue = mgr.Queue() + info_thread, info_event = self._create_info_thread(results, + results_store, + self.stats) results_thread, event = self._create_results_thread(results, - results_queue, - self.stats) + results_queue) + info_thread_started = False results_thread_started = False results_manager = SearchTaskResultsManager( results_store, @@ -815,8 +873,10 @@ def _run_mp(self, mgr, results, results_store): # noqa,pylint: disable=too-many self.stats['total_jobs'] += 1 log.debug("filesearcher: syncing %s job(s)", len(jobs)) + info_thread.start() results_thread.start() results_thread_started = True + info_thread_started = True try: for future in concurrent.futures.as_completed(jobs): self.stats.update(future.result()) @@ -835,7 +895,9 @@ def _run_mp(self, mgr, results, results_store): # noqa,pylint: disable=too-many log.info("job for path '%s' still running when " "not expected to be", path) - self._stop_results_thread(results_thread, event) + self._stop_thread('info', info_thread, info_event) + info_thread = None + self._stop_thread('results', results_thread, event) results_thread = None log.debug("purging remaining results (expected=%s, " "remaining=%s)", self.stats['results'], @@ -847,7 +909,10 @@ def _run_mp(self, mgr, results, results_store): # noqa,pylint: disable=too-many log.debug("terminating pool") finally: if results_thread is not None and results_thread_started: - self._stop_results_thread(results_thread, event) + self._stop_thread('results', results_thread, event) + + if info_thread is not None and info_thread_started: + self._stop_thread('info', info_thread, info_event) def run(self): """ Run all searches. @@ -871,16 +936,12 @@ def run(self): rs = ResultStoreParallel(mgr) results = SearchResultsCollection(self.catalog, rs) self._run_mp(mgr, results, rs) - self.stats['parts_deduped'] = rs.parts_deduped - self.stats['parts_non_deduped'] = rs.parts_non_deduped rs.unproxy_results() else: log.debug("running searches (parallel=False)") rs = ResultStoreSimple() results = SearchResultsCollection(self.catalog, rs) self._run_single(results, rs) - self.stats['parts_deduped'] = rs.parts_deduped - self.stats['parts_non_deduped'] = rs.parts_non_deduped log.debug("filesearcher: completed (%s)", self.stats) return results diff --git a/searchkit/task.py b/searchkit/task.py index 3e4083f..eb50453 100644 --- a/searchkit/task.py +++ b/searchkit/task.py @@ -1,12 +1,12 @@ """ Search task implementations. """ import gzip -import multiprocessing import os import queue import uuid from functools import cached_property from collections import UserDict +import psutil from searchkit.log import log from searchkit.result import ( SearchResult, @@ -17,8 +17,7 @@ RESULTS_QUEUE_TIMEOUT = 60 MAX_QUEUE_RETRIES = 10 -RS_LOCK = multiprocessing.Lock() -NUM_BUFFERED_RESULTS = 100 +NUM_BUFFERED_RESULTS = 100000 class SearchTaskError(Exception): @@ -76,6 +75,7 @@ def __init__(self, info, constraints_manager, results_manager, @param results_manager: SearchTaskResultsManager object @param decode_errors: unicode decode error handling. """ + self.proc = None self.info = info self.stats = SearchTaskStats() self.constraints_manager = constraints_manager @@ -103,20 +103,22 @@ def search_defs(self): return alldefs - def put_result(self, result): - self.stats['results'] += 1 + def put_result(self, results): + self.stats['results'] += len(results) if self.results_manager.results_collection is not None: - self.results_manager.results_collection.add(result) + for result in results: + self.results_manager.results_collection.add(result) + return max_tries = MAX_QUEUE_RETRIES while max_tries > 0: try: if max_tries == MAX_QUEUE_RETRIES: - self.results_manager.results_queue.put_nowait(result) + self.results_manager.results_queue.put_nowait(results) else: self.results_manager.results_queue.put( - result, + results, timeout=RESULTS_QUEUE_TIMEOUT) break @@ -137,12 +139,19 @@ def put_result(self, result): log.error("exceeded max number of retries (%s) to put results " "data on the queue", MAX_QUEUE_RETRIES) + def show_mem_usage(self, label): + if self.proc is None: + self.proc = psutil.Process() + log.debug('%s (rss=%sM)', label, + int(self.proc.memory_info().rss / 1024 ** 2)) + def _flush_results_buffer(self): # log.debug("flushing results buffer (%s)", len(self.buffered_results)) - for result in self.buffered_results: - self.put_result(result) + # self.show_mem_usage('before') + self.put_result(self.buffered_results) self.buffered_results = [] + # self.show_mem_usage('after') def _simple_search(self, search_def, line, ln): """ Perform a simple search on line. @@ -286,10 +295,10 @@ def _run_search(self, fd): for s, _runnable in self.search_defs.items()} ln = 0 # NOTE: line numbers start at 1 hence offset + 1 - for ln, line in enumerate(fd, start=offset + 1): + for ln, line in enumerate(fd, start=1): # This could be helpful to show progress for large files if ln % 100000 == 0: - log.debug("%s lines searched in %s", ln, fd.name) + self.show_mem_usage(f"{ln} lines searched in {fd.name}") self.stats['lines_searched'] += 1 line = line.decode("utf-8", **self.decode_kwargs) @@ -366,6 +375,8 @@ def execute(self): f"- {e}") raise FileSearchException(msg) from e + self.results_manager.results_store.sync() + log.debug("finished execution on path %s", path) return stats diff --git a/tests/unit/test_search.py b/tests/unit/test_search.py index 45650e3..562684f 100644 --- a/tests/unit/test_search.py +++ b/tests/unit/test_search.py @@ -356,9 +356,6 @@ def test_large_sequence_search(self): finally: shutil.rmtree(dtmp) - self.assertEqual(f.stats['parts_deduped'], 40037) - self.assertEqual(f.stats['parts_non_deduped'], 3) - self.assertEqual(len(results), 40040) self.assertEqual(len(results.find_by_tag('simple')), 20000) self.assertEqual(len(results.find_sequence_by_tag('myseq')), 20) @@ -924,15 +921,12 @@ def test_search_result_index(self): for val in ['foo', 'bar', 'foo']: sri.add('atag', None, val) - self.assertEqual(sri, ['foo', 'bar']) - self.assertEqual(sri.counters, {0: 2, 1: 1}) - self.assertEqual(sri.tag_store, ['atag']) + self.assertEqual(list(sri.values()), ['foo', 'atag', 'bar']) + self.assertEqual(list(sri.tag_store.values()), [1]) self.assertEqual(sri[0], 'foo') - self.assertEqual(sri[1], 'bar') - self.assertEqual(sri[2], None) - self.assertEqual(sri.parts_deduped, 1) - self.assertEqual(sri.parts_non_deduped, 2) + self.assertEqual(sri[2], 'bar') + self.assertEqual(sri.get(109238), None) def test_search_unicode_decode_w_error(self): f = FileSearcher()