diff --git a/build/build_scripts/check_dockerfiles_consistency.py b/build/build_scripts/check_dockerfiles_consistency.py index f1d92d4e5..e1571991d 100644 --- a/build/build_scripts/check_dockerfiles_consistency.py +++ b/build/build_scripts/check_dockerfiles_consistency.py @@ -10,7 +10,7 @@ def load_dockerfile(submodule): def check_exists(dependency, dockerfile): for line in dockerfile.split('\n'): - if re.search(f'RUN python3? -m pip install "?{dependency}"?$', line): + if re.search(f'RUN python3?(.11)? -m pip install "?{dependency}"?$', line): return True if re.search(dependency, line): print(f'Dependency "{dependency}" is mentioned in Dockerfile, but something isn\'t quite right with installation command.') diff --git a/build/build_scripts/create_pyproject.py b/build/build_scripts/create_pyproject.py index 7239fd712..e0a045ff4 100644 --- a/build/build_scripts/create_pyproject.py +++ b/build/build_scripts/create_pyproject.py @@ -2,7 +2,7 @@ import toml def validate_dependencies_all(project): - modules = ['apiserver', 'db', 'ondemand', 'workflow'] + modules = ['apiserver', 'cggnn', 'db', 'ondemand', 'workflow'] dependencies = set() for module in modules: dependencies = dependencies.union(set(project['project']['optional-dependencies'][module])) diff --git a/build/cggnn/Dockerfile b/build/cggnn/Dockerfile index 9a45858d6..739a7efb9 100644 --- a/build/cggnn/Dockerfile +++ b/build/cggnn/Dockerfile @@ -1,13 +1,21 @@ -FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime ENV DEBIAN_FRONTEND=noninteractive -RUN apt update && apt install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/* +RUN apt update && apt install -y gcc libpq-dev WORKDIR /usr/src/app -RUN python -m pip install dgl-cu116 dglgo -f https://data.dgl.ai/wheels/repo.html -RUN python -m pip install psycopg2==2.9.6 -RUN python -m pip install adiscstudies==0.11.0 -RUN python -m pip install numba==0.57.0 -RUN python -m pip install attrs==23.1.0 -RUN python -m pip install cg-gnn +RUN apt install software-properties-common -y +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt update +RUN apt install python3.11 -y +RUN apt install python3.11-dev -y +RUN apt install python3.11-venv -y +RUN apt install python3.11-distutils -y +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 && python3.11 -m ensurepip +RUN python3.11 -m pip install dgl-cu117 dglgo -f https://data.dgl.ai/wheels/repo.html +RUN python3.11 -m pip install psycopg2==2.9.6 +RUN python3.11 -m pip install adiscstudies==0.11.0 +RUN python3.11 -m pip install numba==0.57.0 +RUN python3.11 -m pip install attrs==23.1.0 +RUN python3.11 -m pip install cg-gnn ARG version ARG service_name ARG WHEEL_FILENAME @@ -15,4 +23,4 @@ LABEL version=$version LABEL service_name=$service_name ENV service_name $service_name COPY $WHEEL_FILENAME ./ -RUN python -m pip install "$WHEEL_FILENAME" +RUN python3.11 -m pip install "$WHEEL_FILENAME" diff --git a/pyproject.toml.unversioned b/pyproject.toml.unversioned index 688409c6c..9dc413808 100644 --- a/pyproject.toml.unversioned +++ b/pyproject.toml.unversioned @@ -41,6 +41,9 @@ apiserver = [ "Pillow==9.5.0", "pydantic==2.0.2" ] +cggnn = [ + "cg-gnn" +] db = [ "pandas==2.0.2", "pyshp==2.2.0", @@ -55,9 +58,6 @@ ondemand = [ "pydantic==2.0.2", "squidpy==1.3.0" ] -cggnn = [ - "cg-gnn" -] workflow = [ "matplotlib==3.7.1", "umap-learn==0.5.3", @@ -71,6 +71,7 @@ workflow = [ "Pillow==9.5.0" ] all = [ + "cg-gnn", "matplotlib==3.7.1", "umap-learn==0.5.3", "uvicorn>=0.15.0,<0.16.0", diff --git a/spatialprofilingtoolbox/cggnn/__init__.py b/spatialprofilingtoolbox/cggnn/__init__.py index 4be301fc3..d3dc32f9e 100644 --- a/spatialprofilingtoolbox/cggnn/__init__.py +++ b/spatialprofilingtoolbox/cggnn/__init__.py @@ -1,2 +1,4 @@ """Cell-graph graph neural network functionality.""" __version__ = '0.2.1' + +from spatialprofilingtoolbox.cggnn.extract import extract_cggnn_data diff --git a/spatialprofilingtoolbox/cggnn/extract.py b/spatialprofilingtoolbox/cggnn/extract.py new file mode 100644 index 000000000..be76310b4 --- /dev/null +++ b/spatialprofilingtoolbox/cggnn/extract.py @@ -0,0 +1,112 @@ +"""Extract information cg-gnn needs from SPT.""" + +from pandas import DataFrame, concat, merge # type: ignore +from numpy import sort # type: ignore + +from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor + + +def _create_cell_df(dfs_by_specimen: dict[str, DataFrame]) -> DataFrame: + """Find simple and complex phenotypes, and locations and merge into a DataFrame.""" + for specimen, df_specimen in dfs_by_specimen.items(): + df_specimen['specimen'] = specimen + + df = concat(dfs_by_specimen.values(), axis=0) + df.index.name = 'histological_structure' + # Reorder columns so it's specimen, xy, channels, and phenotypes + column_order = ['specimen', 'pixel x', 'pixel y'] + column_order.extend(df.columns[df.columns.str.startswith('C ')]) + column_order.extend(df.columns[df.columns.str.startswith('P ')]) + return df[column_order] + + +def _create_label_df( + df_assignments: DataFrame, + df_strata: DataFrame, + strata_to_use: list[int] | None, +) -> tuple[DataFrame, dict[int, str]]: + """Get slide-level results.""" + df_assignments = df_assignments.set_index('specimen') + df_strata = df_strata.set_index('stratum identifier') + df_strata = _filter_for_strata(strata_to_use, df_strata) + df_strata = _drop_unneeded_columns(df_strata) + df_strata = _compress_df(df_strata) + return _label(df_assignments, df_strata) + + +def _filter_for_strata(strata_to_use: list[int] | None, df_strata: DataFrame) -> DataFrame: + if strata_to_use is not None: + df_strata = df_strata.loc[sorted(strata_to_use)] + if df_strata.shape[0] < 2: + raise ValueError(f'Need at least 2 strata to classify, there are {df_strata.shape[0]}.') + return df_strata + + +def _drop_unneeded_columns(df_strata: DataFrame) -> DataFrame: + """Drop columns that have internally same contents.""" + for col in df_strata.columns.tolist(): + if df_strata[col].nunique() == 1: + df_strata = df_strata.drop(col, axis=1) + return df_strata + + +def _compress_df(df_strata: DataFrame) -> DataFrame: + """Compress remaining columns into a single string""" + df_strata['label'] = '(' + df_strata.iloc[:, 0].astype(str) + for i in range(1, df_strata.shape[1]): + df_strata['label'] += df_strata.iloc[:, i].astype(str) + df_strata['label'] += ')' + df_strata = df_strata[['label']] + return df_strata + + +def _label(df_assignments: DataFrame, df_strata: DataFrame) -> tuple[DataFrame, dict[int, str]]: + """Merge with specimen assignments, keeping only selected strata.""" + df = merge(df_assignments, df_strata, on='stratum identifier', how='inner')[['label']] + label_to_result = dict(enumerate(sort(df['label'].unique()))) + return df.replace({res: i for i, res in label_to_result.items()}), label_to_result + + +def extract_cggnn_data( + spt_db_config_location: str, + study: str, + strata_to_use: list[int] | None, +) -> tuple[DataFrame, DataFrame, dict[int, str]]: + """Extract information cg-gnn needs from SPT. + + Parameters + ---------- + spt_db_config_location : str + Location of the SPT DB config file. + study : str + Name of the study to query data for. + strata_to_use : list[int] | None + Specimen strata to use as labels, identified according to the "stratum identifier" in + `explore_classes`. This should be given as space separated integers. + If not provided, all strata will be used. + + Returns + ------- + df_cell: DataFrame + Rows are individual cells, indexed by an integer ID. + Column or column groups are, named and in order: + 1. The 'specimen' the cell is from + 2. Cell centroid positions 'pixel x' and 'pixel y' + 3. Channel expressions starting with 'C ' and followed by human-readable symbol text + 4. Phenotype expressions starting with 'P ' followed by human-readable symbol text + df_label: DataFrame + Rows are specimens, the sole column 'label' is its class label as an integer. + label_to_result_text: dict[int, str] + Mapping from class integer label to human-interpretable result text. + """ + extractor = FeatureMatrixExtractor(database_config_file=spt_db_config_location) + df_cell = _create_cell_df({ + slide: data.dataframe for slide, data in extractor.extract(study=study).items() + }) + cohorts = extractor.extract_cohorts(study) + df_label, label_to_result_text = _create_label_df( + cohorts['assignments'], + cohorts['strata'], + strata_to_use, + ) + return df_cell, df_label, label_to_result_text diff --git a/spatialprofilingtoolbox/cggnn/scripts/explore_classes.py b/spatialprofilingtoolbox/cggnn/scripts/explore_classes.py new file mode 100644 index 000000000..0367ad200 --- /dev/null +++ b/spatialprofilingtoolbox/cggnn/scripts/explore_classes.py @@ -0,0 +1,33 @@ +"""Report the different strata available to classify with.""" + +from argparse import ArgumentParser + +from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor + + +def parse_arguments(): + """Process command line arguments.""" + parser = ArgumentParser( + prog='spt cggnn explore_classes', + description='See the strata available to classify on.' + ) + parser.add_argument( + '--spt_db_config_location', + type=str, + help='Location of the SPT DB config file.', + required=True + ) + parser.add_argument( + '--study', + type=str, + help='Name of the study to query data for.', + required=True + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + extractor = FeatureMatrixExtractor(args.spt_db_config_location) + strata = extractor.extract_cohorts(study=args.study)['strata'] + print(strata.to_string()) diff --git a/spatialprofilingtoolbox/cggnn/scripts/extract.py b/spatialprofilingtoolbox/cggnn/scripts/extract.py new file mode 100644 index 000000000..12752e8ea --- /dev/null +++ b/spatialprofilingtoolbox/cggnn/scripts/extract.py @@ -0,0 +1,63 @@ +"""Extract information cg-gnn needs from SPT and save to file.""" + +from argparse import ArgumentParser +from os.path import join, exists +from json import dump + +from spatialprofilingtoolbox.cggnn import extract_cggnn_data + + +def parse_arguments(): + """Process command line arguments.""" + parser = ArgumentParser( + prog='spt cggnn extract', + description='Extract information cg-gnn needs from SPT and save to file.' + ) + parser.add_argument( + '--spt_db_config_location', + type=str, + help='Location of the SPT DB config file.', + required=True + ) + parser.add_argument( + '--study', + type=str, + help='Name of the study to query data for.', + required=True + ) + parser.add_argument( + '--strata', + nargs='+', + type=int, + help='Specimen strata to use as labels, identified according to the "stratum identifier" ' + 'in `explore_classes`. This should be given as space separated integers.\n' + 'If not provided, all strata will be used.', + required=False, + default=None + ) + parser.add_argument( + '--output_location', + type=str, + help='Directory to save extracted data to.', + required=True + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + df_cell, df_label, label_to_result = extract_cggnn_data( + args.spt_db_config_location, + args.study, + args.strata, + ) + + assert isinstance(args.output_location, str) + dict_filename = join(args.output_location, 'label_to_results.json') + cells_filename = join(args.output_location, 'cells.h5') + labels_filename = join(args.output_location, 'labels.h5') + if not (exists(dict_filename) and exists(cells_filename) and exists(labels_filename)): + df_cell.to_hdf(cells_filename, 'cells') + df_label.to_hdf(labels_filename, 'labels') + with open(dict_filename, 'w', encoding='utf-8') as f: + dump(label_to_result, f) diff --git a/spatialprofilingtoolbox/cggnn/scripts/run.py b/spatialprofilingtoolbox/cggnn/scripts/run.py index 06b618569..84292c491 100644 --- a/spatialprofilingtoolbox/cggnn/scripts/run.py +++ b/spatialprofilingtoolbox/cggnn/scripts/run.py @@ -1,25 +1,23 @@ -"Run through the entire SPT CG-GNN pipeline using a local db config." +"""Run through the entire SPT CG-GNN pipeline using a local db config.""" + from argparse import ArgumentParser from os.path import join -from typing import cast -from pandas import DataFrame from pandas import read_csv -from pandas import concat, merge -from numpy import sort -from spatialprofilingtoolbox import DatabaseConnectionMaker -from spatialprofilingtoolbox import DBCredentials +from spatialprofilingtoolbox.cggnn import extract_cggnn_data +from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker from spatialprofilingtoolbox.db.importance_score_transcriber import transcribe_importance -from spatialprofilingtoolbox.db.feature_matrix_extractor import ( - FeatureMatrixExtractor, - Bundle, -) +from spatialprofilingtoolbox.standalone_utilities.module_load_error import SuggestExtrasException +try: + from cggnn.run_all import run_with_dfs +except ModuleNotFoundError as e: + SuggestExtrasException(e, 'cggnn') from cggnn.run_all import run_with_dfs def parse_arguments(): - "Process command line arguments." + """Process command line arguments.""" parser = ArgumentParser( prog='spt cggnn run', description='Create cell graphs from SPT tables saved locally, train a graph neural ' @@ -120,77 +118,21 @@ def parse_arguments(): return parser.parse_args() -def _create_cell_df(cell_dfs: dict[str, DataFrame], feature_names: dict[str, str]) -> DataFrame: - "Find chemical species, phenotypes, and locations and merge into a DataFrame." - for specimen, df_specimen in cell_dfs.items(): - df_specimen.rename( - {ft_id: 'FT_' + ft_name for ft_id, ft_name in feature_names.items()}, - axis=1, - inplace=True, - ) - # TODO: Create phenotype columns - df_specimen.rename( - {'pixel x': 'center_x', 'pixel y': 'center_y'}, - axis=1, - inplace=True, - ) - df_specimen['specimen'] = specimen - - # TODO: Reorder so that it's feature, specimen, phenotype, xy - # TODO: Verify histological structure ID or recreate one - df = concat(cell_dfs.values(), axis=0) - df.index.name = 'histological_structure' - return df - - -def _create_label_df( - df_assignments: DataFrame, - df_strata: DataFrame, -) -> tuple[DataFrame, dict[int, str]]: - """Get slide-level results.""" - df = merge(df_assignments, df_strata, on='stratum identifier', how='left')[ - ['specimen', 'subject diagnosed result'] - ].rename( - {'specimen': 'slide', 'subject diagnosed result': 'result'}, - axis=1, - ) - label_to_result = dict(enumerate(sort(df['result'].unique()))) - return df.replace({res: i for i, res in label_to_result.items()}), label_to_result - - -def retrieve_importances() -> DataFrame: - filename = join('out', 'importances.csv') - return read_csv(filename, index_col=0) - - -def save_importances(_args): - df = retrieve_importances() - credentials = DBCredentials(_args.dbname, _args.host, _args.user, _args.password) - connection = DatabaseConnectionMaker.make_connection(credentials) +def save_importance(spt_db_config_location: str) -> None: + """Save cell importance scores as defined by cggnn to the database.""" + df = read_csv(join('out', 'importances.csv'), index_col=0) + connection = DatabaseConnectionMaker(spt_db_config_location).get_connection() transcribe_importance(df, connection) connection.close() if __name__ == "__main__": args = parse_arguments() - extractor = FeatureMatrixExtractor(database_config_file=args.spt_db_config_location) - study_data = cast(Bundle, extractor.extract(study=args.study)) - df_cell = _create_cell_df( - { - slide: cast(DataFrame, data['dataframe']) - for slide, data in study_data['feature matrices'].items() - }, - cast(dict[str, str], study_data['channel symbols by column name']), - ) - df_label, label_to_result_text = _create_label_df( - cast(DataFrame, study_data['sample cohorts']['assignments']), - cast(DataFrame, study_data['sample cohorts']['strata']), - ) - + df_cell, df_label, label_to_result = extract_cggnn_data(args.spt_db_config_location, args.study) run_with_dfs( df_cell, df_label, - label_to_result_text, + label_to_result, args.validation_data_percent, args.test_data_percent, args.roi_side_length, @@ -201,6 +143,6 @@ def save_importances(_args): args.k_folds, args.explainer, args.merge_rois, - args.prune_misclassified, + args.prune_misclassified ) - save_importances(args) + save_importance(args) diff --git a/spatialprofilingtoolbox/cggnn/scripts/run_sql.py b/spatialprofilingtoolbox/cggnn/scripts/run_sql.py index a0ab4a078..8244df941 100644 --- a/spatialprofilingtoolbox/cggnn/scripts/run_sql.py +++ b/spatialprofilingtoolbox/cggnn/scripts/run_sql.py @@ -1,4 +1,5 @@ "Run through the entire SPT CG-GNN pipeline, starting from a SPT SQL database." + from argparse import ArgumentParser from spatialprofilingtoolbox.standalone_utilities.module_load_error import SuggestExtrasException try: @@ -8,7 +9,7 @@ def parse_arguments(): - "Process command line arguments." + """Process command line arguments.""" parser = ArgumentParser( prog='spt cggnn run_sql', description='Create cell graphs from SPT SQL tables, train a graph neural network on ' diff --git a/spatialprofilingtoolbox/db/feature_matrix_extractor.py b/spatialprofilingtoolbox/db/feature_matrix_extractor.py index 1d30af422..0f30ab77d 100644 --- a/spatialprofilingtoolbox/db/feature_matrix_extractor.py +++ b/spatialprofilingtoolbox/db/feature_matrix_extractor.py @@ -1,16 +1,21 @@ -""" -Convenience provision of a feature matrix for each study, the data retrieved from the SPT database. -""" +"""Convenience provision of a feature matrix for each study, retrieved from the SPT database.""" from enum import Enum from enum import auto -from typing import cast +from typing import cast, Any +from dataclasses import dataclass from pandas import DataFrame from psycopg2.extensions import cursor as Psycopg2Cursor from spatialprofilingtoolbox import DatabaseConnectionMaker -from spatialprofilingtoolbox.db.stratification_puller import StratificationPuller +from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria +from spatialprofilingtoolbox.db.phenotypes import PhenotypesAccess +from spatialprofilingtoolbox.db.stratification_puller import ( + StratificationPuller, + Stratification, +) +from spatialprofilingtoolbox.db.study_access import StudyAccess from spatialprofilingtoolbox.workflow.common.structure_centroids_puller import \ StructureCentroidsPuller from spatialprofilingtoolbox.workflow.common.sparse_matrix_puller import SparseMatrixPuller @@ -18,10 +23,16 @@ logger = colorized_logger(__name__) -BundlePart = dict[str, DataFrame | str | dict[str, DataFrame | str ]] -Bundle = dict[str, dict[str, BundlePart]] -class DBSource(Enum): +@dataclass +class MatrixBundle: + """Bundle of information for a specimen matrix.""" + dataframe: DataFrame + filename: str + continuous_dataframe: DataFrame | None = None + + +class _DBSource(Enum): """Indicator of intended database source.""" CURSOR = auto() CONFIG_FILE = auto() @@ -32,20 +43,20 @@ class FeatureMatrixExtractor: """Pull from the database and create convenience bundle of feature matrices and metadata.""" cursor: Psycopg2Cursor database_config_file: str | None - db_source: DBSource + db_source: _DBSource def __init__(self, - cursor: Psycopg2Cursor | None=None, - database_config_file: str | None=None, - ): + cursor: Psycopg2Cursor | None = None, + database_config_file: str | None = None, + ) -> None: self.cursor = cast(Psycopg2Cursor, cursor) self.database_config_file = database_config_file if cursor is not None: - self.db_source = DBSource.CURSOR + self.db_source = _DBSource.CURSOR elif database_config_file is not None: - self.db_source = DBSource.CONFIG_FILE + self.db_source = _DBSource.CONFIG_FILE else: - self.db_source = DBSource.UNKNOWN + self.db_source = _DBSource.UNKNOWN self._report_on_arguments() def _report_on_arguments(self): @@ -57,19 +68,41 @@ def _report_on_arguments(self): logger.warning(message) def extract(self, - specimen: str | None=None, - study: str | None=None, - continuous_also: bool=False, - ) -> Bundle | None: - extraction = None + specimen: str | None = None, + study: str | None = None, + continuous_also: bool = False, + ) -> dict[str, MatrixBundle]: + """Extract feature matrices for a specimen or every specimen in a study. + + Parameters + ---------- + specimen: str | None = None + study: str | None = None + Which specimen to extract features for or study to extract features for all specimens + specimens for. Exactly one of specimen or study must be provided. + continuous_also: bool = False + Whether to also calculate and return a DataFrame for each specimen with continuous + channel information in addition to the default DataFrame which provides binary cast + channel information. + + Returns + ------- + dict[str, MatrixBundle] + A dictionary of specimen names to a MatrixBundle dataclass instances, which contain: + 1. `dataframe`, a DataFrame with the feature matrix for the specimen, including + centroid location, channel information, and phenotype information. + 2. `filename`, a filename for the DataFrame. + 3. `continuous_dataframe`, a DataFrame with continuous channel information if + continuous_also is true, otherwise this property is None. + """ match self.db_source: - case DBSource.CURSOR: + case _DBSource.CURSOR: extraction = self._extract( specimen=specimen, study=study, continuous_also=continuous_also, ) - case DBSource.CONFIG_FILE: + case _DBSource.CONFIG_FILE: with DatabaseConnectionMaker(self.database_config_file) as dcm: with dcm.get_connection().cursor() as cursor: self.cursor = cursor @@ -78,15 +111,17 @@ def extract(self, study=study, continuous_also=continuous_also, ) - case DBSource.UNKNOWN: - logger.error('The database source can not be determined.') + case _DBSource.UNKNOWN: + raise RuntimeError('The database source can not be determined.') return extraction def _extract(self, - specimen: str | None=None, - study: str | None=None, - continuous_also: bool=False, - ) -> Bundle | None: + specimen: str | None = None, + study: str | None = None, + continuous_also: bool = False, + ) -> dict[str, MatrixBundle]: + if (specimen is None) == (study is None): + raise ValueError('Must specify exactly one of specimen or study.') data_arrays = self._retrieve_expressions_from_database( specimen=specimen, study=study, @@ -96,163 +131,159 @@ def _extract(self, specimen=specimen, study=study, ) - stratification = self._retrieve_derivative_stratification_from_database() - study_component_lookup = self._retrieve_study_component_lookup() - merged = self._merge_dictionaries( - self._create_feature_matrices(data_arrays, centroid_coordinates), + if study is None: + assert specimen is not None + study = StudyAccess(self.cursor).get_study_from_specimen(specimen) + + return self._create_feature_matrices( + data_arrays, + centroid_coordinates, + self._retrieve_phenotypes(study), self._create_channel_information(data_arrays), - stratification, - new_keys=['feature matrices','channel symbols by column name', 'sample cohorts'], - study_component_lookup=study_component_lookup, ) - if merged is None: - return None - if study is not None: - for key in list(merged.keys()): - if not key == study: - del merged[key] - return merged - - @staticmethod - def redact_dataframes(extraction): - for study_name, study in extraction.items(): - for specimen in study['feature matrices'].keys(): - extraction[study_name]['feature matrices'][specimen]['dataframe'] = None - key = 'continuous dataframe' - if key in extraction[study_name]['feature matrices'][specimen]: - extraction[study_name]['feature matrices'][specimen][key] = None - extraction[study_name]['sample cohorts']['assignments'] = None - extraction[study_name]['sample cohorts']['strata'] = None def _retrieve_expressions_from_database(self, - specimen: str | None=None, - study: str | None=None, - continuous_also: bool=False, - ): + specimen: str | None = None, + study: str | None = None, + continuous_also: bool = False, + ) -> dict[str, dict[str, Any]]: logger.info('Retrieving expression data from database.') puller = SparseMatrixPuller(self.cursor) puller.pull(specimen=specimen, study=study, continuous_also=continuous_also) data_arrays = puller.get_data_arrays() logger.info('Done retrieving expression data from database.') - return data_arrays.get_studies() + return list(data_arrays.get_studies().values())[0] def _retrieve_structure_centroids_from_database(self, - specimen: str | None=None, - study: str | None=None, - ): + specimen: str | None = None, + study: str | None = None, + ) -> dict[str, Any]: logger.info('Retrieving polygon centroids from shapefiles in database.') puller = StructureCentroidsPuller(self.cursor) puller.pull(specimen=specimen, study=study) structure_centroids = puller.get_structure_centroids() logger.info('Done retrieving centroids.') - return structure_centroids.get_studies() + return list(structure_centroids.get_studies().values())[0] - def _retrieve_derivative_stratification_from_database(self): - logger.info('Retrieving stratification from database.') - puller = StratificationPuller(self.cursor) - puller.pull() - stratification = puller.get_stratification() - logger.info('Done retrieving stratification.') - return stratification - - def _retrieve_study_component_lookup(self): - self.cursor.execute('SELECT * FROM study_component ; ') - rows = self.cursor.fetchall() - lookup = {} - for row in rows: - lookup[row[1]] = row[0] - return lookup + def _retrieve_phenotypes(self, study_name: str) -> dict[str, PhenotypeCriteria]: + logger.info('Retrieving phenotypes from database.') + phenotypes: dict[str, PhenotypeCriteria] = {} + phenotype_access = PhenotypesAccess(self.cursor) + for symbol_data in phenotype_access.get_phenotype_symbols(study_name): + symbol = symbol_data.handle_string + phenotypes[symbol] = phenotype_access.get_phenotype_criteria(study_name, symbol) + logger.info('Done retrieving phenotypes.') + return phenotypes - def _create_feature_matrices(self, data_arrays, centroid_coordinates): + def _create_feature_matrices(self, + study: dict[str, dict[str, Any]], + centroid_coordinates: dict[str, Any], + phenotypes: dict[str, PhenotypeCriteria], + channel_information: list[str], + ) -> dict[str, MatrixBundle]: logger.info('Creating feature matrices from binary data arrays and centroids.') - matrices = {} - for k, study_name in enumerate(sorted(list(data_arrays.keys()))): - study = data_arrays[study_name] - matrices[study_name] = {} - for j, specimen in enumerate(sorted(list(study['data arrays by specimen'].keys()))): + matrices: dict[str, MatrixBundle] = {} + for j, specimen in enumerate(sorted(list(study['data arrays by specimen'].keys()))): + logger.debug('Specimen %s .', specimen) + expressions = study['data arrays by specimen'][specimen] + rows = [ + self._create_feature_matrix_row( + centroid_coordinates[specimen][i], + expressions[i], + len(study['target index lookup']), + ) for i in range(len(expressions)) + ] + dataframe = DataFrame( + rows, + columns=['pixel x', 'pixel y'] + [f'C {cs}' for cs in channel_information], + ) + for symbol, criteria in phenotypes.items(): + dataframe[f'P {symbol}'] = ( + dataframe[[f'C {m}' for m in criteria.positive_markers]].all(axis=1) & + ~dataframe[[f'C {m}' for m in criteria.negative_markers]].any(axis=1) + ).astype(int) + matrices[specimen] = MatrixBundle(dataframe, f'{j}.tsv') + + if 'continuous data arrays by specimen' in study: + specimens = list(study['continuous data arrays by specimen'].keys()) + for j, specimen in enumerate(sorted(specimens)): logger.debug('Specimen %s .', specimen) - expressions = study['data arrays by specimen'][specimen] - number_channels = len(study['target index lookup']) - rows = [ - self._create_feature_matrix_row( - centroid_coordinates[study_name][specimen][i], - expressions[i], - number_channels, - ) - for i in range(len(expressions)) - ] + expression_vectors = study['continuous data arrays by specimen'][specimen] dataframe = DataFrame( - rows, - columns=['pixel x', 'pixel y'] + [f'F{i}' for i in range(number_channels)], + expression_vectors, + columns=[f'C {cs}' for cs in channel_information], ) - matrices[study_name][specimen] = { - 'dataframe': dataframe, - 'filename': f'{k}.{j}.tsv', - } - - if 'continuous data arrays by specimen' in study: - specimens = list(study['continuous data arrays by specimen'].keys()) - for j, specimen in enumerate(sorted(specimens)): - logger.debug('Specimen %s .', specimen) - expression_vectors = study['continuous data arrays by specimen'][specimen] - number_channels = len(study['target index lookup']) - dataframe = DataFrame( - expression_vectors, - columns=[f'F{i}' for i in range(number_channels)], - ) - matrices[study_name][specimen]['continuous dataframe'] = dataframe + matrices[specimen].continuous_dataframe = dataframe logger.info('Done creating feature matrices.') return matrices @staticmethod - def _create_feature_matrix_row(centroid, binary, number_channels): + def _create_feature_matrix_row( + centroid: tuple[float, float], + binary: list[str], + number_channels: int, + ) -> list[float | int]: template = '{0:0%sb}' % number_channels # pylint: disable=consider-using-f-string - feature_vector = [int(value) for value in list(template.format(binary)[::-1])] + feature_vector: list[int] = [int(value) for value in list(template.format(binary)[::-1])] return [centroid[0], centroid[1]] + feature_vector - def _create_channel_information(self, data_arrays): - return { - study_name: self._create_channel_information_for_study(study) - for study_name, study in data_arrays.items() - } - - def _create_channel_information_for_study(self, study): + def _create_channel_information(self, + study_information: dict[str, dict[str, Any]] + ) -> list[str]: logger.info('Aggregating channel information for one study.') targets = { int(index): target - for target, index in study['target index lookup'].items() + for target, index in study_information['target index lookup'].items() } symbols = { target: symbol - for symbol, target in study['target by symbol'].items() + for symbol, target in study_information['target by symbol'].items() } logger.info('Done aggregating channel information.') - return { - f'F{i}': symbols[targets[i]] - for i in sorted([int(index) for index in targets.keys()]) - } + return [ + symbols[targets[i]] for i in sorted([int(index) for index in targets.keys()]) + ] - def _merge_dictionaries(self, - *args, - new_keys: list, - study_component_lookup: dict - ) -> Bundle | None: - if not len(args) == len(new_keys): - logger.error( - "Can not match up dictionaries to be merged with the list of key names to be " - "issued for them." - ) - return None + def extract_cohorts(self, study: str) -> dict[str, DataFrame]: + """Extract specimen cohort information for every specimen in a study.""" + match self.db_source: + case _DBSource.CURSOR: + extraction = self._extract_cohorts(study) + case _DBSource.CONFIG_FILE: + with DatabaseConnectionMaker(self.database_config_file) as dcm: + with dcm.get_connection().cursor() as cursor: + self.cursor = cursor + extraction = self._extract_cohorts(study) + case _DBSource.UNKNOWN: + raise RuntimeError('The database source can not be determined.') + return extraction - merged: dict = {} - for i in range(len(new_keys)): - for substudy, value in args[i].items(): - merged[study_component_lookup[substudy]] = {} + def _extract_cohorts(self, study: str) -> dict[str, DataFrame]: + stratification = self._retrieve_derivative_stratification_from_database() + for substudy in self._retrieve_component_studies(study): + if substudy in stratification: + break + else: + raise RuntimeError('Stratification substudy not found for study.') + return stratification[substudy] - for i, key in enumerate(new_keys): - for substudy, value in args[i].items(): - merged[study_component_lookup[substudy]][key] = value + def _retrieve_derivative_stratification_from_database(self) -> Stratification: + logger.info('Retrieving stratification from database.') + puller = StratificationPuller(self.cursor) + puller.pull() + stratification = puller.get_stratification() + logger.info('Done retrieving stratification.') + return stratification - logger.info('Done merging into a single dictionary bundle.') - return merged + def _retrieve_component_studies(self, study: str) -> set[str]: + self.cursor.execute(f''' + SELECT component_study + FROM study_component + WHERE primary_study = '{study}'; + ''') + rows = self.cursor.fetchall() + lookup: set[str] = set() + for row in rows: + lookup.add(row[0]) + return lookup diff --git a/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py b/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py index 467f6544b..4a3711138 100644 --- a/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py +++ b/spatialprofilingtoolbox/db/scripts/retrieve_feature_matrices.py @@ -1,27 +1,20 @@ """Convenience CLI wrapper of FeatureMatrixExtractor functionality, writes to files.""" + import argparse -import json from os.path import exists from os.path import abspath from os.path import expanduser -from typing import cast - -from pandas import DataFrame from spatialprofilingtoolbox.standalone_utilities.module_load_error import \ SuggestExtrasException try: from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor - from spatialprofilingtoolbox.db.feature_matrix_extractor import Bundle except ModuleNotFoundError as e: SuggestExtrasException(e, 'db') from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor -from spatialprofilingtoolbox.db.feature_matrix_extractor import Bundle from spatialprofilingtoolbox.workflow.common.cli_arguments import add_argument -from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger -logger = colorized_logger('spt db create-schema') def retrieve(args: argparse.Namespace): database_config_file = None @@ -31,31 +24,26 @@ def retrieve(args: argparse.Namespace): message = f'Need to supply valid database config filename: {database_config_file}' raise FileNotFoundError(message) extractor = FeatureMatrixExtractor(database_config_file=database_config_file) - bundle = cast(Bundle, extractor.extract()) - for _, study in bundle.items(): - feature_matrices = cast(dict[str, dict[str, DataFrame | str ]], study['feature matrices']) - for _, specimen_data in feature_matrices.items(): - df = cast(DataFrame, specimen_data['dataframe']) - filename = cast(str, specimen_data['filename']) - df.to_csv(filename, sep='\t', index=False) - outcomes = cast(DataFrame, study['sample cohorts']['assignments']) - filename = 'assignments.tsv' - outcomes.to_csv(filename, sep='\t', index=False) - FeatureMatrixExtractor.redact_dataframes(bundle) - with open('features.json', 'wt', encoding='utf-8') as file: - file.write(json.dumps(bundle, indent=2)) + feature_matrices = extractor.extract(args.study_name) + for _, specimen_data in feature_matrices.items(): + specimen_data.dataframe.to_csv(specimen_data.filename, sep='\t', index=False) + outcomes = extractor.extract_cohorts(args.study_name)['assignments'] + filename = 'assignments.tsv' + outcomes.to_csv(filename, sep='\t', index=False) if __name__ == '__main__': parser = argparse.ArgumentParser( prog='spt db retrieve-feature-matrices', description=''' -Retrieve feature matrices for each sample of each study, an outcomes dataframe -for each study, and a column/channel names lookup for each study. -Retrieves from any database that conforms to "single cell ADI" database schema. -Writes TSV files to the current working directory, with filenames listed alongside -specimen and channel name information in: features.json -''' +Retrieve feature matrices for each sample of a study and corresponding outcomes +dataframe and column/channel names lookup from any database that conforms to +"single cell ADI" database schema and writes them as TSV files to the current +working directory, with filenames listed alongside specimen and channel name +information in: features.json +''', + ) add_argument(parser, 'database config') + add_argument(parser, 'study name') retrieve(parser.parse_args()) diff --git a/spatialprofilingtoolbox/db/squidpy_metrics.py b/spatialprofilingtoolbox/db/squidpy_metrics.py index aead67a07..7a7406e09 100644 --- a/spatialprofilingtoolbox/db/squidpy_metrics.py +++ b/spatialprofilingtoolbox/db/squidpy_metrics.py @@ -1,13 +1,10 @@ """Make squidpy metrics that don't require specific phenotype selection available.""" -from typing import cast - from pandas import DataFrame from psycopg2.extensions import cursor as Psycopg2Cursor from spatialprofilingtoolbox import DatabaseConnectionMaker from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor -from spatialprofilingtoolbox.db.feature_matrix_extractor import Bundle from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria from spatialprofilingtoolbox.db.create_data_analysis_study import DataAnalysisStudyFactory from spatialprofilingtoolbox.workflow.common.squidpy import ( @@ -27,10 +24,7 @@ def create_and_transcribe_squidpy_features( """Transcribe "off-demand" Squidpy feature(s) in features system.""" connection = database_connection_maker.get_connection() das = DataAnalysisStudyFactory(connection, study, 'spatial autocorrelation').create() - features_by_specimen, channel_symbols_by_column_name = _fetch_cells_and_phenotypes( - connection.cursor(), - study, - ) + features_by_specimen = _fetch_cells(connection.cursor(), study) with ADIFeaturesUploader( database_connection_maker, data_analysis_study=das, @@ -42,7 +36,6 @@ def create_and_transcribe_squidpy_features( create_and_transcribe_one_sample( sample, df, - channel_symbols_by_column_name, feature_uploader, ) @@ -50,31 +43,24 @@ def create_and_transcribe_squidpy_features( def create_and_transcribe_one_sample( sample: str, df: DataFrame, - channel_symbols_by_column_name: dict[str, str], feature_uploader: ADIFeaturesUploader, ) -> None: - for column, symbol in channel_symbols_by_column_name.items(): - criteria = PhenotypeCriteria(positive_markers=[column], negative_markers=[]) - value = compute_squidpy_metric_for_one_sample(df, [criteria], 'spatial autocorrelation') - if value is None: - continue - feature_uploader.stage_feature_value((symbol,), sample, value) + for column in df.columns: + if column.startswith('C '): + symbol = column[2:] + criteria = PhenotypeCriteria(positive_markers=[symbol], negative_markers=[]) + value = compute_squidpy_metric_for_one_sample(df, [criteria], 'spatial autocorrelation') + if value is None: + continue + feature_uploader.stage_feature_value((symbol,), sample, value) -def _fetch_cells_and_phenotypes( +def _fetch_cells( cursor: Psycopg2Cursor, study: str, -) -> tuple[dict[str, DataFrame], dict[str, str]]: - extractor = FeatureMatrixExtractor(cursor) - bundle = cast(Bundle, extractor.extract(study=study)) - FeatureMatrices = dict[str, dict[str, DataFrame | str]] - feature_matrices = cast(FeatureMatrices, bundle[study]['feature matrices']) +) -> dict[str, DataFrame]: + feature_matrices = FeatureMatrixExtractor(cursor).extract(study=study) features_by_specimen = { - specimen: cast(DataFrame, packet['dataframe']) - for specimen, packet in feature_matrices.items() + specimen: bundle.dataframe for specimen, bundle in feature_matrices.items() } - channel_symbols_by_columns_name = cast( - dict[str, str], - bundle[study]['channel symbols by column name'], - ) - return features_by_specimen, channel_symbols_by_columns_name + return features_by_specimen diff --git a/spatialprofilingtoolbox/workflow/common/cell_df_indexer.py b/spatialprofilingtoolbox/workflow/common/cell_df_indexer.py deleted file mode 100644 index fec073d40..000000000 --- a/spatialprofilingtoolbox/workflow/common/cell_df_indexer.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Convenience functions for working with cell DataFrames.""" - -import warnings - -from numpy import asarray, ndarray -from numpy.typing import NDArray -from pandas import DataFrame - -from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria -from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger - -logger = colorized_logger(__name__) - -ValueMultiIndex = tuple[int | tuple[int, ...], list[str]] - -def _get_value_and_multiindex(signature: PhenotypeCriteria) -> ValueMultiIndex: - value = (1,) * len(signature.positive_markers) + (0,) * len(signature.negative_markers) - _value: int | tuple[int, ...] - if len(value) == 1: - _value = value[0] - else: - _value = value - multiindex = [*signature.positive_markers, *signature.negative_markers] - return _value, multiindex - - -def get_mask(cells: DataFrame, signature: PhenotypeCriteria) -> NDArray[bool,]: - """Transform phenotype signature into a boolean mask for a DataFrame.""" - value, multiindex = _get_value_and_multiindex(signature) - try: - with warnings.catch_warnings(): - message = "indexing past lexsort depth may impact performance" - warnings.filterwarnings("ignore", message=message) - loc = cells.set_index(multiindex).index.get_loc(value) - except KeyError as exception: - logger.debug('Some KeyError. %s', str(exception)) - return asarray([False,] * cells.shape[0]) - if isinstance(loc, ndarray): - return loc - if isinstance(loc, slice): - range1 = [False,]*(loc.start - 0) - range2 = [True,]*(loc.stop - loc.start) - range3 = [False,]*(cells.shape[0] - loc.stop) - return asarray(range1 + range2 + range3) - if isinstance(loc, int): - return asarray([i == loc for i in range(cells.shape[0])]) - raise ValueError(f'Could not select by index: {multiindex}. Got: {loc}') diff --git a/spatialprofilingtoolbox/workflow/common/proximity.py b/spatialprofilingtoolbox/workflow/common/proximity.py index 84e5e036a..42ea892b4 100644 --- a/spatialprofilingtoolbox/workflow/common/proximity.py +++ b/spatialprofilingtoolbox/workflow/common/proximity.py @@ -1,14 +1,12 @@ """Low-level calculations of proximity metric.""" from math import isnan -import re from pandas import DataFrame from sklearn.neighbors import BallTree # type: ignore from numpy import logical_and from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria -from spatialprofilingtoolbox.workflow.common.cell_df_indexer import get_mask from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger logger = colorized_logger(__name__) @@ -21,12 +19,19 @@ def compute_proximity_metric_for_signature_pair( cells: DataFrame, tree: BallTree ) -> float | None: - mask1 = get_mask(cells, signature1) - mask2 = get_mask(cells, signature2) + cells = cells.rename({ + column: (column[2:] if (column.startswith('C ') or column.startswith('P ')) else column) + for column in cells.columns + }, axis=1) + mask1 = cells.astype(bool)[signature1.positive_markers].all(axis=1) & \ + (~(cells.astype(bool))[signature1.negative_markers]).all(axis=1) + mask2 = cells.astype(bool)[signature2.positive_markers].all(axis=1) & \ + (~(cells.astype(bool))[signature2.negative_markers]).all(axis=1) source_count = sum(mask1) + if source_count == 0: return None - source_cell_locations = cells.loc()[mask1][['pixel x', 'pixel y']] + source_cell_locations = cells.loc[mask1, ['pixel x', 'pixel y']] within_radius_indices_list = tree.query_radius( source_cell_locations, radius, @@ -50,25 +55,15 @@ def _validate_value(value) -> bool: return True -def _phenotype_identifier_lookup(handle, channel_symbols_by_column_name) -> str: - if re.match(r'^\d+$', handle): - return f'cell_phenotype {handle}' - if re.match(r'^F\d+$', handle): - channel_symbol = channel_symbols_by_column_name[handle] - return channel_symbol - raise ValueError(f'Did not understand meaning of specifier: {handle}') - - def stage_proximity_feature_values( feature_uploader, feature_values, - channel_symbols_by_column_name, sample_identifier, ) -> None: for _, row in feature_values.iterrows(): specifiers = ( - _phenotype_identifier_lookup(row['Phenotype 1'], channel_symbols_by_column_name), - _phenotype_identifier_lookup(row['Phenotype 2'], channel_symbols_by_column_name), + row['Phenotype 1'], + row['Phenotype 2'], row['Pixel radius'], ) value = row['Proximity'] diff --git a/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py b/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py index 20a9133fa..52e1f6956 100644 --- a/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py +++ b/spatialprofilingtoolbox/workflow/common/sparse_matrix_puller.py @@ -2,7 +2,7 @@ (in-memory) binary compressed format. """ -from typing import cast +from typing import cast, Any from psycopg2.extensions import cursor as Psycopg2Cursor @@ -36,9 +36,9 @@ class CompressedDataArrays: """ def __init__(self): - self.studies = {} + self.studies: dict[str, dict[str, dict[str, Any]]] = {} - def get_studies(self): + def get_studies(self) -> dict[str, dict[str, dict[str, Any]]]: return self.studies def add_study_data( diff --git a/spatialprofilingtoolbox/workflow/common/squidpy.py b/spatialprofilingtoolbox/workflow/common/squidpy.py index 305e01a6e..92f1316d2 100644 --- a/spatialprofilingtoolbox/workflow/common/squidpy.py +++ b/spatialprofilingtoolbox/workflow/common/squidpy.py @@ -2,10 +2,11 @@ from typing import Any from typing import cast +from warnings import warn from numpy.typing import NDArray from numpy import isnan -from pandas import DataFrame +from pandas import DataFrame, Series from squidpy.gr import ( # type: ignore spatial_neighbors, nhood_enrichment, @@ -18,7 +19,6 @@ from scipy.stats import norm # type: ignore from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria -from spatialprofilingtoolbox.workflow.common.cell_df_indexer import get_mask from spatialprofilingtoolbox import get_feature_description from spatialprofilingtoolbox import squidpy_feature_classnames @@ -37,8 +37,15 @@ def compute_squidpy_metric_for_one_sample( radius: float | None = None, ) -> float | None: """Compute Squidpy metrics for a tissue sample with a clustering of the given phenotypes.""" - df_cell.sort_index(inplace=True) - masks: list[NDArray[Any]] = [get_mask(df_cell, signature) for signature in phenotypes] + df_cell = df_cell.rename({ + column: (column[2:] if (column.startswith('C ') or column.startswith('P ')) else column) + for column in df_cell.columns + }, axis=1) + masks: list[Series] = [ + (df_cell.astype(bool)[signature.positive_markers].all(axis=1) & + (~(df_cell.astype(bool))[signature.negative_markers]).all(axis=1)) + for signature in phenotypes + ] adata = convert_df_to_anndata(df_cell, masks) match feature_class: case 'neighborhood enrichment': @@ -54,20 +61,30 @@ def compute_squidpy_metric_for_one_sample( return None -def _summarize_neighborhood_enrichment(unstructured_metrics) -> float | None: +def _summarize_neighborhood_enrichment( + unstructured_metrics: dict[str, NDArray[Any]] +) -> float | None: + if len(unstructured_metrics['zscore'].shape) != 2: + return None zscore = float(unstructured_metrics['zscore'][0][1]) return float(norm.cdf(zscore)) -def _summarize_co_occurrence(unstructured_metrics) -> float | None: +def _summarize_co_occurrence(unstructured_metrics: dict[str, NDArray[Any]] | None) -> float | None: + if unstructured_metrics is None: + return None occurrence_ratios = unstructured_metrics['occ'] + if len(occurrence_ratios.shape) != 3: + return None return float(occurrence_ratios[0][1][0]) -def _summarize_ripley(unstructured_metrics) -> float | None: +def _summarize_ripley(unstructured_metrics: dict[str, NDArray[Any]]) -> float | None: bins = unstructured_metrics['bins'] pvalues = unstructured_metrics['pvalues'] - pairs = list(zip(bins, pvalues)) + pairs = list(zip(bins.tolist(), pvalues.tolist())) + if len(bins) != len(pvalues) or len(pairs) == 0: + return None filtered = [pair for pair in pairs if pair[1] > 0] if len(filtered) == 0: return 1.0 @@ -75,7 +92,7 @@ def _summarize_ripley(unstructured_metrics) -> float | None: return float(sorted_pairs[0][1]) -def _summarize_spatial_autocorrelation(unstructured_metrics) -> float | None: +def _summarize_spatial_autocorrelation(unstructured_metrics: DataFrame) -> float | None: row = unstructured_metrics.iloc[0] pvalue = float(row['pval_norm']) if isnan(pvalue): @@ -88,9 +105,10 @@ def _summarize_spatial_autocorrelation(unstructured_metrics) -> float | None: def round10(value): return int(pow(10, 10) * value) / pow(10, 10) + def convert_df_to_anndata( df: DataFrame, - phenotypes_to_cluster_on: list[NDArray[Any]] | None = None, + phenotypes_to_cluster_on: list[Series] | None = None, ) -> AnnData: """Convert SPT DataFrame to AnnData object for use with Squidpy metrics. @@ -99,7 +117,7 @@ def convert_df_to_anndata( A dataframe with an arbitrary index, x and y locations of histological structures with column names 'pixel x' and 'pixel y', and several columns with arbitrary names each indicating the expression of a phenotype. - phenotypes_to_cluster_on: list[NDArray[Any]] | None + phenotypes_to_cluster_on: list[Series] | None Used to create a 'cluster' column in the AnnData object if provided. Each list is a mask of positive or negative features that indicate the phenotype. * If only one phenotype is provided, two clusters will be created mirroring the @@ -113,7 +131,7 @@ def convert_df_to_anndata( """ locations: NDArray[Any] = df[['pixel x', 'pixel y']].to_numpy() adata = AnnData( - df.to_numpy(), + df.drop(['pixel x', 'pixel y'], axis=1).to_numpy(), obsm={'spatial': locations}, # type: ignore ) spatial_neighbors(adata) @@ -123,20 +141,24 @@ def convert_df_to_anndata( for phenotype in phenotypes_to_cluster_on[1:]: clustering[(clustering == 0) & phenotype] = i_cluster i_cluster += 1 - adata.obs['cluster'] = clustering + adata.obs['cluster'] = clustering.to_numpy() adata.obs['cluster'] = adata.obs['cluster'].astype('category') + if adata.obs['cluster'].nunique() == 1: + warn('All phenotypes provided had identical values. Only one cluster could be made.') return adata -def _nhood_enrichment(adata: AnnData) -> dict[str, list[float] | list[int]]: +def _nhood_enrichment(adata: AnnData) -> dict[str, NDArray[Any]]: """Compute neighborhood enrichment by permutation test.""" result = nhood_enrichment(adata, 'cluster', copy=True, seed=128, show_progress_bar=False) zscore, count = cast(tuple[NDArray[Any], NDArray[Any]], result) - return {'zscore': zscore.tolist(), 'count': count.tolist()} + return {'zscore': zscore, 'count': count} -def _co_occurrence(adata: AnnData, radius: float) -> dict[str, list[float]]: +def _co_occurrence(adata: AnnData, radius: float) -> dict[str, NDArray[Any]] | None: """Compute co-occurrence probability of clusters.""" + if adata.obs['cluster'].nunique() < 2: + return None result = co_occurrence( adata, 'cluster', @@ -145,15 +167,17 @@ def _co_occurrence(adata: AnnData, radius: float) -> dict[str, list[float]]: show_progress_bar=False, ) occ, interval = cast(tuple[NDArray[Any], NDArray[Any]], result) - return {'occ': occ.tolist(), 'interval': interval.tolist()} + return {'occ': occ, 'interval': interval} -def _ripley(adata: AnnData) -> dict[str, list[list[float]] | list[float] | list[int]]: +def _ripley(adata: AnnData) -> dict[str, NDArray[Any]]: r"""Compute various Ripley\'s statistics for point processes.""" result = ripley(adata, 'cluster', copy=True) + bins = cast(NDArray[Any], result['bins']) + pvalues = cast(NDArray[Any], result['pvalues']) return { - 'bins': result['bins'].tolist(), - 'pvalues': result['pvalues'].tolist()[0], + 'bins': bins, + 'pvalues': pvalues[0,], } diff --git a/spatialprofilingtoolbox/workflow/common/structure_centroids.py b/spatialprofilingtoolbox/workflow/common/structure_centroids.py index 37a934459..d9d2adc1d 100644 --- a/spatialprofilingtoolbox/workflow/common/structure_centroids.py +++ b/spatialprofilingtoolbox/workflow/common/structure_centroids.py @@ -1,4 +1,6 @@ """An object for in-memory storage of summarized-location data for all cells of each study.""" + +from typing import Any from pickle import dump from pickle import load from os.path import join @@ -8,33 +10,36 @@ class StructureCentroids: - """An object for in-memory storage of summarized-location data for all cells of - each study. - - Member `studies` is a dictionary with keys the study names. The values are - dictionaries, providing for each specimen name (for specimens collected as - part of the given study) the list of pairs of pixel coordinate values - representing the centroid of the shape specification for a given cell. The - order is ascending lexicographical order of the corresponding "histological + """An object for in-memory storage of summarized-location data for all cells of each study. + + Member `studies` is a dictionary with keys the study names. The values are dictionaries, + providing for each specimen name (for specimens collected as part of the given study) the list + of pairs of pixel coordinate values representing the centroid of the shape specification for a + given cell. The order is ascending lexicographical order of the corresponding "histological structure" identifier strings. """ + def __init__(self): - self.studies = {} + self.studies: dict[str, dict[str, Any]] = {} - def get_studies(self): + def get_studies(self) -> dict[str, dict[str, Any]]: return self.studies - def add_study_data(self, study_name, structure_centroids_by_specimen): + def add_study_data( + self, + study_name: str, + structure_centroids_by_specimen: dict[str, Any] + ) -> None: self.studies[study_name] = structure_centroids_by_specimen - def write_to_file(self, data_directory): + def write_to_file(self, data_directory: str) -> None: with open(join(data_directory, CENTROIDS_FILENAME), 'wb') as file: dump(self.get_studies(), file) - def load_from_file(self, data_directory): + def load_from_file(self, data_directory: str) -> None: with open(join(data_directory, CENTROIDS_FILENAME), 'rb') as file: self.studies = load(file) @staticmethod - def already_exists(data_directory): + def already_exists(data_directory: str) -> bool: return isfile(join(data_directory, CENTROIDS_FILENAME)) diff --git a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py index 5a42f5af1..78e3e84ff 100644 --- a/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py +++ b/spatialprofilingtoolbox/workflow/phenotype_proximity/core.py @@ -1,7 +1,8 @@ """The core calculator for the proximity calculation on a single source file.""" + import warnings import pickle -from typing import cast + import pandas as pd from pandas import DataFrame @@ -9,7 +10,6 @@ from spatialprofilingtoolbox.workflow.component_interfaces.core import CoreJob from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor -from spatialprofilingtoolbox.db.feature_matrix_extractor import Bundle from spatialprofilingtoolbox.workflow.common.logging.performance_timer import \ PerformanceTimerReporter from spatialprofilingtoolbox import DatabaseConnectionMaker @@ -29,16 +29,15 @@ class PhenotypeProximityCoreJob(CoreJob): """Core/parallelizable functionality for the phenotype proximity workflow.""" radii = [60, 120] - channel_symbols_by_column_name: dict[str, str] tree: BallTree def __init__(self, - study_name: str='', - database_config_file: str='', - performance_report_file: str='', - results_file: str='', + study_name: str = '', + database_config_file: str = '', + performance_report_file: str = '', + results_file: str = '', **kwargs # pylint: disable=unused-argument - ): + ) -> None: self.study_name = study_name self.database_config_file = database_config_file self.results_file = results_file @@ -64,48 +63,55 @@ def log_job_info(self): def calculate_proximity(self): self.reporter.record_timepoint('Start pulling data for one sample.') - extractor = FeatureMatrixExtractor(database_config_file=self.database_config_file) - bundle = cast(Bundle, extractor.extract(specimen=self.sample_identifier)) + bundle = FeatureMatrixExtractor(database_config_file=self.database_config_file).extract( + specimen=self.sample_identifier) self.reporter.record_timepoint('Finished pulling data for one sample.') - study_name = list(bundle.keys())[0] - identifier = list(bundle[study_name]['feature matrices'].keys())[0] - Sample = dict[str, DataFrame | str] - sample = cast(Sample, bundle[study_name]['feature matrices'][identifier]) - cells = cast(DataFrame, sample['dataframe']) + identifier = list(bundle.keys())[0] + cells = bundle[identifier].dataframe logger.info('Dataframe pulled: %s', cells.head()) self.create_ball_tree(cells) - self.channel_symbols_by_column_name = cast( - dict[str, str], - bundle[study_name]['channel symbols by column name'], + # Assemble phenotype signatures for every channel and phenotype + channels = sorted( + [col_name[2:] for col_name in cells.columns if col_name.startswith('C ')] + ) + phenotypes = sorted( + [col_name[2:] for col_name in cells.columns if col_name.startswith('P ')] ) - phenotype_identifiers, signatures = self.get_named_phenotype_signatures() + signatures = self.get_named_phenotype_signatures() + assert set(phenotypes) == signatures.keys() logger.info('Named phenotypes:') - logger.info(signatures) - - channels = sorted(self.channel_symbols_by_column_name.keys()) - singleton_signatures = [ - PhenotypeCriteria(positive_markers=[column_name], negative_markers=[]) - for column_name in channels - ] - all_signatures = singleton_signatures + signatures - cases = self.get_cases(all_signatures) - proximity_metrics=[ - compute_proximity_metric_for_signature_pair(s1, s2, r, cells, self.tree) - for s1, s2, r in cases - ] - self.write_table(proximity_metrics, self.get_cases(channels + phenotype_identifiers)) + logger.info(phenotypes) + all_features = channels + phenotypes + signatures.update({ + column_name: PhenotypeCriteria( + positive_markers=[column_name], + negative_markers=[], + ) for column_name in channels + }) + + # Calculate proximity metrics for every phenotype pair and write + proximity_metrics = { + (f1, f2, r): compute_proximity_metric_for_signature_pair( + signatures[f1], + signatures[f2], + r, + cells, + self.tree, + ) for f1, f2, r in self.get_cases(all_features) + } + self.write_table(proximity_metrics) def create_ball_tree(self, cells): self.tree = BallTree(cells[['pixel x', 'pixel y']].to_numpy()) - def get_named_phenotype_signatures(self) -> tuple[list[str], list[PhenotypeCriteria]]: + def get_named_phenotype_signatures(self) -> dict[str, PhenotypeCriteria]: with DatabaseConnectionMaker(self.database_config_file) as dcm: connection = dcm.get_connection() cursor = connection.cursor() cursor.execute(''' - SELECT cp.identifier, cp.symbol, cs.symbol, CASE cpc.polarity WHEN 'positive' THEN 1 WHEN 'negative' THEN 0 END coded_value + SELECT cp.symbol, cs.symbol, CASE cpc.polarity WHEN 'positive' THEN 1 WHEN 'negative' THEN 0 END coded_value FROM cell_phenotype cp JOIN cell_phenotype_criterion cpc ON cpc.cell_phenotype=cp.identifier JOIN chemical_species cs ON cs.identifier=cpc.marker @@ -116,44 +122,34 @@ def get_named_phenotype_signatures(self) -> tuple[list[str], list[PhenotypeCrite ''', (self.study_name,)) rows = cursor.fetchall() cursor.close() - lookup = {value : key for key, value in self.channel_symbols_by_column_name.items()} - criteria = DataFrame(rows, columns=['phenotype', 'name', 'channel', 'polarity']) - criteria = criteria[['phenotype', 'channel', 'polarity']] + criteria = DataFrame(rows, columns=['phenotype', 'channel', 'polarity']) def list_channels(df: DataFrame, polarity: int) -> list[str]: - return [lookup[r['channel']] for _, r in df.iterrows() if r['polarity'] == polarity] + return [r['channel'] for _, r in df.iterrows() if r['polarity'] == polarity] def make_signature(df) -> PhenotypeCriteria: return PhenotypeCriteria( - positive_markers = list_channels(df, 1), - negative_markers = list_channels(df, 0), + positive_markers=list_channels(df, 1), + negative_markers=list_channels(df, 0), ) by_identifier: dict[str, PhenotypeCriteria] = {} - for key, criteria in criteria.groupby(['phenotype']): - by_identifier[str(key[0])] = make_signature(criteria) - identifiers = sorted(by_identifier.keys()) - return identifiers, [by_identifier[i] for i in identifiers] + for phenotype, criteria in criteria.groupby('phenotype'): + by_identifier[phenotype] = make_signature(criteria) + return by_identifier - def get_cases(self, items): + def get_cases(self, items: list[str]) -> list[tuple[str, str, float]]: return [ - (s1, s2, radius) - for s1 in items - for s2 in items + (f1, f2, radius) + for f1 in items + for f2 in items for radius in PhenotypeProximityCoreJob.radii ] - def write_table(self, proximity_metrics, cases): - if len(proximity_metrics) != len(cases): - raise ValueError('Number of computed features not equal to number of cases.') - rows = list(zip( - [case[0] for case in cases], - [case[1] for case in cases], - [case[2] for case in cases], - proximity_metrics, - )) + def write_table(self, proximity_metrics: dict[tuple[str, str, float], float | None]) -> None: + rows = [(f1, f2, r, m) for (f1, f2, r), m in proximity_metrics.items()] columns = ['Phenotype 1', 'Phenotype 2', 'Pixel radius', 'Proximity'] df = DataFrame(rows, columns=columns) - bundle = [df, self.channel_symbols_by_column_name, self.sample_identifier] + bundle = [df, self.sample_identifier] with open(self.results_file, 'wb') as file: pickle.dump(bundle, file) logger.info('Computed metrics: %s', df.head()) diff --git a/spatialprofilingtoolbox/workflow/phenotype_proximity/integrator.py b/spatialprofilingtoolbox/workflow/phenotype_proximity/integrator.py index 569f1b492..890a1e569 100644 --- a/spatialprofilingtoolbox/workflow/phenotype_proximity/integrator.py +++ b/spatialprofilingtoolbox/workflow/phenotype_proximity/integrator.py @@ -68,9 +68,9 @@ def export_feature_values(self, core_computation_results_files, data_analysis_st def send_features_to_uploader(self, feature_uploader, core_computation_results_files): for results_file in core_computation_results_files: with open(results_file, 'rb') as file: - feature_values, channel_symbols_by_column_name, sample_identifier= pickle.load(file) + feature_values, sample_identifier= pickle.load(file) stage_proximity_feature_values( feature_uploader, feature_values, - channel_symbols_by_column_name, sample_identifier, + sample_identifier, ) diff --git a/test/cggnn/module_tests/test_image_runs_properly.sh b/test/cggnn/module_tests/test_image_runs_properly.sh index f99a041ad..d20b37f52 100644 --- a/test/cggnn/module_tests/test_image_runs_properly.sh +++ b/test/cggnn/module_tests/test_image_runs_properly.sh @@ -1,4 +1,4 @@ -pip freeze | grep dgl +python3.11 -m pip freeze | grep dgl status=$? [ $status -eq 0 ] || echo "Docker image for cggnn did not build and run properly." diff --git a/test/cggnn/unit_tests/test_explore.sh b/test/cggnn/unit_tests/test_explore.sh new file mode 100644 index 000000000..184c0626d --- /dev/null +++ b/test/cggnn/unit_tests/test_explore.sh @@ -0,0 +1,12 @@ +spt cggnn explore_classes \ + --spt_db_config_location ../db/.spt_db.config.container \ + --study "Melanoma intralesional IL2" +status=$? +[ $status -eq 0 ] || echo "cggnn explore_classes failed." + +if [ $status -eq 0 ]; +then + exit 0 +else + exit 1 +fi diff --git a/test/cggnn/unit_tests/test_extract.sh b/test/cggnn/unit_tests/test_extract.sh new file mode 100644 index 000000000..ced9a3e5e --- /dev/null +++ b/test/cggnn/unit_tests/test_extract.sh @@ -0,0 +1,23 @@ +spt cggnn extract \ + --spt_db_config_location ../db/.spt_db.config.container \ + --study "Melanoma intralesional IL2" \ + --output_location . +$([ $? -eq 0 ] && [ -e "label_to_results.json" ] && [ -e "cells.h5" ] && [ -e "labels.h5" ]) +status="$?" +echo "Status: $status" +[ $status -eq 0 ] || echo "cggnn extract failed." + +cat label_to_results.json +python3.11 -c 'import pandas as pd; print(pd.read_hdf("cells.h5"))' +python3.11 -c 'import pandas as pd; print(pd.read_hdf("labels.h5"))' + +rm label_to_results.json +rm cells.h5 +rm labels.h5 + +if [ $status -eq 0 ]; +then + exit 0 +else + exit 1 +fi diff --git a/test/cggnn/unit_tests/test_image_runs_properly.sh b/test/cggnn/unit_tests/test_image_runs_properly.sh index f99a041ad..d20b37f52 100644 --- a/test/cggnn/unit_tests/test_image_runs_properly.sh +++ b/test/cggnn/unit_tests/test_image_runs_properly.sh @@ -1,4 +1,4 @@ -pip freeze | grep dgl +python3.11 -m pip freeze | grep dgl status=$? [ $status -eq 0 ] || echo "Docker image for cggnn did not build and run properly." diff --git a/test/db/module_tests/test_autocomputed_squidpy.py b/test/db/module_tests/test_autocomputed_squidpy.py index 5a5e560b8..964e0b1e4 100644 --- a/test/db/module_tests/test_autocomputed_squidpy.py +++ b/test/db/module_tests/test_autocomputed_squidpy.py @@ -26,7 +26,8 @@ def check_records(feature_values): rows = [(row[0], row[1], round6(row[2])) for row in feature_values] missing = set(get_expected_records()).difference(rows) if len(missing) > 0: - raise ValueError(f'Expected to find records: {missing}') + raise ValueError(f'Expected to find records: {sorted(missing)}\nGot: {sorted(rows)}') + print('All expected records found.') unexpected = set(rows).difference(get_expected_records()) if len(unexpected) > 0: diff --git a/test/workflow/module_tests/check_proximity_metric_values.py b/test/workflow/module_tests/check_proximity_metric_values.py index 816c23eaf..6ced154e0 100644 --- a/test/workflow/module_tests/check_proximity_metric_values.py +++ b/test/workflow/module_tests/check_proximity_metric_values.py @@ -5,39 +5,47 @@ from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker -def lookup_all_cell_phenotype_signatures(database_config_file): +def lookup_all_cell_phenotype_signatures(database_config_file) -> dict[str, set[tuple[str, str]]]: with DatabaseConnectionMaker(database_config_file) as dcm: connection = dcm.get_connection() cursor = connection.cursor() cursor.execute(''' SELECT - cpc.cell_phenotype, cs.symbol, cpc.polarity + cp.symbol AS phenotype, + cs.symbol AS channel, + cpc.polarity FROM cell_phenotype_criterion cpc - JOIN chemical_species cs ON cs.identifier=cpc.marker - ORDER BY cpc.cell_phenotype + JOIN cell_phenotype cp + ON cp.identifier=cpc.cell_phenotype + JOIN chemical_species cs + ON cs.identifier=cpc.marker + ORDER BY phenotype ; ''') rows = cursor.fetchall() cursor.close() - df = pd.DataFrame(rows, columns=['cell phenotype', 'symbol', 'polarity']) - def extract_signature(group): + df = pd.DataFrame(rows, columns=['phenotype', 'channel', 'polarity']) + def extract_signature(group: pd.DataFrame) -> set[tuple[str, str]]: return set( - (row['symbol'], '+' if row['polarity'] == 'positive' else '-') - for i, row in group.iterrows() + (row['channel'], '+' if row['polarity'] == 'positive' else '-') + for _, row in group.iterrows() ) return { - str(cell_phenotype) : extract_signature(group) - for cell_phenotype, group in df.groupby('cell phenotype') + phenotype : extract_signature(group) + for phenotype, group in df.groupby('phenotype') } -def retrieve_cell_phenotype_identifier(description_string, lookup): +def retrieve_cell_phenotype_identifier( + description_string: str, + lookup: dict[str, set[tuple[str, str]]], +) -> str: signature = set((re.sub(r'[\+\-]', '', token), re.search(r'[\+\-]', token).group(0)) for token in description_string.split(' ')) if len(signature) == 1: return list(signature)[0][0] - for key, value in lookup.items(): + for phenotype, value in lookup.items(): if value == signature: - return f'cell_phenotype {key}' + return phenotype raise KeyError(f'Could not figure out {description_string}. ' f'Looked for {signature} in values of: {lookup}') diff --git a/test/workflow/unit_tests/test_feature_matrix_extraction.py b/test/workflow/unit_tests/test_feature_matrix_extraction.py index adea04f6b..61eb2c3c9 100644 --- a/test/workflow/unit_tests/test_feature_matrix_extraction.py +++ b/test/workflow/unit_tests/test_feature_matrix_extraction.py @@ -1,80 +1,73 @@ -import json import sys -import pandas as pd +from pandas import read_csv, DataFrame -from spatialprofilingtoolbox.db.feature_matrix_extractor import FeatureMatrixExtractor +from spatialprofilingtoolbox.db.feature_matrix_extractor import ( + FeatureMatrixExtractor, + MatrixBundle, +) -def get_study(bundle): - study_name = 'Melanoma intralesional IL2' - if not study_name in bundle.keys(): - print(f'Missing study: {study_name}') +def test_sample_set(study: dict[str, MatrixBundle]): + if study.keys() != set(['lesion 0_1', 'lesion 6_1']): + print(f'Wrong sample set: {list(study.keys())}') sys.exit(1) - return bundle[study_name] -def test_sample_set(study): - if study['feature matrices'].keys() != set(['lesion 0_1', 'lesion 6_1']): - print(f'Wrong sample set: {list(study["feature matrices"].keys())}') +def test_one_sample_set(study: dict[str, MatrixBundle]): + if study.keys() != set(['lesion 6_1']): + print(f'Wrong sample set: {list(study.keys())}') sys.exit(1) -def test_one_sample_set(study): - if study['feature matrices'].keys() != set(['lesion 6_1']): - print(f'Wrong sample set: {list(study["feature matrices"].keys())}') - sys.exit(1) - - -def test_feature_matrix_schemas(study): - for specimen, sample in study['feature matrices'].items(): - df = sample['dataframe'] - if not all(f'F{i}' in df.columns for i in range(26)): - print(f'Missing some columns in dataframe (case "{specimen}"): ') - print(df.to_string(index=False)) - sys.exit(1) - if df.shape != (100, 28): - print(f'Wrong number of rows or columns: {df.shape}') +def test_feature_matrix_schemas(study: dict[str, MatrixBundle]): + for specimen, sample in study.items(): + df = sample.dataframe + if df.shape != (100, 32): + print(f'Wrong number of rows or columns: {df.shape} != (100, 32)') sys.exit(1) -def show_example_feature_matrix(study): +def show_example_feature_matrix(study: dict[str, MatrixBundle]): specimen = 'lesion 0_1' - df = study['feature matrices'][specimen]['dataframe'] + df = study[specimen].dataframe print(f'Example feature matrix, for specimen {specimen}:') print(df.to_string(index=False)) print('') -def test_channels(study): - channels = study['channel symbols by column name'] - known = ['B2M', 'B7H3', 'CD14', 'CD163', 'CD20', 'CD25', 'CD27', 'CD3', 'CD4', 'CD56', 'CD68', +def test_channels(study: dict[str, MatrixBundle]): + columns = list(study.values())[0].dataframe.columns + channels = set(name[2:] for name in columns[columns.str.startswith('C ')]) + known = {'B2M', 'B7H3', 'CD14', 'CD163', 'CD20', 'CD25', 'CD27', 'CD3', 'CD4', 'CD56', 'CD68', 'CD8', 'DAPI', 'FOXP3', 'IDO1', 'KI67', 'LAG3', 'MHCI', 'MHCII', 'MRC1', 'PD1', - 'PDL1', 'S100B', 'SOX10', 'TGM2', 'TIM3'] - if set(channels.values()) != set(known): - print(f'Wrong channel set: {list(channels.values())}') + 'PDL1', 'S100B', 'SOX10', 'TGM2', 'TIM3'} + if channels != known: + print(f'Wrong channel set: {channels.tolist()}') sys.exit(1) -def test_expression_vectors(study): - def create_column_name(channels, channel_num): - return channels[channel_num] + '_Positive' +def test_expression_vectors(study: dict[str, MatrixBundle]): + for specimen in study.keys(): + df = study[specimen].dataframe + + print('Dataframe: ' + str(specimen)) + print(df) - for specimen in study['feature matrices'].keys(): - df = study['feature matrices'][specimen]['dataframe'] expression_vectors = sorted([ - tuple(row[f'F{i}'] for i in range(26)) + tuple(row[row.index.str.startswith('C ')].tolist()) for _, row in df.iterrows() ]) filenames = {'lesion 0_1': '0.csv', 'lesion 6_1': '3.csv'} cells_filename = filenames[specimen] - reference = pd.read_csv( + reference = read_csv( f'../test_data/adi_preprocessed_tables/dataset1/{cells_filename}', sep=',') - channels = study['channel symbols by column name'] + columns = list(study.values())[0].dataframe.columns + channels = [name[2:] for name in columns[columns.str.startswith('C ')]] expected_expression_vectors = sorted([ - tuple(row[create_column_name(channels, f'F{i}')] for i in range(26)) + tuple(row[f'{channel}_Positive'] for channel in channels) for _, row in reference.iterrows() ]) @@ -89,26 +82,24 @@ def create_column_name(channels, channel_num): print('Expression vector sets are as expected.') -def test_expression_vectors_continuous(study): - def create_column_name(channels, channel_num): - return channels[channel_num] + '_Intensity' - - for specimen in study['feature matrices'].keys(): - df = study['feature matrices'][specimen]['continuous dataframe'] +def test_expression_vectors_continuous(study: dict[str, MatrixBundle]): + for specimen in study.keys(): + df = study[specimen].continuous_dataframe print(df.head()) expression_vectors = sorted([ - tuple(row[f'F{i}'] for i in range(26)) + tuple(row[row.index.str.startswith('C ')].tolist()) for _, row in df.iterrows() ]) filenames = {'lesion 0_1': '0.csv', 'lesion 6_1': '3.csv'} cells_filename = filenames[specimen] - reference = pd.read_csv( + reference = read_csv( f'../test_data/adi_preprocessed_tables/dataset1/{cells_filename}', sep=',') - channels = study['channel symbols by column name'] + columns = list(study.values())[0].dataframe.columns + channels = [name[2:] for name in columns[columns.str.startswith('C ')]] expected_expression_vectors = sorted([ - tuple(row[create_column_name(channels, f'F{i}')] for i in range(26)) + tuple(row[f'{channel}_Intensity'] for channel in channels) for _, row in reference.iterrows() ]) @@ -123,9 +114,9 @@ def create_column_name(channels, channel_num): print('Expression vector sets are as expected.') -def test_stratification(study): - df = study['sample cohorts']['assignments'] - strata = study['sample cohorts']['strata'] +def test_stratification(study: dict[str, DataFrame]): + df = study['assignments'] + strata = study['strata'] print('Sample cohorts:') print(df.to_string(index=False)) print(strata.to_string(index=False)) @@ -138,34 +129,23 @@ def test_stratification(study): if __name__ == '__main__': extractor = FeatureMatrixExtractor(database_config_file='../db/.spt_db.config.container') - matrix_bundle = extractor.extract() - test_study = get_study(matrix_bundle) + study_name = 'Melanoma intralesional IL2' + test_study = extractor.extract(study=study_name) test_sample_set(test_study) test_feature_matrix_schemas(test_study) show_example_feature_matrix(test_study) test_channels(test_study) test_expression_vectors(test_study) - test_stratification(test_study) + test_stratification(extractor.extract_cohorts(study_name)) - one_sample_bundle = extractor.extract(specimen='lesion 6_1') - one_sample_study = get_study(one_sample_bundle) + one_sample_study = extractor.extract(specimen='lesion 6_1') test_one_sample_set(one_sample_study) test_feature_matrix_schemas(one_sample_study) test_channels(one_sample_study) test_expression_vectors(one_sample_study) - one_sample_bundle_continuous = extractor.extract(specimen='lesion 6_1', continuous_also=True) - one_sample_study_continuous = get_study(one_sample_bundle_continuous) + one_sample_study_continuous = extractor.extract(specimen='lesion 6_1', continuous_also=True) test_one_sample_set(one_sample_study_continuous) test_feature_matrix_schemas(one_sample_study_continuous) test_channels(one_sample_study_continuous) test_expression_vectors_continuous(one_sample_study_continuous) - - FeatureMatrixExtractor.redact_dataframes(matrix_bundle) - print('\nMetadata "bundle" with dataframes removed:') - print(json.dumps(matrix_bundle, indent=2)) - - print('\n... and in the one-sample case:') - FeatureMatrixExtractor.redact_dataframes(one_sample_bundle) - print('\nMetadata "bundle" with dataframes removed:') - print(json.dumps(one_sample_bundle, indent=2))