diff --git a/docs/release-notes/1.9.0.md b/docs/release-notes/1.9.0.md index 8f0370e9ce..8a43cf1aff 100644 --- a/docs/release-notes/1.9.0.md +++ b/docs/release-notes/1.9.0.md @@ -7,3 +7,4 @@ - {func}`~scanpy.pl.embedding_density` now allows more than 10 groups {pr}`1936` {smaller}`A Wolf` - {func}`~scanpy.logging.print_versions` now uses `session_info` {pr}`2089` {smaller}`P Angerer` {smaller}`I Virshup` - `_choose_representation` now subsets the provided representation to n_pcs, regardless of the name of the provided representation (should affect mostly {func}`~scanpy.pp.neighbors`) {pr}`2179` {smaller}`I Virshup` {smaller}`PG Majev` +- Embedding plots now have a `dimensions` argument, which lets users select which dimensions of their embedding to plot and uses the same broadcasting rules as other arguments {pr}`1538` {smaller}`I Virshup` diff --git a/scanpy/plotting/_docs.py b/scanpy/plotting/_docs.py index 6c8bcf1484..947ed17bf3 100644 --- a/scanpy/plotting/_docs.py +++ b/scanpy/plotting/_docs.py @@ -50,6 +50,11 @@ groups Restrict to a few categories in categorical observation annotation. The default is not to restrict to any groups. +dimensions + 0-indexed dimensions of the embedding to plot as integers. E.g. [(0, 1), (1, 2)]. + Unlike `components`, this argument is used in the same way as `colors`, e.g. is + used to specify a single plot at a time. Will eventually replace the components + argument. components For instance, `['1,2', '2,3']`. To plot all available components use `components='all'`. diff --git a/scanpy/plotting/_tools/__init__.py b/scanpy/plotting/_tools/__init__.py index 40bd8e097d..d788e360d3 100644 --- a/scanpy/plotting/_tools/__init__.py +++ b/scanpy/plotting/_tools/__init__.py @@ -1484,7 +1484,7 @@ def embedding_density( ax = embedding( adata, basis, - components=components, + dimensions=np.array(components) - 1, # Saved with 1 based indexing color=density_col_name, color_map=color_map, size=dot_sizes, @@ -1515,7 +1515,7 @@ def embedding_density( fig_or_ax = embedding( adata, basis, - components=components, + dimensions=np.array(components) - 1, # Saved with 1 based indexing color=density_col_name, color_map=color_map, size=dot_sizes, diff --git a/scanpy/plotting/_tools/paga.py b/scanpy/plotting/_tools/paga.py index 5bd7eefece..6808ae4682 100644 --- a/scanpy/plotting/_tools/paga.py +++ b/scanpy/plotting/_tools/paga.py @@ -97,7 +97,7 @@ def paga_compare( else: basis = 'umap' - from .scatterplots import embedding, _get_data_points + from .scatterplots import embedding, _get_basis, _components_to_dimensions embedding( adata, @@ -123,9 +123,12 @@ def paga_compare( if pos is None: if color == adata.uns['paga']['groups']: - coords = _get_data_points( - adata, basis, projection="2d", components=components, scale_factor=None - )[0][0] + # TODO: Use dimensions here + _basis = _get_basis(adata, basis) + dims = _components_to_dimensions( + components=components, dimensions=None, total_dims=_basis.shape[1] + )[0] + coords = _basis[:, dims] pos = ( pd.DataFrame(coords, columns=["x", "y"], index=adata.obs_names) .groupby(adata.obs[color], observed=True) diff --git a/scanpy/plotting/_tools/scatterplots.py b/scanpy/plotting/_tools/scatterplots.py index 61c7d48f5e..ce79f80cfa 100644 --- a/scanpy/plotting/_tools/scatterplots.py +++ b/scanpy/plotting/_tools/scatterplots.py @@ -1,6 +1,18 @@ import collections.abc as cabc from copy import copy -from typing import Union, Optional, Sequence, Any, Mapping, List, Tuple +from numbers import Integral +from itertools import combinations, product +from typing import ( + Collection, + Union, + Optional, + Sequence, + Any, + Mapping, + List, + Tuple, +) +from warnings import warn import numpy as np import pandas as pd @@ -62,6 +74,7 @@ def embedding( arrows_kwds: Optional[Mapping[str, Any]] = None, groups: Optional[str] = None, components: Union[str, Sequence[str]] = None, + dimensions: Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]] = None, layer: Optional[str] = None, projection: Literal['2d', '3d'] = '2d', scale_factor: Optional[float] = None, @@ -109,10 +122,38 @@ def embedding( ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + ##################### + # Argument handling # + ##################### + check_projection(projection) sanitize_anndata(adata) - # Setting up color map for continuous values + basis_values = _get_basis(adata, basis) + dimensions = _components_to_dimensions( + components, dimensions, projection=projection, total_dims=basis_values.shape[1] + ) + args_3d = dict(projection='3d') if projection == '3d' else {} + + # Figure out if we're using raw + if use_raw is None: + # check if adata.raw is set + use_raw = layer is None and adata.raw is not None + if use_raw and layer is not None: + raise ValueError( + "Cannot use both a layer and the raw representation. Was passed:" + f"use_raw={use_raw}, layer={layer}." + ) + if use_raw and adata.raw is None: + raise ValueError( + "`use_raw` is set to True but AnnData object does not have raw. " + "Please check." + ) + + if isinstance(groups, str): + groups = [groups] + + # Color map if color_map is not None: if cmap is not None: raise ValueError("Cannot specify both `color_map` and `cmap`.") @@ -121,84 +162,23 @@ def embedding( cmap = copy(get_cmap(cmap)) cmap.set_bad(na_color) kwargs["cmap"] = cmap - # Prevents warnings during legend creation na_color = colors.to_hex(na_color, keep_alpha=True) - if size is not None: - kwargs['s'] = size if 'edgecolor' not in kwargs: # by default turn off edge color. Otherwise, for # very small sizes the edge will not reduce its size # (https://github.com/theislab/scanpy/issues/293) kwargs['edgecolor'] = 'none' - if groups: - if isinstance(groups, str): - groups = [groups] - - args_3d = dict(projection='3d') if projection == '3d' else {} + # Vectorized arguments - # Deal with Raw - if use_raw is None: - # check if adata.raw is set - use_raw = layer is None and adata.raw is not None - if use_raw and layer is not None: - raise ValueError( - "Cannot use both a layer and the raw representation. Was passed:" - f"use_raw={use_raw}, layer={layer}." - ) - - if wspace is None: - # try to set a wspace that is not too large or too small given the - # current figure size - wspace = 0.75 / rcParams['figure.figsize'][0] + 0.02 - if adata.raw is None and use_raw: - raise ValueError( - "`use_raw` is set to True but AnnData object does not have raw. " - "Please check." - ) # turn color into a python list color = [color] if isinstance(color, str) or color is None else list(color) if title is not None: # turn title into a python list if not None title = [title] if isinstance(title, str) else list(title) - # get the points position and the components list - # (only if components is not None) - data_points, components_list = _get_data_points( - adata, basis, projection, components, scale_factor - ) - - # Setup layout. - # Most of the code is for the case when multiple plots are required - # 'color' is a list of names that want to be plotted. - # Eg. ['Gene1', 'louvain', 'Gene2']. - # component_list is a list of components [[0,1], [1,2]] - if ( - not isinstance(color, str) - and isinstance(color, cabc.Sequence) - and len(color) > 1 - ) or len(components_list) > 1: - if ax is not None: - raise ValueError( - "Cannot specify `ax` when plotting multiple panels " - "(each for a given value of 'color')." - ) - if len(components_list) == 0: - components_list = [None] - - # each plot needs to be its own panel - num_panels = len(color) * len(components_list) - fig, grid = _panel_grid(hspace, wspace, ncols, num_panels) - else: - if len(components_list) == 0: - components_list = [None] - grid = None - if ax is None: - fig = pl.figure() - ax = fig.add_subplot(111, **args_3d) - # turn vmax and vmin into a sequence if isinstance(vmax, str) or not isinstance(vmax, cabc.Sequence): vmax = [vmax] @@ -209,28 +189,62 @@ def embedding( if isinstance(norm, Normalize) or not isinstance(norm, cabc.Sequence): norm = [norm] - if 's' in kwargs: + # Size + if 's' in kwargs and size is None: size = kwargs.pop('s') - if size is not None: # check if size is any type of sequence, and if so # set as ndarray - import pandas.core.series - if ( size is not None - and isinstance(size, (cabc.Sequence, pandas.core.series.Series, np.ndarray)) + and isinstance(size, (cabc.Sequence, pd.Series, np.ndarray)) and len(size) == adata.shape[0] ): size = np.array(size, dtype=float) else: size = 120000 / adata.shape[0] - # make the plots - axs = [] - import itertools + ########## + # Layout # + ########## + # Most of the code is for the case when multiple plots are required - idx_components = range(len(components_list)) + if wspace is None: + # try to set a wspace that is not too large or too small given the + # current figure size + wspace = 0.75 / rcParams['figure.figsize'][0] + 0.02 + + if components is not None: + color, dimensions = list(zip(*product(color, dimensions))) + + color, dimensions = _broadcast_args(color, dimensions) + + # 'color' is a list of names that want to be plotted. + # Eg. ['Gene1', 'louvain', 'Gene2']. + # component_list is a list of components [[0,1], [1,2]] + if ( + not isinstance(color, str) + and isinstance(color, cabc.Sequence) + and len(color) > 1 + ) or len(dimensions) > 1: + if ax is not None: + raise ValueError( + "Cannot specify `ax` when plotting multiple panels " + "(each for a given value of 'color')." + ) + + # each plot needs to be its own panel + fig, grid = _panel_grid(hspace, wspace, ncols, len(color)) + else: + grid = None + if ax is None: + fig = pl.figure() + ax = fig.add_subplot(111, **args_3d) + + ############ + # Plotting # + ############ + axs = [] # use itertools.product to make a plot for each color and for each component # For example if color=[gene1, gene2] and components=['1,2, '2,3']. @@ -238,9 +252,7 @@ def embedding( # color=gene1, components=[1,2], color=gene1, components=[2,3], # color=gene2, components = [1, 2], color=gene2, components=[2,3], # ] - for count, (value_to_plot, component_idx) in enumerate( - itertools.product(color, idx_components) - ): + for count, (value_to_plot, dims) in enumerate(zip(color, dimensions)): color_source_vector = _get_color_source_vector( adata, value_to_plot, @@ -270,7 +282,7 @@ def embedding( size = np.array(size)[order] color_source_vector = color_source_vector[order] color_vector = color_vector[order] - _data_points = data_points[component_idx][order, :] + coords = basis_values[:, dims][order, :] # if plotting multiple panels, get the ax from the grid spec # else use the ax value (either user given or created previously) @@ -310,9 +322,9 @@ def embedding( # make the scatter plot if projection == '3d': cax = ax.scatter( - _data_points[:, 0], - _data_points[:, 1], - _data_points[:, 2], + coords[:, 0], + coords[:, 1], + coords[:, 2], marker=".", c=color_vector, rasterized=settings._vector_friendly, @@ -320,11 +332,12 @@ def embedding( **kwargs, ) else: - scatter = ( partial(ax.scatter, s=size, plotnonfinite=True) if scale_factor is None - else partial(circles, s=size, ax=ax) # size in circles is radius + else partial( + circles, s=size, ax=ax, scale_factor=scale_factor + ) # size in circles is radius ) if add_outline: @@ -353,8 +366,8 @@ def embedding( alpha = kwargs.pop('alpha') if 'alpha' in kwargs else None ax.scatter( - _data_points[:, 0], - _data_points[:, 1], + coords[:, 0], + coords[:, 1], s=bg_size, marker=".", c=bg_color, @@ -363,8 +376,8 @@ def embedding( **kwargs, ) ax.scatter( - _data_points[:, 0], - _data_points[:, 1], + coords[:, 0], + coords[:, 1], s=gap_size, marker=".", c=gap_color, @@ -376,8 +389,8 @@ def embedding( kwargs['alpha'] = 0.7 if alpha is None else alpha cax = scatter( - _data_points[:, 0], - _data_points[:, 1], + coords[:, 0], + coords[:, 1], marker=".", c=color_vector, rasterized=settings._vector_friendly, @@ -393,12 +406,7 @@ def embedding( # set default axis_labels name = _basis2name(basis) - if components is not None: - axis_labels = [name + str(x + 1) for x in components_list[component_idx]] - elif projection == '3d': - axis_labels = [name + str(x + 1) for x in range(3)] - else: - axis_labels = [name + str(x + 1) for x in range(2)] + axis_labels = [name + str(d + 1) for d in dims] ax.set_xlabel(axis_labels[0]) ax.set_ylabel(axis_labels[1]) @@ -430,7 +438,7 @@ def embedding( ax, color_source_vector, palette=_get_palette(adata, value_to_plot), - scatter_array=_data_points, + scatter_array=coords, legend_loc=legend_loc, legend_fontweight=legend_fontweight, legend_fontsize=legend_fontsize, @@ -1021,115 +1029,39 @@ def spatial( # Helpers -def _get_data_points( - adata, basis, projection, components, scale_factor -) -> Tuple[List[np.ndarray], List[Tuple[int, int]]]: - """ - Returns the data points corresponding to the selected basis, projection and/or components. - - Because multiple components are given (eg components=['1,2', '2,3'] the - returned data are lists, containing each of the components. When only one component is plotted - the list length is 1. - - Returns - ------- - data_points - Each entry is a numpy array containing the data points - components - The cleaned list of components. Eg. [(0,1)] or [(0,1), (1,2)] - for components = [1,2] and components=['1,2', '2,3'] respectively - """ - - if basis in adata.obsm.keys(): - basis_key = basis - - elif f"X_{basis}" in adata.obsm.keys(): - basis_key = f"X_{basis}" - else: - raise KeyError( - f"Could not find entry in `obsm` for '{basis}'.\n" - f"Available keys are: {list(adata.obsm.keys())}." - ) - - n_dims = 2 - if projection == '3d': - # check if the data has a third dimension - if adata.obsm[basis_key].shape[1] == 2: - if settings._low_resolution_warning: - logg.warning( - 'Selected projections is "3d" but only two dimensions ' - 'are available. Only these two dimensions will be plotted' - ) - else: - n_dims = 3 - - if components == 'all': - from itertools import combinations - - r_value = 3 if projection == '3d' else 2 - _components_list = np.arange(adata.obsm[basis_key].shape[1]) + 1 - components = [ - ",".join(map(str, x)) for x in combinations(_components_list, r=r_value) - ] - - components_list = [] - offset = 0 - if basis == 'diffmap': - offset = 1 - if components is not None: - # components have different formats, either a list with integers, a string - # or a list of strings. - +def _components_to_dimensions( + components: Optional[Union[str, Collection[str]]], + dimensions: Optional[Union[Collection[int], Collection[Collection[int]]]], + *, + projection: Literal["2d", "3d"] = "2d", + total_dims: int, +) -> List[Collection[int]]: + """Normalize components/ dimensions args for embedding plots.""" + # TODO: Deprecate components kwarg + ndims = {"2d": 2, "3d": 3}[projection] + if components is None and dimensions is None: + dimensions = [tuple(i for i in range(ndims))] + elif components is not None and dimensions is not None: + raise ValueError("Cannot provide both dimensions and components") + + # TODO: Consider deprecating this + # If components is not None, parse them and set dimensions + if components == "all": + dimensions = list(combinations(range(total_dims), ndims)) + elif components is not None: if isinstance(components, str): - # eg: components='1,2' - components_list.append( - tuple(int(x.strip()) - 1 + offset for x in components.split(',')) - ) - - elif isinstance(components, cabc.Sequence): - if isinstance(components[0], int): - # components=[1,2] - components_list.append(tuple(int(x) - 1 + offset for x in components)) - else: - # in this case, the components are str - # eg: components=['1,2'] or components=['1,2', '2,3] - # More than one component can be given and is stored - # as a new item of components_list - for comp in components: - components_list.append( - tuple(int(x.strip()) - 1 + offset for x in comp.split(',')) - ) + components = [components] + # Components use 1 based indexing + dimensions = [[int(dim) - 1 for dim in c.split(",")] for c in components] - else: - raise ValueError( - "Given components: '{}' are not valid. Please check. " - "A valid example is `components='2,3'`" - ) - # check if the components are present in the data - try: - data_points = [] - for comp in components_list: - data_points.append(adata.obsm[basis_key][:, comp]) - except Exception: # TODO catch the correct exception - raise ValueError( - "Given components: '{}' are not valid. Please check. " - "A valid example is `components='2,3'`" - ) - - if basis == 'diffmap': - # remove the offset added in the case of diffmap, such that - # plot_scatter can print the labels correctly. - components_list = [ - tuple(number - 1 for number in comp) for comp in components_list - ] - else: - data_points = [np.array(adata.obsm[basis_key])[:, offset : offset + n_dims]] - components_list = [] + if all(isinstance(el, Integral) for el in dimensions): + dimensions = [dimensions] + # if all(isinstance(el, Collection) for el in dimensions): + for dims in dimensions: + if len(dims) != ndims or not all(isinstance(d, Integral) for d in dims): + raise ValueError() - if scale_factor is not None: # if basis need scale for img background - data_points[0] = np.multiply(data_points[0], scale_factor) - - return data_points, components_list + return dimensions def _add_categorical_legend( @@ -1198,6 +1130,16 @@ def _add_categorical_legend( ) +def _get_basis(adata: AnnData, basis: str) -> np.ndarray: + """Get array for basis from anndata. Just tries to add 'X_'.""" + if basis in adata.obsm: + return adata.obsm[basis] + elif f"X_{basis}" in adata.obsm: + return adata.obsm[f"X_{basis}"] + else: + raise KeyError(f"Could not find '{basis}' or 'X_{basis}' in .obsm") + + def _get_color_source_vector( adata, value_to_plot, use_raw=False, gene_symbols=None, layer=None, groups=None ): @@ -1401,3 +1343,16 @@ def _check_na_color( else: na_color = "lightgray" return na_color + + +def _broadcast_args(*args): + """Broadcasts arguments to a common length.""" + from itertools import repeat + + lens = [len(arg) for arg in args] + longest = max(lens) + if not (set(lens) == {1, longest} or set(lens) == {longest}): + raise ValueError(f"Could not broadast together arguments with shapes: {lens}.") + return list( + [[arg[0] for _ in range(longest)] if len(arg) == 1 else arg for arg in args] + ) diff --git a/scanpy/plotting/_utils.py b/scanpy/plotting/_utils.py index 18195f8d04..5579d171b4 100644 --- a/scanpy/plotting/_utils.py +++ b/scanpy/plotting/_utils.py @@ -1060,7 +1060,9 @@ def check_projection(projection): ) -def circles(x, y, s, ax, marker=None, c='b', vmin=None, vmax=None, **kwargs): +def circles( + x, y, s, ax, marker=None, c='b', vmin=None, vmax=None, scale_factor=1.0, **kwargs +): """ Taken from here: https://gist.github.com/syrte/592a062c562cd2a98a83 Make a scatter plot of circles. @@ -1102,7 +1104,9 @@ def circles(x, y, s, ax, marker=None, c='b', vmin=None, vmax=None, **kwargs): # You can set `facecolor` with an array for each patch, # while you can only set `facecolors` with a value for all. - + if scale_factor != 1.0: + x = x * scale_factor + y = y * scale_factor zipped = np.broadcast(x, y, s) patches = [Circle((x_, y_), s_) for x_, y_, s_ in zipped] collection = PatchCollection(patches, **kwargs) diff --git a/scanpy/tests/_images/master_umap_with_edges.png b/scanpy/tests/_images/master_umap_with_edges.png index 05319691f6..61c9eaa266 100644 Binary files a/scanpy/tests/_images/master_umap_with_edges.png and b/scanpy/tests/_images/master_umap_with_edges.png differ diff --git a/scanpy/tests/test_embedding_plots.py b/scanpy/tests/test_embedding_plots.py index 58db03d965..62eb9f6e0d 100644 --- a/scanpy/tests/test_embedding_plots.py +++ b/scanpy/tests/test_embedding_plots.py @@ -220,6 +220,58 @@ def test_enumerated_palettes(fixture_request, adata, tmpdir, plotfunc): check_images(dict_pth, list_pth, tol=15) +def test_dimension_broadcasting(adata, tmpdir, check_same_image): + tmpdir = Path(tmpdir) + + with pytest.raises(ValueError): + sc.pl.pca( + adata, color=["label", "1_missing"], dimensions=[(0, 1), (1, 2), (2, 3)] + ) + + dims_pth = tmpdir / "broadcast_dims.png" + color_pth = tmpdir / "broadcast_colors.png" + + sc.pl.pca(adata, color=["label", "label", "label"], dimensions=(2, 3), show=False) + plt.savefig(dims_pth, dpi=40) + plt.close() + sc.pl.pca(adata, color="label", dimensions=[(2, 3), (2, 3), (2, 3)], show=False) + plt.savefig(color_pth, dpi=40) + plt.close() + + check_same_image(dims_pth, color_pth, tol=5) + + +def test_dimensions_same_as_components(adata, tmpdir, check_same_image): + tmpdir = Path(tmpdir) + adata = adata.copy() + adata.obs["mean"] = np.ravel(adata.X.mean(axis=1)) + + comp_pth = tmpdir / "components_plot.png" + dims_pth = tmpdir / "dimension_plot.png" + + # TODO: Deprecate components kwarg + # with pytest.warns(FutureWarning, match=r"components .* deprecated"): + sc.pl.pca( + adata, + color=["mean", "label"], + components=["1,2", "2,3"], + show=False, + ) + plt.savefig(comp_pth, dpi=40) + plt.close() + + sc.pl.pca( + adata, + color=["mean", "mean", "label", "label"], + dimensions=[(0, 1), (1, 2), (0, 1), (1, 2)], + show=False, + ) + plt.savefig(dims_pth, dpi=40) + plt.close() + + check_same_image(dims_pth, comp_pth, tol=5) + + # Spatial specific diff --git a/scanpy/tests/test_plotting.py b/scanpy/tests/test_plotting.py index 83b4331a40..ca00ba3efb 100644 --- a/scanpy/tests/test_plotting.py +++ b/scanpy/tests/test_plotting.py @@ -939,8 +939,9 @@ def test_genes_symbols(image_comparer, id, fn): save_and_compare_images(f"master_{id}_gene_symbols") -@pytest.fixture(scope="module") -def pbmc_scatterplots(): +@pytest.fixture(scope="session") +def _pbmc_scatterplots(): + # Wrapped in another fixture to avoid any mutation pbmc = sc.datasets.pbmc68k_reduced() pbmc.layers["sparse"] = pbmc.raw.X / 2 pbmc.layers["test"] = pbmc.X.copy() + 100 @@ -951,6 +952,11 @@ def pbmc_scatterplots(): return pbmc +@pytest.fixture +def pbmc_scatterplots(_pbmc_scatterplots): + return _pbmc_scatterplots.copy() + + @pytest.mark.parametrize( 'id,fn', [