From 9a57ec8d205e5e9561dcad4c55c2284c95d16854 Mon Sep 17 00:00:00 2001 From: James Mathews Date: Fri, 28 Jul 2023 16:48:52 -0400 Subject: [PATCH] Normalize database config file vs cursor etc (#184) * Start making db access normalized. * Refactor FeatureMatrixExtractor and all its internal dependencies to use cursors and remove config file pollution. * Update tests according to database config file vs cursor stuff. * Correct None handling in kwarg. * Ensure cursor still around for repeat use of feature matrix extractor object. * Update syntax for builtin type hints. * DB source behavior change. * Added db source state variable. * Separate logging step in extractor. * Fix up some type hint stuff. * Fix another type hint. * Fix some None typing for cursor. * Make type for stratification. * Update occurrence of Stratification. * Implement match case. --- spatialprofilingtoolbox/cggnn/scripts/run.py | 9 +- .../db/expressions_table_indexer.py | 10 +- .../db/feature_matrix_extractor.py | 186 ++++++++++----- .../db/scripts/retrieve_feature_matrices.py | 3 +- .../db/stratification_puller.py | 117 +++++----- .../scripts/cache_expressions_data_array.py | 16 +- .../workflow/common/sparse_matrix_puller.py | 214 ++++++++++-------- .../common/structure_centroids_puller.py | 56 ++--- .../workflow/phenotype_proximity/core.py | 4 +- .../workflow/phenotype_proximity/ondemand.py | 4 +- .../unit_tests/test_centroid_pulling.py | 10 +- .../test_feature_matrix_extraction.py | 9 +- .../unit_tests/test_stratification_pulling.py | 14 +- 13 files changed, 379 insertions(+), 273 deletions(-) diff --git a/spatialprofilingtoolbox/cggnn/scripts/run.py b/spatialprofilingtoolbox/cggnn/scripts/run.py index ff360f399..deb6387e9 100644 --- a/spatialprofilingtoolbox/cggnn/scripts/run.py +++ b/spatialprofilingtoolbox/cggnn/scripts/run.py @@ -1,6 +1,5 @@ "Run through the entire SPT CG-GNN pipeline using a local db config." from argparse import ArgumentParser -from typing import Dict, Tuple from os.path import join from pandas import DataFrame @@ -121,7 +120,7 @@ def parse_arguments(): return parser.parse_args() -def _create_cell_df(cell_dfs: Dict[str, DataFrame], feature_names: Dict[str, str]) -> DataFrame: +def _create_cell_df(cell_dfs: dict[str, DataFrame], feature_names: dict[str, str]) -> DataFrame: "Find chemical species, phenotypes, and locations and merge into a DataFrame." for specimen, df_specimen in cell_dfs.items(): @@ -140,7 +139,7 @@ def _create_cell_df(cell_dfs: Dict[str, DataFrame], feature_names: Dict[str, str def _create_label_df(df_assignments: DataFrame, - df_strata: DataFrame) -> Tuple[DataFrame, Dict[int, str]]: + df_strata: DataFrame) -> tuple[DataFrame, dict[int, str]]: """Get slide-level results.""" df = merge(df_assignments, df_strata, on='stratum identifier', how='left')[ ['specimen', 'subject diagnosed result']].rename( @@ -164,8 +163,8 @@ def save_importances(_args): if __name__ == "__main__": args = parse_arguments() - study_data: Dict[str, Dict] = FeatureMatrixExtractor.extract( - args.spt_db_config_location)[args.study] + extractor = FeatureMatrixExtractor(database_config_file=args.spt_db_config_location) + study_data: dict[str, dict] = extractor.extract(study=args.study) df_cell = _create_cell_df( {slide: data['dataframe'] diff --git a/spatialprofilingtoolbox/db/expressions_table_indexer.py b/spatialprofilingtoolbox/db/expressions_table_indexer.py index 104e27a8f..43aaf4dd9 100644 --- a/spatialprofilingtoolbox/db/expressions_table_indexer.py +++ b/spatialprofilingtoolbox/db/expressions_table_indexer.py @@ -14,11 +14,15 @@ def ensure_indexed_expressions_table(connection): ExpressionsTableIndexer.create_index(cursor) connection.commit() + @staticmethod + def expressions_table_is_indexed_cursor(cursor): + columns = ExpressionsTableIndexer.get_expression_quantification_columns(cursor) + return 'source_specimen' in columns + @staticmethod def expressions_table_is_indexed(connection): with connection.cursor() as cursor: - columns = ExpressionsTableIndexer.get_expression_quantification_columns(cursor) - return 'source_specimen' in columns + return ExpressionsTableIndexer.expressions_table_is_indexed_cursor(cursor) @staticmethod def get_expression_quantification_columns(cursor): @@ -31,7 +35,7 @@ def get_expression_quantification_columns(cursor): @staticmethod def create_index(cursor): - ETI = ExpressionsTableIndexer() + ETI = ExpressionsTableIndexer() #pylint: disable=invalid-name ExpressionsTableIndexer.log_current_indexes(cursor) logger.debug('Will create extra index column "source_specimen".') ETI.create_extra_column(cursor) diff --git a/spatialprofilingtoolbox/db/feature_matrix_extractor.py b/spatialprofilingtoolbox/db/feature_matrix_extractor.py index 64c34c620..a80d9316a 100644 --- a/spatialprofilingtoolbox/db/feature_matrix_extractor.py +++ b/spatialprofilingtoolbox/db/feature_matrix_extractor.py @@ -1,10 +1,13 @@ """ -Convenience provision of a feature matrix for each study, the data retrieved -from the SPT database. +Convenience provision of a feature matrix for each study, the data retrieved from the SPT database. """ import sys +from enum import Enum +from enum import auto +from typing import cast import pandas as pd +from psycopg2.extensions import cursor as Psycopg2Cursor from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.db.stratification_puller import StratificationPuller @@ -16,26 +19,89 @@ logger = colorized_logger(__name__) +class DBSource(Enum): + """Indicator of intended database source.""" + CURSOR = auto() + CONFIG_FILE = auto() + UNKNOWN = auto() + + class FeatureMatrixExtractor: """ - Pull from the database and create convenience bundle of feature matrices - and metadata. + Pull from the database and create convenience bundle of feature matrices and metadata. """ - @staticmethod - def extract(database_config_file, specimen: str=None, study: str=None, continuous_also=False): - E = FeatureMatrixExtractor - data_arrays = E.retrieve_expressions_from_database(database_config_file, - specimen=specimen, - study=study, - continuous_also=continuous_also) - centroid_coordinates = E.retrieve_structure_centroids_from_database(database_config_file, - specimen=specimen, - study=study) - stratification = E.retrieve_derivative_stratification_from_database(database_config_file) - study_component_lookup = E.retrieve_study_component_lookup(database_config_file) - merged = E.merge_dictionaries( - E.create_feature_matrices(data_arrays, centroid_coordinates), - E.create_channel_information(data_arrays), + + cursor: Psycopg2Cursor + database_config_file: str | None + db_source: DBSource + + def __init__(self, + cursor: Psycopg2Cursor | None=None, + database_config_file: str | None=None, + ): + self.cursor = cast(Psycopg2Cursor, cursor) + self.database_config_file = database_config_file + if cursor is not None: + self.db_source = DBSource.CURSOR + elif database_config_file is not None: + self.db_source = DBSource.CONFIG_FILE + else: + self.db_source = DBSource.UNKNOWN + self._report_on_arguments() + + def _report_on_arguments(self): + if self.cursor is None and self.database_config_file is None: + logger.error('Must supply either cursor or database_config_file.') + if self.cursor is not None and self.database_config_file is not None: + message = 'A cursor and database configuration file were both specified. Using the '\ + 'cursor.' + logger.warning(message) + + def extract(self, + specimen: str | None=None, + study: str | None=None, + continuous_also: bool=False, + ): + extraction = None + match self.db_source: + case DBSource.CURSOR: + extraction = self._extract( + specimen=specimen, + study=study, + continuous_also=continuous_also, + ) + case DBSource.CONFIG_FILE: + with DatabaseConnectionMaker(self.database_config_file) as dcm: + with dcm.get_connection().cursor() as cursor: + self.cursor = cursor + extraction = self._extract( + specimen=specimen, + study=study, + continuous_also=continuous_also, + ) + case DBSource.UNKNOWN: + logger.error('The database source can not be determined.') + return extraction + + def _extract(self, + specimen: str | None=None, + study: str | None=None, + continuous_also: bool=False, + ): + data_arrays = self._retrieve_expressions_from_database( + specimen=specimen, + study=study, + continuous_also=continuous_also, + ) + centroid_coordinates = self._retrieve_structure_centroids_from_database( + specimen=specimen, + study=study, + ) + stratification = self._retrieve_derivative_stratification_from_database() + study_component_lookup = self._retrieve_study_component_lookup() + merged = self._merge_dictionaries( + self._create_feature_matrices(data_arrays, centroid_coordinates), + self._create_channel_information(data_arrays), stratification, new_keys=['feature matrices','channel symbols by column name', 'sample cohorts'], study_component_lookup=study_component_lookup, @@ -57,52 +123,47 @@ def redact_dataframes(extraction): extraction[study_name]['sample cohorts']['assignments'] = None extraction[study_name]['sample cohorts']['strata'] = None - @staticmethod - def retrieve_expressions_from_database(database_config_file, specimen: str=None, - study: str=None, continuous_also=False): + def _retrieve_expressions_from_database(self, + specimen: str | None=None, + study: str | None=None, + continuous_also: bool=False, + ): logger.info('Retrieving expression data from database.') - with SparseMatrixPuller(database_config_file) as puller: - puller.pull(specimen=specimen, study=study, continuous_also=continuous_also) - data_arrays = puller.get_data_arrays() + puller = SparseMatrixPuller(self.cursor) + puller.pull(specimen=specimen, study=study, continuous_also=continuous_also) + data_arrays = puller.get_data_arrays() logger.info('Done retrieving expression data from database.') return data_arrays.get_studies() - @staticmethod - def retrieve_structure_centroids_from_database(database_config_file, specimen: str=None, - study: str=None): + def _retrieve_structure_centroids_from_database(self, + specimen: str | None=None, + study: str | None=None, + ): logger.info('Retrieving polygon centroids from shapefiles in database.') - with StructureCentroidsPuller(database_config_file) as puller: - puller.pull(specimen=specimen, study=study) - structure_centroids = puller.get_structure_centroids() + puller = StructureCentroidsPuller(self.cursor) + puller.pull(specimen=specimen, study=study) + structure_centroids = puller.get_structure_centroids() logger.info('Done retrieving centroids.') return structure_centroids.get_studies() - @staticmethod - def retrieve_derivative_stratification_from_database(database_config_file): + def _retrieve_derivative_stratification_from_database(self): logger.info('Retrieving stratification from database.') - with StratificationPuller(database_config_file=database_config_file) as puller: - puller.pull() - stratification = puller.get_stratification() + puller = StratificationPuller(self.cursor) + puller.pull() + stratification = puller.get_stratification() logger.info('Done retrieving stratification.') return stratification - @staticmethod - def retrieve_study_component_lookup(database_config_file): - with DatabaseConnectionMaker(database_config_file=database_config_file) as maker: - connection = maker.get_connection() - cursor = connection.cursor() - cursor.execute('SELECT * FROM study_component ; ') - rows = cursor.fetchall() - cursor.close() + def _retrieve_study_component_lookup(self): + self.cursor.execute('SELECT * FROM study_component ; ') + rows = self.cursor.fetchall() lookup = {} for row in rows: lookup[row[1]] = row[0] return lookup - @staticmethod - def create_feature_matrices(data_arrays, centroid_coordinates): - logger.info( - 'Creating feature matrices from binary data arrays and centroids.') + def _create_feature_matrices(self, data_arrays, centroid_coordinates): + logger.info('Creating feature matrices from binary data arrays and centroids.') matrices = {} for k, study_name in enumerate(sorted(list(data_arrays.keys()))): study = data_arrays[study_name] @@ -112,7 +173,7 @@ def create_feature_matrices(data_arrays, centroid_coordinates): expressions = study['data arrays by specimen'][specimen] number_channels = len(study['target index lookup']) rows = [ - FeatureMatrixExtractor.create_feature_matrix_row( + self._create_feature_matrix_row( centroid_coordinates[study_name][specimen][i], expressions[i], number_channels, @@ -144,40 +205,41 @@ def create_feature_matrices(data_arrays, centroid_coordinates): return matrices @staticmethod - def create_feature_matrix_row(centroid, binary, number_channels): + def _create_feature_matrix_row(centroid, binary, number_channels): template = '{0:0%sb}' % number_channels # pylint: disable=consider-using-f-string feature_vector = [int(value) for value in list(template.format(binary)[::-1])] return [centroid[0], centroid[1]] + feature_vector - @staticmethod - def create_channel_information(data_arrays): + def _create_channel_information(self, data_arrays): return { - study_name: FeatureMatrixExtractor.create_channel_information_for_study(study) + study_name: self._create_channel_information_for_study(study) for study_name, study in data_arrays.items() } - @staticmethod - def create_channel_information_for_study(study): + def _create_channel_information_for_study(self, study): logger.info('Aggregating channel information for one study.') - targets = {int(index): target for target, - index in study['target index lookup'].items()} - symbols = {target: symbol for symbol, - target in study['target by symbol'].items()} + targets = { + int(index): target + for target, index in study['target index lookup'].items() + } + symbols = { + target: symbol + for symbol, target in study['target by symbol'].items() + } logger.info('Done aggregating channel information.') return { f'F{i}': symbols[targets[i]] for i in sorted([int(index) for index in targets.keys()]) } - @staticmethod - def merge_dictionaries(*args, new_keys: list, study_component_lookup: dict): + def _merge_dictionaries(self, *args, new_keys: list, study_component_lookup: dict): if not len(args) == len(new_keys): logger.error( "Can not match up dictionaries to be merged with the list of key names to be " "issued for them.") sys.exit(1) - merged = {} + merged: dict = {} for i in range(len(new_keys)): for substudy, value in args[i].items(): merged[study_component_lookup[substudy]] = {} diff --git a/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py b/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py index 83cae39a3..cc1d4ee9f 100644 --- a/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py +++ b/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py @@ -41,7 +41,8 @@ except ModuleNotFoundError as e: SuggestExtrasException(e, 'db') - bundle = FeatureMatrixExtractor.extract(database_config_file) + extractor = FeatureMatrixExtractor(database_config_file=database_config_file) + bundle: dict = extractor.extract() for study_name, study in bundle.items(): for specimen, specimen_data in study['feature matrices']: diff --git a/spatialprofilingtoolbox/db/stratification_puller.py b/spatialprofilingtoolbox/db/stratification_puller.py index ecc5830e7..17a8ed638 100644 --- a/spatialprofilingtoolbox/db/stratification_puller.py +++ b/spatialprofilingtoolbox/db/stratification_puller.py @@ -1,67 +1,80 @@ """Retrieve outcome data for all studies.""" -import pandas as pd +from typing import cast + +from pandas import DataFrame +from psycopg2.extensions import cursor as Psycopg2Cursor -from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger logger = colorized_logger(__name__) +Stratification = dict[str, dict[str, DataFrame]] -class StratificationPuller(DatabaseConnectionMaker): +class StratificationPuller: """Retrieve sample cohort data for all studies.""" - def __init__(self, database_config_file): - super().__init__(database_config_file=database_config_file) + + cursor: Psycopg2Cursor + stratification: dict | None + + def __init__(self, cursor: Psycopg2Cursor): + self.cursor = cursor self.stratification = None - def pull(self): - self.stratification = self.retrieve_stratification() + def pull(self) -> None: + self.stratification = self._retrieve_stratification() - def get_stratification(self): - return self.stratification + def get_stratification(self) -> Stratification: + return cast(dict, self.stratification) - def retrieve_stratification(self): - study_names = self.get_study_names() - stratification = {} - with self.get_connection().cursor() as cursor: - for study_name in study_names: - cursor.execute(''' - SELECT - scp.study, - sample, - stratum_identifier, - local_temporal_position_indicator, - subject_diagnosed_condition, - subject_diagnosed_result - FROM - sample_strata - JOIN - specimen_collection_process scp ON sample=scp.specimen - JOIN - study_component sc ON sc.component_study=scp.study - WHERE - sc.primary_study=%s - ; - ''', (study_name,)) - rows = cursor.fetchall() - if len(rows) == 0: - continue - df = pd.DataFrame(rows, columns=['specimen collection study', - 'specimen', - 'stratum identifier', - 'local temporal position indicator', - 'subject diagnosed condition', 'subject diagnosed result']) - substudy_name = list(df['specimen collection study'])[0] - stratification[substudy_name] = {} - assignments_columns = ['specimen', 'stratum identifier'] - stratification[substudy_name]['assignments'] = df[assignments_columns] - metadata_columns = ['stratum identifier', 'local temporal position indicator', - 'subject diagnosed condition', 'subject diagnosed result'] - stratification[substudy_name]['strata'] = df[metadata_columns].drop_duplicates() + def _retrieve_stratification(self) -> Stratification: + study_names = self._get_study_names() + stratification: Stratification = {} + for study_name in study_names: + self.cursor.execute(''' + SELECT + scp.study, + sample, + stratum_identifier, + local_temporal_position_indicator, + subject_diagnosed_condition, + subject_diagnosed_result + FROM + sample_strata + JOIN + specimen_collection_process scp ON sample=scp.specimen + JOIN + study_component sc ON sc.component_study=scp.study + WHERE + sc.primary_study=%s + ; + ''', (study_name,)) + rows = self.cursor.fetchall() + if len(rows) == 0: + continue + columns = [ + 'specimen collection study', + 'specimen', + 'stratum identifier', + 'local temporal position indicator', + 'subject diagnosed condition', + 'subject diagnosed result', + ] + df = DataFrame(rows, columns=columns) + substudy_name = list(df['specimen collection study'])[0] + stratification[substudy_name] = {} + assignments_columns = ['specimen', 'stratum identifier'] + stratification[substudy_name]['assignments'] = df[assignments_columns] + metadata_columns = [ + 'stratum identifier', + 'local temporal position indicator', + 'subject diagnosed condition', + 'subject diagnosed result', + ] + stratification[substudy_name]['strata'] = df[metadata_columns].drop_duplicates() return stratification - def get_study_names(self): - with self.get_connection().cursor() as cursor: - cursor.execute('SELECT study_specifier FROM study ;') - rows = cursor.fetchall() + def _get_study_names(self) -> tuple[str, ...]: + self.cursor.execute('SELECT study_specifier FROM study ;') + rows = self.cursor.fetchall() study_names = [row[0] for row in rows] - return sorted(study_names) + return tuple(sorted(study_names)) diff --git a/spatialprofilingtoolbox/ondemand/scripts/cache_expressions_data_array.py b/spatialprofilingtoolbox/ondemand/scripts/cache_expressions_data_array.py index 18497f46f..86c0cb827 100644 --- a/spatialprofilingtoolbox/ondemand/scripts/cache_expressions_data_array.py +++ b/spatialprofilingtoolbox/ondemand/scripts/cache_expressions_data_array.py @@ -35,9 +35,11 @@ def main(): database_config_file = abspath(expanduser(database_config_file)) if not StructureCentroids.already_exists(getcwd()): - with StructureCentroidsPuller(database_config_file) as puller: - puller.pull() - puller.get_structure_centroids().write_to_file(getcwd()) + with DatabaseConnectionMaker(database_config_file) as dcm: + with dcm.get_connection().cursor() as cursor: + puller = StructureCentroidsPuller(cursor) + puller.pull() + puller.get_structure_centroids().write_to_file(getcwd()) else: logger.info('%s already exists, skipping shapefile pull.', CENTROIDS_FILENAME) @@ -54,10 +56,10 @@ def main(): with DatabaseConnectionMaker(database_config_file) as dcm: connection = dcm.get_connection() ExpressionsTableIndexer.ensure_indexed_expressions_table(connection) - - with SparseMatrixPuller(database_config_file) as puller: - puller.pull() - data_arrays = puller.get_data_arrays() + with connection.cursor() as cursor: + puller = SparseMatrixPuller(cursor) + puller.pull() + data_arrays = puller.get_data_arrays() writer = CompressedMatrixWriter() writer.write(data_arrays) diff --git a/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py b/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py index d3e320739..118727d38 100644 --- a/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py +++ b/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py @@ -1,9 +1,13 @@ """ -Retrieve the "feature matrix" for a given study from the database, and store -it in a special (in-memory) binary compressed format. +Retrieve the "feature matrix" for a given study from the database, and store it in a special +(in-memory) binary compressed format. """ + +from typing import cast + +from psycopg2.extensions import cursor as Psycopg2Cursor + from spatialprofilingtoolbox.db.expressions_table_indexer import ExpressionsTableIndexer -from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.workflow.common.logging.fractional_progress_reporter \ import FractionalProgressReporter from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger @@ -40,13 +44,13 @@ def get_studies(self): return self.studies def add_study_data( - self, - study_name, - data_arrays_by_specimen, - target_index_lookup, - target_by_symbol, - continuous_data_arrays_by_specimen=None, - ): + self, + study_name, + data_arrays_by_specimen, + target_index_lookup, + target_by_symbol, + continuous_data_arrays_by_specimen=None, + ): self.check_target_index_lookup(study_name, target_index_lookup) self.check_target_by_symbol(study_name, target_by_symbol) if not study_name in self.studies: @@ -66,11 +70,11 @@ def add_study_data( ) def add_more_data_arrays( - self, - study_name, - data_arrays_by_specimen, - continuous_data_arrays_by_specimen=None, - ): + self, + study_name, + data_arrays_by_specimen, + continuous_data_arrays_by_specimen=None, + ): for key, integers_list in data_arrays_by_specimen.items(): self.studies[study_name]['data arrays by specimen'][key] = integers_list if continuous_data_arrays_by_specimen is not None: @@ -96,41 +100,49 @@ def check_dicts_equal(dict1, dict2): raise ValueError(f'Dictionary values not equal: {value}, {dict2[key]}') -class SparseMatrixPuller(DatabaseConnectionMaker): +class SparseMatrixPuller: """"Get sparse matrix representation of cell x channel data in database.""" + + cursor: Psycopg2Cursor data_arrays: CompressedDataArrays - def __init__(self, database_config_file): - super().__init__(database_config_file=database_config_file) + def __init__(self, cursor: Psycopg2Cursor): + self.cursor = cursor - def pull(self, specimen: str=None, study: str=None, continuous_also=False): - self.data_arrays = self.retrieve_data_arrays( - specimen=specimen, study=study, continuous_also=continuous_also) + def pull(self, specimen: str | None=None, study: str | None=None, continuous_also: bool=False): + self.data_arrays = self._retrieve_data_arrays( + specimen=specimen, + study=study, + continuous_also=continuous_also, + ) def get_data_arrays(self): return self.data_arrays - def retrieve_data_arrays(self, - specimen: str=None, - study: str=None, - continuous_also=False, + def _retrieve_data_arrays(self, + specimen: str | None=None, + study: str | None=None, + continuous_also: bool=False, ) -> CompressedDataArrays: - study_names = self.get_study_names(self.get_connection(), study=study) + study_names = self._get_study_names(study=study) data_arrays = CompressedDataArrays() for study_name in study_names: - self.fill_data_arrays_for_study( - data_arrays, study_name, specimen=specimen, continuous_also=continuous_also) + self._fill_data_arrays_for_study( + data_arrays, + study_name, + specimen=specimen, + continuous_also=continuous_also, + ) return data_arrays - def fill_data_arrays_for_study( - self, - data_arrays, - study_name, - specimen: str=None, - continuous_also=False, - ): - specimens = self.get_pertinent_specimens(study_name, specimen=specimen) - target_by_symbol = self.get_target_by_symbol(study_name, self.get_connection()) + def _fill_data_arrays_for_study(self, + data_arrays: CompressedDataArrays, + study_name: str, + specimen: str | None=None, + continuous_also: bool=False, + ): + specimens = self._get_pertinent_specimens(study_name, specimen=specimen) + target_by_symbol = self._get_target_by_symbol(study_name) logger.debug('Pulling sparse entries for study "%s".', study_name) progress_reporter = FractionalProgressReporter( len(specimens), @@ -138,10 +150,9 @@ def fill_data_arrays_for_study( task_and_done_message=('pulling sparse entries from the study', None), logger=logger, ) - parse = self.parse_data_arrays_by_specimen + parse = self._parse_data_arrays_by_specimen for _specimen in specimens: - sparse_entries = self.get_sparse_entries( - self.get_connection(), + sparse_entries = self._get_sparse_entries( study_name, specimen=_specimen, ) @@ -161,65 +172,64 @@ def fill_data_arrays_for_study( progress_reporter.increment(iteration_details=_specimen) progress_reporter.done() - def get_pertinent_specimens(self, study_name, specimen: str=None): + def _get_pertinent_specimens(self, + study_name: str, + specimen: str | None=None, + ) -> tuple[str, ...]: if specimen is not None: - return [specimen] - with self.get_connection().cursor() as cursor: - cursor.execute(''' - SELECT sdmp.specimen - FROM specimen_data_measurement_process sdmp - WHERE sdmp.study=%s - ORDER BY sdmp.specimen - ; - ''', (study_name,)) - rows = cursor.fetchall() - return [row[0] for row in rows] + return (specimen,) + self.cursor.execute(''' + SELECT sdmp.specimen + FROM specimen_data_measurement_process sdmp + WHERE sdmp.study=%s + ORDER BY sdmp.specimen + ; + ''', (study_name,)) + rows = self.cursor.fetchall() + return tuple(cast(str, row[0]) for row in rows) - def get_study_names(self, connection, study=None): + def _get_study_names(self, study: str | None=None) -> tuple[str, ...]: if study is None: - with connection.cursor() as cursor: - cursor.execute('SELECT name FROM specimen_measurement_study ;') - rows = cursor.fetchall() + self.cursor.execute('SELECT name FROM specimen_measurement_study ;') + rows = self.cursor.fetchall() else: - with connection.cursor() as cursor: - cursor.execute(''' - SELECT sms.name FROM specimen_measurement_study sms - JOIN study_component sc ON sc.component_study=sms.name - WHERE sc.primary_study=%s - ; - ''', (study,)) - rows = cursor.fetchall() + self.cursor.execute(''' + SELECT sms.name FROM specimen_measurement_study sms + JOIN study_component sc ON sc.component_study=sms.name + WHERE sc.primary_study=%s + ; + ''', (study,)) + rows = self.cursor.fetchall() logger.info('Will pull feature matrices for studies:') - names = sorted([row[0] for row in rows]) + names = tuple(sorted([row[0] for row in rows])) for name in names: logger.info(' %s', name) return names - def get_sparse_entries(self, connection, study_name, specimen): - sparse_entries = [] + def _get_sparse_entries(self, study_name:str , specimen: str) -> list[tuple]: + sparse_entries: list[tuple] = [] number_log_messages = 0 - with connection.cursor() as cursor: - cursor.execute( - self.get_sparse_matrix_query_specimen_specific(), - (study_name, specimen), - ) - total = cursor.rowcount - while cursor.rownumber < total - 1: - current_number_stored = len(sparse_entries) - sparse_entries.extend(cursor.fetchmany(size=self.get_batch_size())) - logger.debug('Received %s entries from DB.', - len(sparse_entries) - current_number_stored) - number_log_messages = number_log_messages + 1 + self.cursor.execute( + self._get_sparse_matrix_query_specimen_specific(), + (study_name, specimen), + ) + total = self.cursor.rowcount + while self.cursor.rownumber < total - 1: + current_number_stored = len(sparse_entries) + sparse_entries.extend(self.cursor.fetchmany(size=self._get_batch_size())) + received = len(sparse_entries) - current_number_stored + logger.debug('Received %s entries from DB.', received) + number_log_messages = number_log_messages + 1 if number_log_messages > 1: logger.debug('Received %s sparse entries total from DB.', len(sparse_entries)) return sparse_entries - def get_sparse_matrix_query_specimen_specific(self): - if ExpressionsTableIndexer.expressions_table_is_indexed(self.get_connection()): + def _get_sparse_matrix_query_specimen_specific(self) -> str: + if ExpressionsTableIndexer.expressions_table_is_indexed_cursor(self.cursor): return self.sparse_entries_query_optimized() return self.sparse_entries_query_unoptimized() - def sparse_entries_query_optimized(self): + def sparse_entries_query_optimized(self) -> str: return ''' -- absorb/ignore first string formatting argument: %s SELECT @@ -234,7 +244,7 @@ def sparse_entries_query_optimized(self): ; ''' - def sparse_entries_query_unoptimized(self): + def sparse_entries_query_unoptimized(self) -> str: return ''' SELECT eq.histological_structure, @@ -252,11 +262,14 @@ def sparse_entries_query_unoptimized(self): ; ''' - def get_batch_size(self): + def _get_batch_size(self) -> int: return 10000000 - def parse_data_arrays_by_specimen(self, sparse_entries, continuous_also=False): - target_index_lookup = self.get_target_index_lookup(sparse_entries) + def _parse_data_arrays_by_specimen(self, + sparse_entries: list[tuple], + continuous_also: bool=False, + ): + target_index_lookup = self._get_target_index_lookup(sparse_entries) sparse_entries.sort(key=lambda x: (x[3], x[0])) data_arrays_by_specimen = {} continuous_data_arrays_by_specimen = {} @@ -276,7 +289,7 @@ def parse_data_arrays_by_specimen(self, sparse_entries, continuous_also=False): continuous_data_arrays_by_specimen[specimen] = zerovector else: continuous_data_arrays_by_specimen[specimen] = None - self.fill_data_array( + self._fill_data_array( data_arrays_by_specimen[specimen], buffer, target_index_lookup, @@ -290,18 +303,16 @@ def parse_data_arrays_by_specimen(self, sparse_entries, continuous_also=False): cell_count = 1 return data_arrays_by_specimen, target_index_lookup, continuous_data_arrays_by_specimen - def get_target_index_lookup(self, sparse_entries): - targets = set([]) - for i, entry in enumerate(sparse_entries): - targets.add(entry[1]) - targets = sorted(list(targets)) + def _get_target_index_lookup(self, sparse_entries: list[tuple]) -> dict[str, int]: + target_set = set(entry[1] for entry in sparse_entries) + targets = sorted(list(target_set)) lookup = { target: i for i, target in enumerate(targets) } return lookup - def get_target_by_symbol(self, study_name, connection): + def _get_target_by_symbol(self, study_name: str) -> dict[str, str]: query = ''' SELECT cs.identifier, cs.symbol FROM chemical_species cs @@ -309,18 +320,21 @@ def get_target_by_symbol(self, study_name, connection): WHERE bms.study=%s ; ''' - with connection.cursor() as cursor: - cursor.execute(query, (study_name,)) - rows = cursor.fetchall() + self.cursor.execute(query, (study_name,)) + rows = self.cursor.fetchall() if len(rows) != len(set(row[1] for row in rows)): - logger.error( - 'The symbols are not unique identifiers of the targets. The symbols are: %s', - [row[1] for row in rows]) + message = 'The symbols are not unique identifiers of the targets. The symbols are: %s' + logger.error(message, [row[1] for row in rows]) target_by_symbol = {row[1]: row[0] for row in rows} logger.debug('Target by symbol: %s', target_by_symbol) return target_by_symbol - def fill_data_array(self, data_array, entries, target_index_lookup, continuous_data_array=None): + def _fill_data_array(self, + data_array, + entries, + target_index_lookup: dict[str, int], + continuous_data_array=None, + ) -> None: structure_index = 0 for i, entry in enumerate(entries): if i > 0: diff --git a/spatialprofilingtoolbox/workflow/common/structure_centroids_puller.py b/spatialprofilingtoolbox/workflow/common/structure_centroids_puller.py index 2c0451427..c02d80108 100644 --- a/spatialprofilingtoolbox/workflow/common/structure_centroids_puller.py +++ b/spatialprofilingtoolbox/workflow/common/structure_centroids_puller.py @@ -1,9 +1,10 @@ """Retrieves positional information for all cells in the SPT database.""" import statistics +from psycopg2.extensions import cursor as Psycopg2Cursor + from spatialprofilingtoolbox.db.shapefile_polygon import extract_points from spatialprofilingtoolbox.workflow.common.structure_centroids import StructureCentroids -from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.workflow.common.logging.fractional_progress_reporter \ import FractionalProgressReporter from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger @@ -11,40 +12,44 @@ logger = colorized_logger(__name__) -class StructureCentroidsPuller(DatabaseConnectionMaker): +class StructureCentroidsPuller: """Retrieve positional information for all cells in single cell database.""" - def __init__(self, database_config_file: str | None=None): - super().__init__(database_config_file=database_config_file) + + cursor: Psycopg2Cursor + structure_centroids: StructureCentroids + + def __init__(self, cursor: Psycopg2Cursor): + self.cursor = cursor self.structure_centroids = StructureCentroids() def pull(self, specimen: str | None=None, study: str | None=None): - study_names = self.get_study_names(study=study) - cursor = self.get_connection().cursor() + study_names = self._get_study_names(study=study) for study_name in study_names: if specimen is None: - specimen_count = self.get_specimen_count(study_name, cursor) - cursor.execute(self.get_shapefiles_query(), (study_name,)) + specimen_count = self._get_specimen_count(study_name, self.cursor) + self.cursor.execute(self._get_shapefiles_query(), (study_name,)) else: specimen_count = 1 - cursor.execute(self.get_shapefiles_query_specimen_specific(), - (study_name, specimen)) - rows = cursor.fetchall() + self.cursor.execute( + self._get_shapefiles_query_specimen_specific(), + (study_name, specimen), + ) + rows = self.cursor.fetchall() if len(rows) == 0: continue self.structure_centroids.add_study_data( study_name, - self.create_study_data(rows, specimen_count, study_name) + self._create_study_data(rows, specimen_count, study_name) ) - cursor.close() - def get_specimen_count(self, study_name, cursor): + def _get_specimen_count(self, study_name, cursor): cursor.execute(''' SELECT COUNT(*) FROM specimen_data_measurement_process sdmp WHERE sdmp.study=%s ; ''', (study_name,)) return cursor.fetchall()[0][0] - def get_shapefiles_query(self): + def _get_shapefiles_query(self): return ''' SELECT hsi.histological_structure, @@ -61,7 +66,7 @@ def get_shapefiles_query(self): ; ''' - def get_shapefiles_query_specimen_specific(self): + def _get_shapefiles_query_specimen_specific(self): return ''' SELECT hsi.histological_structure, @@ -78,22 +83,21 @@ def get_shapefiles_query_specimen_specific(self): ; ''' - def get_study_names(self, study=None): - cursor = self.get_connection().cursor() + def _get_study_names(self, study: str | None=None): if study is None: - cursor.execute('SELECT name FROM specimen_measurement_study ;') - rows = cursor.fetchall() + self.cursor.execute('SELECT name FROM specimen_measurement_study ;') + rows = self.cursor.fetchall() else: - cursor.execute(''' + self.cursor.execute(''' SELECT sms.name FROM specimen_measurement_study sms JOIN study_component sc ON sc.component_study=sms.name WHERE sc.primary_study=%s ; ''', (study,)) - rows = cursor.fetchall() + rows = self.cursor.fetchall() return sorted([row[0] for row in rows]) - def create_study_data(self, rows, specimen_count, study): + def _create_study_data(self, rows, specimen_count, study): study_data = {} field = {'structure': 0, 'specimen': 1, 'base64_contents': 2} current_specimen = rows[0][field['specimen']] @@ -110,19 +114,19 @@ def create_study_data(self, rows, specimen_count, study): progress_reporter.increment(iteration_details=current_specimen) current_specimen = row[field['specimen']] specimen_centroids = [] - specimen_centroids.append(self.compute_centroid( + specimen_centroids.append(self._compute_centroid( extract_points(row[field['base64_contents']]) )) progress_reporter.done() study_data[current_specimen] = specimen_centroids return study_data - def compute_centroid(self, points): + def _compute_centroid(self, points): nonrepeating_points = points[0:(len(points)-1)] return ( statistics.mean([point[0] for point in nonrepeating_points]), statistics.mean([point[1] for point in nonrepeating_points]), ) - def get_structure_centroids(self): + def get_structure_centroids(self) -> StructureCentroids: return self.structure_centroids diff --git a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py index 907b683a5..c8f10eb19 100644 --- a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py +++ b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py @@ -76,8 +76,8 @@ def log_job_info(self): def calculate_proximity(self): self.timer.record_timepoint('Start pulling data for one sample.') - bundle = FeatureMatrixExtractor.extract(database_config_file=self.database_config_file, - specimen=self.sample_identifier) + extractor = FeatureMatrixExtractor(database_config_file=self.database_config_file) + bundle: dict = extractor.extract(specimen=self.sample_identifier) self.timer.record_timepoint('Finished pulling data for one sample.') study_name = list(bundle.keys())[0] _, sample = list(bundle[study_name]['feature matrices'].items())[0] diff --git a/spatialprofilingtoolbox/workflow/phenotype_proximity/ondemand.py b/spatialprofilingtoolbox/workflow/phenotype_proximity/ondemand.py index 8080e18d5..6134f2585 100644 --- a/spatialprofilingtoolbox/workflow/phenotype_proximity/ondemand.py +++ b/spatialprofilingtoolbox/workflow/phenotype_proximity/ondemand.py @@ -24,8 +24,8 @@ def __init__(self, study, database_config_file): logger.info( 'Start pulling feature matrix data for proximity on-demand calculator, study %s.', study) - bundle = FeatureMatrixExtractor.extract(database_config_file=database_config_file, - study=study) + extractor = FeatureMatrixExtractor(database_config_file=database_config_file) + bundle: dict = extractor.extract(study=study) logger.info('Finished pulling data for %s.', study) for identifier, sample in list(bundle[study]['feature matrices'].items()): diff --git a/test/workflow/unit_tests/test_centroid_pulling.py b/test/workflow/unit_tests/test_centroid_pulling.py index 59f503161..380b825a3 100644 --- a/test/workflow/unit_tests/test_centroid_pulling.py +++ b/test/workflow/unit_tests/test_centroid_pulling.py @@ -1,12 +1,16 @@ +"""Test pulling out of centroids of each structure (cell) in the database""" import sys +from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.workflow.common.structure_centroids_puller import \ StructureCentroidsPuller if __name__ == '__main__': - with StructureCentroidsPuller(database_config_file='../db/.spt_db.config.container') as puller: - puller.pull() - structure_centroids = puller.get_structure_centroids() + with DatabaseConnectionMaker(database_config_file='../db/.spt_db.config.container') as dcm: + with dcm.get_connection().cursor() as cursor: + puller = StructureCentroidsPuller(cursor) + puller.pull() + structure_centroids = puller.get_structure_centroids() for study_name, study in structure_centroids.studies.items(): if study.keys() != set(['lesion 0_1', 'lesion 6_1']): diff --git a/test/workflow/unit_tests/test_feature_matrix_extraction.py b/test/workflow/unit_tests/test_feature_matrix_extraction.py index 7f7221cbc..adea04f6b 100644 --- a/test/workflow/unit_tests/test_feature_matrix_extraction.py +++ b/test/workflow/unit_tests/test_feature_matrix_extraction.py @@ -137,7 +137,8 @@ def test_stratification(study): if __name__ == '__main__': - matrix_bundle = FeatureMatrixExtractor.extract('../db/.spt_db.config.container') + extractor = FeatureMatrixExtractor(database_config_file='../db/.spt_db.config.container') + matrix_bundle = extractor.extract() test_study = get_study(matrix_bundle) test_sample_set(test_study) test_feature_matrix_schemas(test_study) @@ -146,16 +147,14 @@ def test_stratification(study): test_expression_vectors(test_study) test_stratification(test_study) - one_sample_bundle = FeatureMatrixExtractor.extract('../db/.spt_db.config.container', - specimen='lesion 6_1') + one_sample_bundle = extractor.extract(specimen='lesion 6_1') one_sample_study = get_study(one_sample_bundle) test_one_sample_set(one_sample_study) test_feature_matrix_schemas(one_sample_study) test_channels(one_sample_study) test_expression_vectors(one_sample_study) - one_sample_bundle_continuous = FeatureMatrixExtractor.extract('../db/.spt_db.config.container', - specimen='lesion 6_1', continuous_also=True) + one_sample_bundle_continuous = extractor.extract(specimen='lesion 6_1', continuous_also=True) one_sample_study_continuous = get_study(one_sample_bundle_continuous) test_one_sample_set(one_sample_study_continuous) test_feature_matrix_schemas(one_sample_study_continuous) diff --git a/test/workflow/unit_tests/test_stratification_pulling.py b/test/workflow/unit_tests/test_stratification_pulling.py index f96155559..e4e948c99 100644 --- a/test/workflow/unit_tests/test_stratification_pulling.py +++ b/test/workflow/unit_tests/test_stratification_pulling.py @@ -1,15 +1,19 @@ +"""Test pulling out of stratification for cohorts from database.""" import pandas as pd +from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.db.stratification_puller import StratificationPuller if __name__ == '__main__': - with StratificationPuller(database_config_file='../db/.spt_db.config.container') as puller: - puller.pull() - stratification = puller.get_stratification() + with DatabaseConnectionMaker(database_config_file='../db/.spt_db.config.container') as dcm: + with dcm.get_connection().cursor() as cursor: + puller = StratificationPuller(cursor) + puller.pull() + stratification = puller.get_stratification() - expected_assignments = pd.read_csv('unit_tests/expected_stratification_assignments.tsv', - sep='\t', dtype=object) + filename = 'unit_tests/expected_stratification_assignments.tsv' + expected_assignments = pd.read_csv(filename, sep='\t', dtype=object) assignment_rows = set(tuple(list(row)) for _, row in expected_assignments.iterrows()) expected_strata = pd.read_csv('unit_tests/expected_strata.tsv', sep='\t', dtype=object) stratum_rows = set(tuple(list(row)) for _, row in expected_strata.iterrows())