From 2bb856ce68c72b71c9c01b294f9103d46524a11d Mon Sep 17 00:00:00 2001 From: shitohana <43905117+shitohana@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:48:00 +0300 Subject: [PATCH] Clusters added into master (#6) * flank regions fix * new labels * draw to cluster * __init__.py Clustering added BismarkPlot.py Clustering done and confidence bands console_metagene.py confidence bands added index.rst new classes and methods added in documentation --- docs/index.rst | 1 + src/bismarkplot/BismarkPlot.py | 574 +++++++++++++++++++--------- src/bismarkplot/__init__.py | 2 +- src/bismarkplot/console_metagene.py | 9 +- 4 files changed, 410 insertions(+), 176 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 25945dc..190fe1b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,7 @@ This page gives an overview of all public objects, functions and methods. bismarkplot.MetageneFiles bismarkplot.Genome bismarkplot.ChrLevels + bismarkplot.Clustering bismarkplot.BismarkPlot.BismarkBase bismarkplot.BismarkPlot.LinePlot diff --git a/src/bismarkplot/BismarkPlot.py b/src/bismarkplot/BismarkPlot.py index c542015..856165b 100644 --- a/src/bismarkplot/BismarkPlot.py +++ b/src/bismarkplot/BismarkPlot.py @@ -1,5 +1,6 @@ import gzip import re +from functools import cache from multiprocessing import cpu_count from os.path import getsize @@ -14,17 +15,20 @@ from scipy.signal import savgol_filter from scipy.spatial.distance import pdist -import scipy.cluster.hierarchy as hclust +from scipy.cluster.hierarchy import linkage, leaves_list +from scipy import stats from pandas import DataFrame as pdDataFrame from pyreadr import write_rds +from dynamicTreeCut import cutreeHybrid + def remove_extension(path): re.sub("\.[^./]+$", "", path) -def approx_batch_num(path, batch_size, check_lines=100000): +def approx_batch_num(path, batch_size, check_lines=10000): size = getsize(path) length = 0 @@ -72,6 +76,7 @@ def from_gff(cls, file: str): dtypes={'start': pl.Int32, 'end': pl.Int32, 'chr': pl.Utf8} ).select(['chr', 'type', 'start', 'end', 'strand']) + print(f"Genome read from {file}") return cls(genes) def gene_body(self, min_length: int = 4000, flank_length: int = 2000) -> pl.DataFrame: @@ -121,6 +126,18 @@ def near_TSS(self, min_length: int = 4000, flank_length: int = 2000): :param flank_length: length of the flanking region. :return: :class:`pl.LazyFrame` with genes and their flanking regions. """ + upstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('upstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('upstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) + ) + gene_type = "gene" genes = self.__filter_genes( self.genome, gene_type, min_length, flank_length) @@ -135,13 +152,7 @@ def near_TSS(self, min_length: int = 4000, flank_length: int = 2000): ]) .explode(['start', 'upstream']) .with_columns([ - (pl.col('start') - pl.when( - pl.col('upstream') >= flank_length - ) - .then(flank_length) - .otherwise( - (pl.col('upstream') - pl.col('upstream') % 2) // 2 - )).alias('upstream'), + (pl.col('start') - upstream_length).alias('upstream'), (pl.col("start") + flank_length).alias("end") ]) .with_columns(pl.col("end").alias("downstream")) @@ -157,6 +168,19 @@ def near_TES(self, min_length: int = 4000, flank_length: int = 2000): :param flank_length: length of the flanking region. :return: :class:`pl.LazyFrame` with genes and their flanking regions. """ + + downstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('downstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('downstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) + ) + gene_type = "gene" genes = self.__filter_genes( self.genome, gene_type, min_length, flank_length) @@ -171,13 +195,7 @@ def near_TES(self, min_length: int = 4000, flank_length: int = 2000): ]) .explode(['end', 'downstream']) .with_columns([ - (pl.col('end') + pl.when( - pl.col('downstream') >= flank_length - ) - .then(flank_length) - .otherwise( - (pl.col('downstream') - pl.col('downstream') % 2) // 2 - )).alias('downstream'), + (pl.col('end') + downstream_length).alias('downstream'), (pl.col("end") - flank_length).alias("start") ]) .with_columns(pl.col("start").alias("upstream")) @@ -203,45 +221,62 @@ def other(self, gene_type: str, min_length: int = 1000, flank_length: int = 100) def __filter_genes(genes, gene_type, min_length, flank_length): genes = genes.filter(pl.col('type') == gene_type).drop('type') + # filter genes, which start < flank_length if flank_length > 0: genes = genes.filter(pl.col('start') > flank_length) + # filter genes which don't pass length threshold if min_length > 0: - genes = genes.filter(pl.col('end') - pl.col('start') > min_length) + genes = genes.filter((pl.col('end') - pl.col('start')) > min_length) return genes @staticmethod def __trim_genes(genes, flank_length) -> pl.LazyFrame: + # upstream shift + # calculates length to previous gene on same chr_strand + length_before = (pl.col('start').shift(-1) - pl.col('end')).shift(1).fill_null(flank_length) + # downstream shift + # calculates length to next gene on same chr_strand + length_after = (pl.col('start').shift(-1) - pl.col('end')).fill_null(flank_length) + + upstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('upstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('upstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) + ) + + downstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('downstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('downstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) + ) + return ( genes .groupby(['chr', 'strand'], maintain_order=True).agg([ - pl.col('start'), pl.col('end'), - # upstream shift - (pl.col('start').shift(-1) - pl.col('end')).shift(1) - .fill_null(flank_length) - .alias('upstream'), - # downstream shift - (pl.col('start').shift(-1) - pl.col('end')) - .fill_null(flank_length) - .alias('downstream') + pl.col('start'), + pl.col('end'), + length_before.alias('upstream'), + length_after.alias('downstream') ]) .explode(['start', 'end', 'upstream', 'downstream']) .with_columns([ - (pl.col('start') - pl.when( - pl.col('upstream') >= flank_length - ) - .then(flank_length) - .otherwise( - (pl.col('upstream') - pl.col('upstream') % 2) // 2 - )).alias('upstream'), - - (pl.col('end') + pl.when( - pl.col('downstream') >= flank_length - ) - .then(flank_length) - .otherwise( - (pl.col('downstream') - pl.col('downstream') % 2) // 2 - )).alias('downstream') + # calculates length of region + (pl.col('start') - upstream_length).alias('upstream'), + # calculates length of region + (pl.col('end') + downstream_length).alias('downstream') ]) ) @@ -256,7 +291,7 @@ def __check_empty(genes): class BismarkBase: """ - Base class for :class:`Bismark` and plots. + Base class for :class:`Metagene` and plots. """ def __init__(self, bismark_df: pl.DataFrame, **kwargs): @@ -266,9 +301,9 @@ def __init__(self, bismark_df: pl.DataFrame, **kwargs): DataFrame Structure: +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ - | chr | strand | context | start | fragment | sum | count | + | chr | strand | context | gene | fragment | sum | count | +=================+=============+=====================+======================+==================+================+=========================================+ - | Categorical | Categorical | Categorical | Int32 | Int32 | Int32 | Int32 | + | Categorical | Categorical | Categorical | Categorical | Int32 | Int32 | Int32 | +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ | chromosome name | strand | methylation context | position of cytosine | fragment in gene | sum methylated | count of all cytosines in this position | +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ @@ -284,12 +319,12 @@ def __init__(self, bismark_df: pl.DataFrame, **kwargs): """ self.bismark: pl.DataFrame = bismark_df - self.upstream_windows: int | None = kwargs.get("upstream_windows") - self.downstream_windows: int | None = kwargs.get("downstream_windows") - self.gene_windows: int | None = kwargs.get("gene_windows") - self.plot_data: pl.DataFrame | None = kwargs.get("plot_data") - self.context: str | None = kwargs.get("context") - self.strand: str | None = kwargs.get("strand") + self.upstream_windows: int = kwargs.get("upstream_windows") + self.downstream_windows: int = kwargs.get("downstream_windows") + self.gene_windows: int = kwargs.get("gene_windows") + self.plot_data: pl.DataFrame = kwargs.get("plot_data") + self.context: str = kwargs.get("context") + self.strand: str = kwargs.get("strand") @property def metadata(self) -> dict: @@ -309,8 +344,8 @@ def save_rds(self, filename, compress: bool = False): """ Save Bismark DataFrame in Rds. - :param filename: path for file. - :param compress: whether to compress to gzip or not. + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. """ write_rds(filename, self.bismark.to_pandas(), compress="gzip" if compress else None) @@ -319,8 +354,8 @@ def save_tsv(self, filename, compress=False): """ Save Bismark DataFrame in TSV. - :param filename: path for file. - :param compress: whether to compress to gzip or not. + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. """ if compress: with gzip.open(filename + ".gz", "wb") as file: @@ -337,6 +372,136 @@ def __len__(self): return len(self.bismark) +class Clustering(BismarkBase): + """ + Class for clustering genes within sample + """ + + def __init__(self, bismark_df: pl.DataFrame, count_threshold=5, dist_method="euclidean", clust_method="average", **kwargs): + """ + :param bismark_df: :class:polars.DataFrame with genes data + :param count_threshold: Minimum counts per fragment + :param dist_method: Method for evaluating distance + :param clust_method: Method for hierarchical clustering + """ + super().__init__(bismark_df, **kwargs) + + grouped = ( + self.bismark.lazy() + .with_columns((pl.col("sum") / pl.col("count")).alias("density")) + .group_by(["chr", "strand", "gene", "context"]) + .agg([pl.col("density"), + pl.col("fragment"), + pl.sum("count").alias("gene_count"), + pl.count("fragment").alias("count")]) + ).collect() + + print(f"Starting with:\t{len(grouped)}") + + by_count = grouped.filter(pl.col("gene_count") > (count_threshold * pl.col("count"))) + + print(f"Left after count theshold filtration:\t{len(by_count)}") + + by_count = by_count.filter(pl.col("count") == self.total_windows) + + print(f"Left after empty windows filtration:\t{len(by_count)}") + + by_count = by_count.explode(["density", "fragment"]).drop(["gene_count", "count"]).fill_nan(0) + + unpivot = by_count.pivot( + index=["chr", "strand", "gene"], + values="density", + columns="fragment", + aggregate_function="sum" + ).select( + ["chr", "strand", "gene"] + list(map(str, range(self.total_windows))) + ).with_columns( + pl.col("gene").alias("label") + ) + + self.gene_labels = unpivot.with_columns(pl.col("label").cast(pl.Utf8))["label"].to_numpy() + self.matrix = unpivot[list(map(str, range(self.total_windows)))].to_numpy() + self.matrix = self.matrix[~np.isnan(self.matrix).any(axis=1), :] + + # dist matrix + print("Distances calculation") + self.dist = pdist(self.matrix, metric=dist_method) + # linkage matrix + print("Linkage calculation and minimizing distances") + self.linkage = linkage(self.dist, method=clust_method, optimal_ordering=True) + + self.order = leaves_list(self.linkage) + + self.tree = None + + @cache + def dynamicTreeCut(self, **kwargs) -> dict: + """ + Method for asigning genes into modules with dynamic tree cut algorithm. + + :param kwargs: all arguements for dynamicTreeCut + + :return: tree dictionary + """ + print("WARNING: dynamicTreeCut can take very long time to run") + return cutreeHybrid(self.linkage, self.dist, **kwargs) + + def draw( + self, + fig_axes: tuple = None, + title: str = None + ) -> Figure: + """ + Draws heat-map on given :class:`matplotlib.Axes` or makes them itself. + + :param fig_axes: Tuple with (fig, axes) from :meth:`matplotlib.plt.subplots`. + :param title: Title of the plot. + :return: + """ + if fig_axes is None: + plt.clf() + fig, axes = plt.subplots() + else: + fig, axes = fig_axes + + vmin = 0 + vmax = np.max(np.array(self.plot_data)) + + image = axes.imshow( + self.matrix[self.order, :], + interpolation="nearest", aspect='auto', + cmap=colormaps['cividis'], + vmin=vmin, vmax=vmax + ) + axes.set_title(title) + axes.set_xlabel('Position') + axes.set_ylabel('') + self.__add_flank_lines(axes) + axes.set_yticks([]) + plt.colorbar(image, ax=axes, label='Methylation density') + + return fig + + def __add_flank_lines(self, axes: plt.Axes): + """ + Add flank lines to the given axis (for line plot) + """ + x_ticks = [] + x_labels = [] + if self.upstream_windows > 0: + x_ticks.append(self.upstream_windows - .5) + x_labels.append('TSS') + if self.downstream_windows > 0: + x_ticks.append(self.gene_windows + self.upstream_windows - .5) + x_labels.append('TES') + + if x_ticks and x_labels: + axes.set_xticks(x_ticks) + axes.set_xticklabels(x_labels) + for tick in x_ticks: + axes.axvline(x=tick, linestyle='--', color='k', alpha=.3) + + class ChrLevels: def __init__(self, df: pl.DataFrame) -> None: self.bismark = df @@ -348,7 +513,7 @@ def __init__(self, df: pl.DataFrame) -> None: .agg([pl.sum("sum"), pl.sum("count")]) .with_columns((pl.col("sum") / pl.col("count")).alias("density")) ) - + @classmethod def from_file( cls, @@ -361,9 +526,9 @@ def from_file( """ Initialize ChrLevels with CX_report file - :param file: path to file - :param chr_min_length: minimum length of chromosome to be analyzed - :param window_length: length of windows in bp + :param file: Path to file + :param chr_min_length: Minimum length of chromosome to be analyzed + :param window_length: Length of windows in bp :param cpu: How many cores to use. Uses every physical core by default :param batch_size: Number of rows to read by one CPU core """ @@ -373,7 +538,8 @@ def from_file( new_columns=['chr', 'position', 'strand', 'count_m', 'count_um', 'context'], columns=[0, 1, 2, 3, 4, 5], - batch_size=batch_size + batch_size=batch_size, + n_threads=cpu ) read_approx = approx_batch_num(file, batch_size) read_batches = 0 @@ -431,12 +597,12 @@ def save_plot_rds(self, path, compress: bool = False): """ write_rds(path, self.plot_data.to_pandas(), compress="gzip" if compress else None) - + def filter(self, context: str = None, strand: str = None, chr: str = None): """ - :param context: methylation context (CG, CHG, CHH) to filter (only one). - :param strand: strand to filter (+ or -). - :param chr: chromosome name to filter. + :param context: Methylation context (CG, CHG, CHH) to filter (only one). + :param strand: Strand to filter (+ or -). + :param chr: Chromosome name to filter. :return: Filtered :class:`Bismark`. """ context_filter = self.bismark["context"] == context if context is not None else True @@ -447,7 +613,7 @@ def filter(self, context: str = None, strand: str = None, chr: str = None): return self else: return self.__class__(self.bismark.filter(context_filter & strand_filter & chr_filter)) - + def draw( self, fig_axes: tuple = None, @@ -456,7 +622,7 @@ def draw( linewidth: float = 1.0, linestyle: str = '-', ) -> Figure: - + if fig_axes is None: fig, axes = plt.subplots() else: @@ -499,7 +665,7 @@ def draw( fig.set_size_inches(12, 5) return fig - + class Metagene(BismarkBase): """ @@ -520,7 +686,7 @@ def from_file( Constructor from Bismark coverage2cytosine output. :param cpu: How many cores to use. Uses every physical core by default - :param file: path to bismark genomeWide report + :param file: Path to bismark genomeWide report :param genome: polars.Dataframe with gene ranges :param upstream_windows: Number of windows flank regions to split :param downstream_windows: Number of windows flank regions to split @@ -551,77 +717,110 @@ def __read_bismark_batches( gene_windows: int = 2000, downstream_windows: int = 500, batch_size: int = 10 ** 7, - cpu: int = cpu_count(), - columns=None + cpu: int = cpu_count() ) -> pl.DataFrame: # enable string cache for categorical comparison pl.enable_string_cache(True) - GENOME_COLUMNS = [pl.col('strand').cast(pl.Categorical), - pl.col('chr').cast(pl.Categorical)] - DF_COLUMNS = [pl.col('position').cast(pl.Int32), - pl.col('chr').cast(pl.Categorical), - pl.col('strand').cast(pl.Categorical), - pl.col('context').cast(pl.Categorical), - ((pl.col('count_m')) / (pl.col('count_m') + - pl.col('count_um'))).alias('density') - ] - - UPSTREAM_REGION = pl.col('position') < pl.col('start') - BODY_REGION = (pl.col('start') <= pl.col('position')) & ( - pl.col('position') <= pl.col('end')) - DOWNSTREAM_REGION = (pl.col('position') > pl.col('end')) - - UPSTREAM_FRAGMENT = (((pl.col('position') - pl.col('upstream')) / - (pl.col('start') - pl.col('upstream'))) * upstream_windows).floor() - BODY_FRAGMENT = (((pl.col('position') - pl.col('start')) / (pl.col('end') - - pl.col('start') + 1e-10)) * gene_windows).floor() + upstream_windows - DOWNSTREAM_FRAGMENT = (((pl.col('position') - pl.col('end')) / (pl.col('downstream') - pl.col( - 'end') + 1e-10)) * downstream_windows).floor() + upstream_windows + gene_windows + # POLARS EXPRESSIONS + # cast genome columns to type to join + gene_columns = [ + pl.col('strand').cast(pl.Categorical), + pl.col('chr').cast(pl.Categorical) + ] + # cast report columns to optimized type + df_columns = [ + pl.col('position').cast(pl.Int32), + pl.col('chr').cast(pl.Categorical), + pl.col('strand').cast(pl.Categorical), + pl.col('context').cast(pl.Categorical), + # density for CURRENT cytosine + ((pl.col('count_m')) / (pl.col('count_m') + pl.col('count_um'))).alias('density') + ] + + # upstream region position check + upstream_region = pl.col('position') < pl.col('start') + # body region position check + body_region = (pl.col('start') <= pl.col('position')) & (pl.col('position') <= pl.col('end')) + # downstream region position check + downstream_region = (pl.col('position') > pl.col('end')) + + upstream_fragment = (( + (pl.col('position') - pl.col('upstream')) / (pl.col('start') - pl.col('upstream')) + ) * upstream_windows).floor() + + # fragment even for position == end needs to be rounded by floor + # so 1e-10 is added (position is always < end) + body_fragment = (( + (pl.col('position') - pl.col('start')) / (pl.col('end') - pl.col('start') + 1e-10) + ) * gene_windows).floor() + upstream_windows + + downstream_fragment = (( + (pl.col('position') - pl.col('end')) / (pl.col('downstream') - pl.col('end') + 1e-10) + ) * downstream_windows).floor() + upstream_windows + gene_windows + + # batch approximation + read_approx = approx_batch_num(file, batch_size) + read_batches = 0 + # output dataframe + total = None + # initialize batched reader bismark = pl.read_csv_batched( file, separator='\t', has_header=False, new_columns=['chr', 'position', 'strand', 'count_m', 'count_um', 'context'], columns=[0, 1, 2, 3, 4, 5], - batch_size=batch_size + batch_size=batch_size, + n_threads=cpu ) - read_approx = approx_batch_num(file, batch_size) - read_batches = 0 + batches = bismark.next_batches(cpu) - total = None + def process_batch(df: pl.DataFrame): + return ( + df.lazy() + # filter empty rows + .filter((pl.col('count_m') + pl.col('count_um') != 0)) + # assign types + # calculate density for each cytosine + .with_columns(df_columns) + # drop redundant columns, because individual cytosine density has already been calculated + # individual counts do not matter because every cytosine is equal + .drop(['count_m', 'count_um']) + # sort by position for joining + .sort(['chr', 'strand', 'position']) + # join with nearest + .join_asof( + genome.lazy().with_columns(gene_columns), + left_on='position', right_on='upstream', by=['chr', 'strand'] + ) + # limit by end of region + .filter(pl.col('position') <= pl.col('downstream')) + # calculate fragment ids + .with_columns([ + pl.when(upstream_region).then(upstream_fragment) + .when(body_region).then(body_fragment) + .when(downstream_region).then(downstream_fragment) + .cast(pl.Int32).alias('fragment'), + pl.concat_str( + pl.col("chr"), + (pl.concat_str(pl.col("start"), pl.col("end"), separator="-")), + separator=":").alias("gene").cast(pl.Categorical) + ]) + # gather fragment stats + .groupby(by=['chr', 'strand', 'gene', 'context', 'fragment']) + .agg([ + pl.sum('density').alias('sum'), + pl.count('density').alias('count') + ]) + .drop_nulls(subset=['sum']) + ).collect() - batches = bismark.next_batches(cpu) print(f"Reading from {file}") while batches: for df in batches: - df = ( - df.lazy() - .filter((pl.col('count_m') + pl.col('count_um') != 0)) - # calculate density for each cytosine - .with_columns(DF_COLUMNS) - .drop(['count_m', 'count_um']) - .sort('position') - .join_asof( - genome.lazy().with_columns(GENOME_COLUMNS), - left_on='position', right_on='upstream', by=['chr', 'strand'] - ) # join on nearest start for every row - # limit by end of gene - .filter(pl.col('position') <= pl.col('downstream')) - .with_columns( - pl.when(UPSTREAM_REGION).then(UPSTREAM_FRAGMENT) - .when(BODY_REGION).then(BODY_FRAGMENT) - .when(DOWNSTREAM_REGION).then(DOWNSTREAM_FRAGMENT) - .cast(pl.Int32).alias('fragment') - ) - .groupby(by=['chr', 'strand', 'start', 'context', 'fragment']) - .agg([ - pl.sum('density').alias('sum'), - pl.count('density').alias('count') - ]) - .drop_nulls(subset=['sum']) - ).collect() + df = process_batch(df) if total is None and len(df) == 0: raise Exception( "Error reading Bismark file. Check format or genome. No joins on first batch.") @@ -632,16 +831,17 @@ def __read_bismark_batches( read_batches += 1 print( - f"\tRead {read_batches}/{read_approx} batch | Total size - {round(total.estimated_size('mb'), 1)}Mb RAM", end="\r") + f"\tRead {read_batches}/{read_approx} batch | Total size - {round(total.estimated_size('mb'), 1)}Mb RAM", + end="\r") batches = bismark.next_batches(cpu) print("DONE") return total def filter(self, context: str = None, strand: str = None, chr: str = None): """ - :param context: methylation context (CG, CHG, CHH) to filter (only one). - :param strand: strand to filter (+ or -). - :param chr: chromosome name to filter. + :param context: Methylation context (CG, CHG, CHH) to filter (only one). + :param strand: Strand to filter (+ or -). + :param chr: Chromosome name to filter. :return: Filtered :class:`Bismark`. """ context_filter = self.bismark["context"] == context if context is not None else True @@ -662,7 +862,7 @@ def resize(self, to_fragments: int = None): """ Modify DataFrame to fewer fragments. - :param to_fragments: number of final fragments. + :param to_fragments: Number of final fragments. :return: Resized :class:`Bismark`. """ if self.upstream_windows is not None and self.gene_windows is not None and self.downstream_windows is not None: @@ -680,7 +880,7 @@ def resize(self, to_fragments: int = None): * to_fragments).floor().cast(pl.Int32) ) .group_by( - by=['chr', 'strand', 'start', 'context', 'fragment'] + by=['chr', 'strand', 'gene', 'context', 'fragment'] ).agg([ pl.sum('sum').alias('sum'), pl.sum('count').alias('count') @@ -701,8 +901,8 @@ def trim_flank(self, upstream=True, downstream=True): """ Trim fragments - :param upstream: keep upstream? - :param downstream: keep downstream? + :param upstream: Keep upstream region? + :param downstream: Keep downstream region? :return: Trimmed :class:`Bismark`. """ trimmed = self.bismark.lazy() @@ -724,7 +924,7 @@ def trim_flank(self, upstream=True, downstream=True): return self.__class__(trimmed.collect(), **metadata) - def dendrogram(self, dist_method="euclidean", clust_method="complete"): + def clustering(self, count_threshold = 5, dist_method="euclidean", clust_method="average"): """ Gives an order for genes in specified method. @@ -732,35 +932,10 @@ def dendrogram(self, dist_method="euclidean", clust_method="complete"): :param dist_method: Distance method to use. See :meth:`scipy.spatial.distance.pdist` :param clust_method: Clustering method to use. See :meth:`scipy.cluster.hierarchy.linkage` - :return: list of indexes of ordered rows. + :return: List of indexes of ordered rows. """ - template = ( - self.bismark.lazy() - .group_by(["strand", "context", "chr", "start"], maintain_order=True) - .agg() - .with_columns(pl.lit([list(range(self.bismark["fragment"].max() + 1))]).alias("fragment")) - .explode("fragment") - .with_columns(pl.col("fragment").cast(pl.Int32)) - ) - joined = ( - template - .join(self.bismark.lazy().with_columns((pl.col("sum") / pl.col("count")).alias("density")), - on=["strand", "context", "chr", "start", "fragment"], - how="left") - .fill_null(0) - .group_by(["strand", "context", "chr", "start"], maintain_order=True) - .agg(pl.col("density")) - ).collect() - data_matrix = np.matrix( - joined["density"].to_list(), - dtype=np.float32 - ) - - dist = pdist(data_matrix, metric=dist_method) - linkage = hclust.linkage(dist, method=clust_method) - ordering = hclust.optimal_leaf_ordering(linkage, dist) - return hclust.leaves_list(ordering) + return Clustering(self.bismark, count_threshold, dist_method, clust_method, **self.metadata) def line_plot(self, resolution: int = None): """ @@ -787,15 +962,38 @@ def __init__(self, bismark_df: pl.DataFrame, **kwargs): """ super().__init__(bismark_df, **kwargs) - self.plot_data = self.bismark.group_by("fragment").agg( + self.plot_data = self.bismark.group_by("fragment").agg([ + pl.col("sum"), pl.col("count"), (pl.sum("sum") / pl.sum("count")).alias("density") - ) + ]).sort("fragment") if self.strand == '-': max_fragment = self.plot_data["fragment"].max() self.plot_data = self.plot_data.with_columns( (max_fragment - pl.col("fragment")).alias("fragment")) + @staticmethod + def __interval(sum_density: list[int], sum_counts: list[int], alpha=.95): + """ + Evaluate confidence interval for point + + :param sum_density: Sums of methylated counts in fragment + :param sum_counts: Sums of all read cytosines in fragment + :param alpha: Probability for confidence band + """ + sum_density, sum_counts = np.array(sum_density), np.array(sum_counts) + average = sum_density.sum() / sum_counts.sum() + + normalized = np.divide(sum_density, sum_counts) + + variance = np.average((normalized - average) ** 2, weights=sum_counts) + + n = sum(sum_counts) - 1 + + i = stats.t.interval(alpha, df=n, loc=average, scale=np.sqrt(variance / n)) + + return {"lower": i[0], "upper": i[1]} + def save_plot_rds(self, path, compress: bool = False): """ Saves plot data in a rds DataFrame with columns: @@ -806,7 +1004,10 @@ def save_plot_rds(self, path, compress: bool = False): | Int | Float | +----------+---------+ """ - write_rds(path, self.plot_data.to_pandas(), + df = self.bismark.group_by("fragment").agg( + (pl.sum("sum") / pl.sum("count")).alias("density") + ) + write_rds(path, df.to_pandas(), compress="gzip" if compress else None) def draw( @@ -814,6 +1015,7 @@ def draw( fig_axes: tuple = None, smooth: int = 10, label: str = None, + confidence = 0, linewidth: float = 1.0, linestyle: str = '-', ) -> Figure: @@ -823,6 +1025,7 @@ def draw( :param fig_axes: Tuple with (fig, axes) from :meth:`matplotlib.plt.subplots` :param smooth: Window for SavGol filter. (see :meth:`scipy.signal.savgol`) :param label: Label of the plot + :param confidence: Probability for confidence bands. 0 for disabled. :param linewidth: See matplotlib documentation. :param linestyle: See matplotlib documentation. :return: @@ -832,7 +1035,21 @@ def draw( else: fig, axes = fig_axes - data = self.plot_data.sort("fragment")["density"] + if 0 < confidence < 1: + df = ( + self.plot_data + .with_columns( + pl.struct(["sum", "count"]).map_elements( + lambda x: self.__interval(x["sum"], x["count"], confidence) + ).alias("interval") + ) + .unnest("interval") + .select(["fragment", "lower", "density", "upper"]) + ) + else: + df = self.plot_data + + data = df["density"] polyorder = 3 window = smooth if smooth > polyorder else polyorder + 1 @@ -842,11 +1059,25 @@ def draw( x = np.arange(len(data)) data = data * 100 # convert to percents - axes.plot(x, data, label=label, + + axes.plot(x, data, + label=label if label is not None else "_", linestyle=linestyle, linewidth=linewidth) + + if 0 < confidence < 1: + upper = df["upper"].to_numpy() * 100 # convert to percents + lower = df["lower"].to_numpy() * 100 # convert to percents + + upper = savgol_filter(upper, window, 3, mode="nearest") if smooth else upper + lower = savgol_filter(lower, window, 3, mode="nearest") if smooth else lower + + axes.fill_between(x, lower, upper, alpha=.2) + self.__add_flank_lines(axes) - axes.legend() + if label is not None: + axes.legend() + axes.set_ylabel('Methylation density, %') axes.set_xlabel('Position') @@ -877,7 +1108,7 @@ def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, **kwargs): order = ( self.bismark.lazy() - .groupby(['chr', 'strand', "start"]) + .groupby(['chr', 'strand', "gene"]) .agg( (pl.col('sum').sum() / pl.col('count').sum()).alias("order") ) @@ -886,7 +1117,7 @@ def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, **kwargs): # sort by rows and add row numbers hm_data = ( self.bismark.lazy() - .groupby(['chr', 'strand', "start"]) + .groupby(['chr', 'strand', "gene"]) .agg( pl.col('fragment'), pl.col('sum'), pl.col('count') ) @@ -995,10 +1226,10 @@ def __add_flank_lines(self, axes: plt.Axes): x_ticks = [] x_labels = [] if self.upstream_windows > 0: - x_ticks.append(self.upstream_windows - 1) + x_ticks.append(self.upstream_windows - .5) x_labels.append('TSS') if self.downstream_windows > 0: - x_ticks.append(self.gene_windows + self.upstream_windows) + x_ticks.append(self.gene_windows + self.upstream_windows - .5) x_labels.append('TES') if x_ticks and x_labels: @@ -1009,7 +1240,7 @@ def __add_flank_lines(self, axes: plt.Axes): class BismarkFilesBase: - def __init__(self, samples, labels: list[str] | None): + def __init__(self, samples, labels: list[str] = None): self.samples = self.__check_metadata( samples if isinstance(samples, list) else [samples]) if samples is None: @@ -1124,7 +1355,7 @@ def merge(self): if len(upstream_windows) == len(downstream_windows) == len(gene_windows) == 1: merged = ( pl.concat([sample.bismark for sample in self.samples]).lazy() - .group_by(["strand", "context", "chr", "start", "fragment"]) + .group_by(["strand", "context", "chr", "gene", "fragment"]) .agg([pl.sum("sum").alias("sum"), pl.sum("count").alias("count")]) ).collect() @@ -1195,15 +1426,16 @@ def box_plot(self, fig_axes: tuple = None, showfliers=False): class LinePlotFiles(BismarkFilesBase): def draw( self, - smooth: float = .05, + smooth: int = 10, linewidth: float = 1.0, linestyle: str = '-', + confidence=0 ): plt.clf() fig, axes = plt.subplots() for lp, label in zip(self.samples, self.labels): assert isinstance(lp, LinePlot) - lp.draw((fig, axes), smooth, label, linewidth, linestyle) + lp.draw((fig, axes), smooth, label, confidence, linewidth, linestyle) return fig diff --git a/src/bismarkplot/__init__.py b/src/bismarkplot/__init__.py index 8123328..46a01a5 100644 --- a/src/bismarkplot/__init__.py +++ b/src/bismarkplot/__init__.py @@ -1,3 +1,3 @@ -from .BismarkPlot import Metagene, MetageneFiles, Genome, ChrLevels +from .BismarkPlot import Metagene, MetageneFiles, Genome, ChrLevels, Clustering __version__ = 1.2 diff --git a/src/bismarkplot/console_metagene.py b/src/bismarkplot/console_metagene.py index e1cd26d..68ca6d3 100644 --- a/src/bismarkplot/console_metagene.py +++ b/src/bismarkplot/console_metagene.py @@ -27,6 +27,7 @@ parser.add_argument('-S', '--smooth', help='windows for smoothing', type=float, default=10) parser.add_argument('-L', '--labels', help='labels for plots', nargs='+') +parser.add_argument('-C', '--confidence', help='probability for confidence bands for line-plot. 0 if disabled', type=float, default=0) parser.add_argument('-H', help='vertical resolution for heat-map', type=int, default=100) parser.add_argument('-V', help='vertical resolution for heat-map', type=int, default=100) parser.add_argument("--dpi", help="dpi of output plot", type=int, default=200) @@ -69,13 +70,13 @@ def main(): base_name = args.out + "_" + context + strand + "_{type}." + args.format if args.line_plot: - filtered.line_plot().draw(smooth=args.smooth).savefig(base_name.format(type = "line-plot"), dpi = args.dpi) + filtered.line_plot().draw(smooth=args.smooth, confidence=args.confidence).savefig(base_name.format(type="line-plot"), dpi = args.dpi) if args.heat_map: - filtered.heat_map(args.hresolution, args.vresolution).draw().savefig(base_name.format(type = "heat-map"), dpi = args.dpi) + filtered.heat_map(args.hresolution, args.vresolution).draw().savefig(base_name.format(type="heat-map"), dpi=args.dpi) if args.box_plot: - filtered.trim_flank().box_plot().savefig(base_name.format(type = "box-plot"), dpi = args.dpi) + filtered.trim_flank().box_plot().savefig(base_name.format(type="box-plot"), dpi=args.dpi) if args.violin_plot: - filtered.trim_flank().violin_plot().savefig(base_name.format(type = "violin-plot"), dpi = args.dpi) + filtered.trim_flank().violin_plot().savefig(base_name.format(type="violin-plot"), dpi=args.dpi) except Exception: filename = f'error{datetime.now().strftime("%m_%d_%H:%M")}.txt'