diff --git a/pyproject.toml.unversioned b/pyproject.toml.unversioned index e2a10e67d..e6c058400 100644 --- a/pyproject.toml.unversioned +++ b/pyproject.toml.unversioned @@ -116,6 +116,7 @@ packages = [ "spatialprofilingtoolbox.ondemand.providers", "spatialprofilingtoolbox.ondemand.scripts", "spatialprofilingtoolbox.db", + "spatialprofilingtoolbox.db.accessors", "spatialprofilingtoolbox.db.exchange_data_formats", "spatialprofilingtoolbox.db.scripts", "spatialprofilingtoolbox.db.data_model", diff --git a/spatialprofilingtoolbox/apiserver/app/main.py b/spatialprofilingtoolbox/apiserver/app/main.py index 0c2a5f967..29ae61c41 100644 --- a/spatialprofilingtoolbox/apiserver/app/main.py +++ b/spatialprofilingtoolbox/apiserver/app/main.py @@ -21,6 +21,7 @@ PhenotypeCriteria, PhenotypeCounts, UnivariateMetricsComputationResult, + CGGNNImportanceRank, ) from spatialprofilingtoolbox.db.exchange_data_formats.metrics import UMAPChannel from spatialprofilingtoolbox.db.querying import query @@ -222,6 +223,14 @@ async def request_spatial_metrics_computation_custom_phenotypes( # pylint: disa return get_squidpy_metrics(study, markers, feature_class, radius=radius) +@app.get("/request-cggnn-metrics/") +async def request_cggnn_metrics( + study: ValidStudy, +) -> list[CGGNNImportanceRank]: + """Importance scores as calculated by cggnn.""" + return query().get_cggnn_metrics(study) + + def get_proximity_metrics( study: str, markers: list[list[str]], diff --git a/spatialprofilingtoolbox/cggnn/scripts/run_sql.py b/spatialprofilingtoolbox/cggnn/scripts/run_sql.py deleted file mode 100644 index 8244df941..000000000 --- a/spatialprofilingtoolbox/cggnn/scripts/run_sql.py +++ /dev/null @@ -1,147 +0,0 @@ -"Run through the entire SPT CG-GNN pipeline, starting from a SPT SQL database." - -from argparse import ArgumentParser -from spatialprofilingtoolbox.standalone_utilities.module_load_error import SuggestExtrasException -try: - from cggnn.run_all import run_pipeline -except ModuleNotFoundError as e: - SuggestExtrasException(e, 'cggnn') - - -def parse_arguments(): - """Process command line arguments.""" - parser = ArgumentParser( - prog='spt cggnn run_sql', - description='Create cell graphs from SPT SQL tables, train a graph neural network on ' - 'them, and save resultant model, metrics, and visualizations (if requested) to file.' - ) - parser.add_argument( - '--study', - type=str, - help='Name of the study to query data for in SPT.', - required=True - ) - parser.add_argument( - '--host', - type=str, - help='Host SQL server IP.', - required=True - ) - parser.add_argument( - '--dbname', - type=str, - help='Database in SQL server to query.', - required=True - ) - parser.add_argument( - '--user', - type=str, - help='Server login username.', - required=True - ) - parser.add_argument( - '--password', - type=str, - help='Server login password.', - required=True - ) - parser.add_argument( - '--validation_data_percent', - type=int, - help='Percentage of data to use as validation data. Set to 0 if you want to do k-fold ' - 'cross-validation later. (Training percentage is implicit.) Default 15%.', - default=15, - required=False - ) - parser.add_argument( - '--test_data_percent', - type=int, - help='Percentage of data to use as the test set. (Training percentage is implicit.) ' - 'Default 15%.', - default=15, - required=False - ) - parser.add_argument( - '--roi_side_length', - type=int, - help='Side length in pixels of the ROI areas we wish to generate.', - default=600, - required=False - ) - parser.add_argument( - '--target_column', - type=str, - help='Phenotype column to use to build ROIs around.', - default=None, - required=False - ) - parser.add_argument( - '-b', - '--batch_size', - type=int, - help='batch size.', - default=1, - required=False - ) - parser.add_argument( - '--epochs', - type=int, - help='epochs.', - default=10, - required=False - ) - parser.add_argument( - '-l', - '--learning_rate', - type=float, - help='learning rate.', - default=10e-3, - required=False - ) - parser.add_argument( - '-k', - '--k_folds', - type=int, - help='Folds to use in k-fold cross validation. 0 means don\'t use k-fold cross validation ' - 'unless no validation dataset is provided, in which case k defaults to 3.', - required=False, - default=0 - ) - parser.add_argument( - '--explainer', - type=str, - help='Which explainer type to use.', - default='pp', - required=False - ) - parser.add_argument( - '--merge_rois', - help='Merge ROIs together by specimen.', - action='store_true' - ) - parser.add_argument( - '--prune_misclassified', - help='Remove entries for misclassified cell graphs when calculating separability scores.', - action='store_true' - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_arguments() - run_pipeline(args.study, - args.host, - args.dbname, - args.user, - args.password, - args.validation_data_percent, - args.test_data_percent, - args.roi_side_length, - args.target_column, - args.batch_size, - args.epochs, - args.learning_rate, - args.k_folds, - args.explainer, - args.merge_rois, - args.prune_misclassified) diff --git a/spatialprofilingtoolbox/cggnn/scripts/upload.py b/spatialprofilingtoolbox/cggnn/scripts/upload.py new file mode 100644 index 000000000..cfd79aed7 --- /dev/null +++ b/spatialprofilingtoolbox/cggnn/scripts/upload.py @@ -0,0 +1,44 @@ +"""Upload importance score output from a cg-gnn instance to the local db.""" + +from argparse import ArgumentParser + +from pandas import read_csv + +from spatialprofilingtoolbox.db.database_connection import DatabaseConnectionMaker +from spatialprofilingtoolbox.db.importance_score_transcriber import transcribe_importance + + +def parse_arguments(): + """Process command line arguments.""" + parser = ArgumentParser( + prog='spt cggnn upload-importances', + description='Save cell importance scores as defined by cggnn to the database.' + ) + parser.add_argument( + '--spt_db_config_location', + type=str, + help='File location for SPT DB config file.', + required=True + ) + parser.add_argument( + '--importances_csv_path', + type=str, + help='File location for the importances CSV.', + required=True + ) + parser.add_argument( + '--cohort_stratifier', + type=str, + help='Name of the classification cohort variable the GNN was trained on.', + default='', + required=False + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + df = read_csv(args.importances_csv_path, index_col=0) + connection = DatabaseConnectionMaker(args.spt_db_config_location).get_connection() + transcribe_importance(df, connection, cohort_stratifier=args.cohort_stratifier) + connection.close() diff --git a/spatialprofilingtoolbox/db/accessors/__init__.py b/spatialprofilingtoolbox/db/accessors/__init__.py new file mode 100644 index 000000000..49466aaa0 --- /dev/null +++ b/spatialprofilingtoolbox/db/accessors/__init__.py @@ -0,0 +1,7 @@ +"""Convenience classes for accessing various data in the database.""" + +from spatialprofilingtoolbox.db.accessors.cggnn import CGGNNAccess +from spatialprofilingtoolbox.db.accessors.fractions_and_associations import FractionsAccess +from spatialprofilingtoolbox.db.accessors.phenotypes import PhenotypesAccess +from spatialprofilingtoolbox.db.accessors.study import StudyAccess +from spatialprofilingtoolbox.db.accessors.umap import UMAPAccess diff --git a/spatialprofilingtoolbox/db/accessors/cggnn.py b/spatialprofilingtoolbox/db/accessors/cggnn.py new file mode 100644 index 000000000..aaf3e3e08 --- /dev/null +++ b/spatialprofilingtoolbox/db/accessors/cggnn.py @@ -0,0 +1,36 @@ +"""Convenience access of cg-gnn metrics.""" + +from spatialprofilingtoolbox import get_feature_description +from spatialprofilingtoolbox.db.accessors.study import StudyAccess +from spatialprofilingtoolbox.db.database_connection import SimpleReadOnlyProvider +from spatialprofilingtoolbox.db.exchange_data_formats.metrics import CGGNNImportanceRank + + +class CGGNNAccess(SimpleReadOnlyProvider): + """Access to cg-gnn features from database.""" + + def get_metrics(self, study: str) -> list[CGGNNImportanceRank]: + """Get cg-gnn metrics for this study. + + Returns + ------- + list[CGGNNImportanceRank] + List of (histological structure ID, importance rank) tuples. + """ + components = StudyAccess(self.cursor).get_study_components(study) + self.cursor.execute(f''' + SELECT + qfv.subject, + qfv.value + FROM quantitative_feature_value qfv + JOIN feature_specification fs + ON fs.identifier=qfv.feature + WHERE fs.derivation_method='{get_feature_description("gnn importance score")}' + AND fs.study='{components.analysis}' + ; + ''') + rows = self.cursor.fetchall() + return [CGGNNImportanceRank( + histological_structure_id=int(row[0]), + rank=int(row[1]) + ) for row in rows] diff --git a/spatialprofilingtoolbox/db/fractions_and_associations.py b/spatialprofilingtoolbox/db/accessors/fractions_and_associations.py similarity index 98% rename from spatialprofilingtoolbox/db/fractions_and_associations.py rename to spatialprofilingtoolbox/db/accessors/fractions_and_associations.py index e707353ad..65c937225 100644 --- a/spatialprofilingtoolbox/db/fractions_and_associations.py +++ b/spatialprofilingtoolbox/db/accessors/fractions_and_associations.py @@ -2,7 +2,7 @@ from spatialprofilingtoolbox.db.exchange_data_formats.metrics import CellFractionsSummary from spatialprofilingtoolbox.db.exchange_data_formats.metrics import CellFractionsAverage from spatialprofilingtoolbox.db.exchange_data_formats.metrics import FeatureAssociationTest -from spatialprofilingtoolbox.db.study_access import StudyAccess +from spatialprofilingtoolbox.db.accessors.study import StudyAccess from spatialprofilingtoolbox.db.cohorts import _replace_stratum_identifiers from spatialprofilingtoolbox import get_feature_description from spatialprofilingtoolbox.db.database_connection import SimpleReadOnlyProvider diff --git a/spatialprofilingtoolbox/db/phenotypes.py b/spatialprofilingtoolbox/db/accessors/phenotypes.py similarity index 98% rename from spatialprofilingtoolbox/db/phenotypes.py rename to spatialprofilingtoolbox/db/accessors/phenotypes.py index 6841f5eb1..7aa491ee5 100644 --- a/spatialprofilingtoolbox/db/phenotypes.py +++ b/spatialprofilingtoolbox/db/accessors/phenotypes.py @@ -1,7 +1,7 @@ """Convenience accessors/manipulators for phenotype data.""" from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeSymbol from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria -from spatialprofilingtoolbox.db.study_access import StudyAccess +from spatialprofilingtoolbox.db.accessors.study import StudyAccess from spatialprofilingtoolbox.db.database_connection import SimpleReadOnlyProvider diff --git a/spatialprofilingtoolbox/db/study_access.py b/spatialprofilingtoolbox/db/accessors/study.py similarity index 100% rename from spatialprofilingtoolbox/db/study_access.py rename to spatialprofilingtoolbox/db/accessors/study.py diff --git a/spatialprofilingtoolbox/db/umap.py b/spatialprofilingtoolbox/db/accessors/umap.py similarity index 100% rename from spatialprofilingtoolbox/db/umap.py rename to spatialprofilingtoolbox/db/accessors/umap.py diff --git a/spatialprofilingtoolbox/db/database_connection.py b/spatialprofilingtoolbox/db/database_connection.py index 16be0c6ba..dc8bc4d5f 100644 --- a/spatialprofilingtoolbox/db/database_connection.py +++ b/spatialprofilingtoolbox/db/database_connection.py @@ -166,6 +166,7 @@ class (QueryCursor) newly provides on each invocation. retrieve_signature_of_phenotype: Callable get_umaps_low_resolution: Callable get_umap: Callable + get_cggnn_metrics: Callable def __init__(self, query_handler: Type): self.query_handler = query_handler diff --git a/spatialprofilingtoolbox/db/exchange_data_formats/metrics.py b/spatialprofilingtoolbox/db/exchange_data_formats/metrics.py index c1c5173d5..ea50aab5f 100644 --- a/spatialprofilingtoolbox/db/exchange_data_formats/metrics.py +++ b/spatialprofilingtoolbox/db/exchange_data_formats/metrics.py @@ -85,3 +85,9 @@ class UMAPChannel(BaseModel): """ channel: str base64_png: str + + +class CGGNNImportanceRank(BaseModel): + """The importance ranking of histological structures in a study.""" + histological_structure_id: int + rank: int diff --git a/spatialprofilingtoolbox/db/feature_matrix_extractor.py b/spatialprofilingtoolbox/db/feature_matrix_extractor.py index 6e25775d5..9c73b7d5c 100644 --- a/spatialprofilingtoolbox/db/feature_matrix_extractor.py +++ b/spatialprofilingtoolbox/db/feature_matrix_extractor.py @@ -10,12 +10,14 @@ from spatialprofilingtoolbox import DatabaseConnectionMaker from spatialprofilingtoolbox.db.exchange_data_formats.metrics import PhenotypeCriteria -from spatialprofilingtoolbox.db.phenotypes import PhenotypesAccess +from spatialprofilingtoolbox.db.accessors import ( + StudyAccess, + PhenotypesAccess, +) from spatialprofilingtoolbox.db.stratification_puller import ( StratificationPuller, Stratification, ) -from spatialprofilingtoolbox.db.study_access import StudyAccess from spatialprofilingtoolbox.workflow.common.structure_centroids_puller import \ StructureCentroidsPuller from spatialprofilingtoolbox.workflow.common.sparse_matrix_puller import SparseMatrixPuller diff --git a/spatialprofilingtoolbox/db/importance_score_transcriber.py b/spatialprofilingtoolbox/db/importance_score_transcriber.py index 4b7017916..9cedf1af6 100644 --- a/spatialprofilingtoolbox/db/importance_score_transcriber.py +++ b/spatialprofilingtoolbox/db/importance_score_transcriber.py @@ -17,20 +17,20 @@ def transcribe_importance( df: DataFrame, connection: Connection, per_specimen_selection_number: int = 1000, - cohort_stratifier: str='default sample stratification', + cohort_stratifier: str = '', ) -> None: r"""Upload importance score output from a cg-gnn instance to the local db. Parameters: df: DataFrame One column, `importance_score`, indexed by `histological_structure`. - cohort_stratifier: str - Name of the classification cohort variable the GNN was trained on to produce - the importance score. connection: psycopg2.extensions.connection per_specimen_selection_number: int Grab this many of the most important cells from each specimen (or fewer if there aren't enough cells in the specimen). + cohort_stratifier: str = '' + Name of the classification cohort variable the GNN was trained on to produce + the importance score. """ study = _get_referenced_study(connection, df) indicator: str = 'cell importance' diff --git a/spatialprofilingtoolbox/db/querying.py b/spatialprofilingtoolbox/db/querying.py index 206e191d2..ebb1a8486 100644 --- a/spatialprofilingtoolbox/db/querying.py +++ b/spatialprofilingtoolbox/db/querying.py @@ -13,12 +13,17 @@ Channel, PhenotypeCriteria, UMAPChannel, + CGGNNImportanceRank, ) from spatialprofilingtoolbox.db.cohorts import get_cohort_identifiers -from spatialprofilingtoolbox.db.study_access import StudyAccess -from spatialprofilingtoolbox.db.fractions_and_associations import FractionsAccess -from spatialprofilingtoolbox.db.phenotypes import PhenotypesAccess -from spatialprofilingtoolbox.db.umap import UMAPAccess +from spatialprofilingtoolbox.db.accessors import ( + CGGNNAccess, + StudyAccess, + FractionsAccess, + PhenotypesAccess, + UMAPAccess, +) + class QueryHandler: """Handle simple queries to the database.""" @@ -108,6 +113,10 @@ def get_umaps_low_resolution(cls, cursor, study: str) -> list[UMAPChannel]: def get_umap(cls, cursor, study: str, channel: str) -> UMAPChannel: return UMAPAccess(cursor).get_umap_row_for_channel(study, channel) + @classmethod + def get_cggnn_metrics(cls, cursor, study: str) -> list[CGGNNImportanceRank]: + return CGGNNAccess(cursor).get_metrics(study) + def query() -> QueryCursor: return QueryCursor(QueryHandler) diff --git a/test/apiserver/unit_tests/test_lookup_study_from_specimen.py b/test/apiserver/unit_tests/test_lookup_study_from_specimen.py index 946b595e3..881d7e0e0 100644 --- a/test/apiserver/unit_tests/test_lookup_study_from_specimen.py +++ b/test/apiserver/unit_tests/test_lookup_study_from_specimen.py @@ -3,7 +3,7 @@ import os from spatialprofilingtoolbox.db.database_connection import DBCursor -from spatialprofilingtoolbox.db.study_access import StudyAccess +from spatialprofilingtoolbox.db.accessors import StudyAccess def test_lookup(): environment = {