diff --git a/CHANGELOG.md b/CHANGELOG.md index 92b4c5240..10c759edd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Changed - Migrate docs from `https://docs.datajoint.org/python` to `https://datajoint.com/docs/core/datajoint-python` - Fixed - Updated set_password to work on MySQL 8 - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106) - Added - Missing tests for set_password - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106) +- Changed - Returning success count after the .populate() call - PR [#1050](https://github.com/datajoint/datajoint-python/pull/1050) ### 0.14.1 -- Jun 02, 2023 - Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 1efd557ff..ccd436554 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -180,6 +180,9 @@ def populate( to be passed down to each ``make()`` call. Computation arguments should be specified within the pipeline e.g. using a `dj.Lookup` table. :type make_kwargs: dict, optional + :return: a dict with two keys + "success_count": the count of successful ``make()`` calls in this ``populate()`` call + "error_list": the error list that is filled if `suppress_errors` is True """ if self.connection.in_transaction: raise DataJointError("Populate cannot be called during a transaction.") @@ -222,49 +225,62 @@ def handler(signum, frame): keys = keys[:max_calls] nkeys = len(keys) - if not nkeys: - return - - processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _) error_list = [] - populate_kwargs = dict( - suppress_errors=suppress_errors, - return_exception_objects=return_exception_objects, - make_kwargs=make_kwargs, - ) + success_list = [] - if processes == 1: - for key in ( - tqdm(keys, desc=self.__class__.__name__) if display_progress else keys - ): - error = self._populate1(key, jobs, **populate_kwargs) - if error is not None: - error_list.append(error) - else: - # spawn multiple processes - self.connection.close() # disconnect parent process from MySQL server - del self.connection._conn.ctx # SSLContext is not pickleable - with mp.Pool( - processes, _initialize_populate, (self, jobs, populate_kwargs) - ) as pool, ( - tqdm(desc="Processes: ", total=nkeys) - if display_progress - else contextlib.nullcontext() - ) as progress_bar: - for error in pool.imap(_call_populate1, keys, chunksize=1): - if error is not None: - error_list.append(error) - if display_progress: - progress_bar.update() - self.connection.connect() # reconnect parent process to MySQL server + if nkeys: + processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _) + + populate_kwargs = dict( + suppress_errors=suppress_errors, + return_exception_objects=return_exception_objects, + make_kwargs=make_kwargs, + ) + + if processes == 1: + for key in ( + tqdm(keys, desc=self.__class__.__name__) + if display_progress + else keys + ): + status = self._populate1(key, jobs, **populate_kwargs) + if status is True: + success_list.append(1) + elif isinstance(status, tuple): + error_list.append(status) + else: + assert status is False + else: + # spawn multiple processes + self.connection.close() # disconnect parent process from MySQL server + del self.connection._conn.ctx # SSLContext is not pickleable + with mp.Pool( + processes, _initialize_populate, (self, jobs, populate_kwargs) + ) as pool, ( + tqdm(desc="Processes: ", total=nkeys) + if display_progress + else contextlib.nullcontext() + ) as progress_bar: + for status in pool.imap(_call_populate1, keys, chunksize=1): + if status is True: + success_list.append(1) + elif isinstance(status, tuple): + error_list.append(status) + else: + assert status is False + if display_progress: + progress_bar.update() + self.connection.connect() # reconnect parent process to MySQL server # restore original signal handler: if reserve_jobs: signal.signal(signal.SIGTERM, old_handler) - if suppress_errors: - return error_list + return { + "success_count": sum(success_list), + "error_list": error_list, + } def _populate1( self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None @@ -275,55 +291,60 @@ def _populate1( :param key: dict specifying job to populate :param suppress_errors: bool if errors should be suppressed and returned :param return_exception_objects: if True, errors must be returned as objects - :return: (key, error) when suppress_errors=True, otherwise None + :return: (key, error) when suppress_errors=True, + True if successfully invoke one `make()` call, otherwise False """ make = self._make_tuples if hasattr(self, "_make_tuples") else self.make - if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)): - self.connection.start_transaction() - if key in self.target: # already populated + if jobs is not None and not jobs.reserve( + self.target.table_name, self._job_key(key) + ): + return False + + self.connection.start_transaction() + if key in self.target: # already populated + self.connection.cancel_transaction() + if jobs is not None: + jobs.complete(self.target.table_name, self._job_key(key)) + return False + + logger.debug(f"Making {key} -> {self.target.full_table_name}") + self.__class__._allow_insert = True + try: + make(dict(key), **(make_kwargs or {})) + except (KeyboardInterrupt, SystemExit, Exception) as error: + try: self.connection.cancel_transaction() - if jobs is not None: - jobs.complete(self.target.table_name, self._job_key(key)) + except LostConnectionError: + pass + error_message = "{exception}{msg}".format( + exception=error.__class__.__name__, + msg=": " + str(error) if str(error) else "", + ) + logger.debug( + f"Error making {key} -> {self.target.full_table_name} - {error_message}" + ) + if jobs is not None: + # show error name and error message (if any) + jobs.error( + self.target.table_name, + self._job_key(key), + error_message=error_message, + error_stack=traceback.format_exc(), + ) + if not suppress_errors or isinstance(error, SystemExit): + raise else: - logger.debug(f"Making {key} -> {self.target.full_table_name}") - self.__class__._allow_insert = True - try: - make(dict(key), **(make_kwargs or {})) - except (KeyboardInterrupt, SystemExit, Exception) as error: - try: - self.connection.cancel_transaction() - except LostConnectionError: - pass - error_message = "{exception}{msg}".format( - exception=error.__class__.__name__, - msg=": " + str(error) if str(error) else "", - ) - logger.debug( - f"Error making {key} -> {self.target.full_table_name} - {error_message}" - ) - if jobs is not None: - # show error name and error message (if any) - jobs.error( - self.target.table_name, - self._job_key(key), - error_message=error_message, - error_stack=traceback.format_exc(), - ) - if not suppress_errors or isinstance(error, SystemExit): - raise - else: - logger.error(error) - return key, error if return_exception_objects else error_message - else: - self.connection.commit_transaction() - logger.debug( - f"Success making {key} -> {self.target.full_table_name}" - ) - if jobs is not None: - jobs.complete(self.target.table_name, self._job_key(key)) - finally: - self.__class__._allow_insert = False + logger.error(error) + return key, error if return_exception_objects else error_message + else: + self.connection.commit_transaction() + logger.debug(f"Success making {key} -> {self.target.full_table_name}") + if jobs is not None: + jobs.complete(self.target.table_name, self._job_key(key)) + return True + finally: + self.__class__._allow_insert = False def progress(self, *restrictions, display=False): """ diff --git a/tests_old/test_autopopulate.py b/tests_old/test_autopopulate.py index bc0c9bb18..7a0a58e39 100644 --- a/tests_old/test_autopopulate.py +++ b/tests_old/test_autopopulate.py @@ -53,6 +53,23 @@ def test_populate(self): assert_true(self.ephys) assert_true(self.channel) + def test_populate_with_success_count(self): + # test simple populate + assert_true(self.subject, "root tables are empty") + assert_false(self.experiment, "table already filled?") + ret = self.experiment.populate() + success_count = ret["success_count"] + assert_equal(len(self.experiment.key_source & self.experiment), success_count) + + # test restricted populate + assert_false(self.trial, "table already filled?") + restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0] + d = self.trial.connection.dependencies + d.load() + ret = self.trial.populate(restriction, suppress_errors=True) + success_count = ret["success_count"] + assert_equal(len(self.trial.key_source & self.trial), success_count) + def test_populate_exclude_error_and_ignore_jobs(self): # test simple populate assert_true(self.subject, "root tables are empty")