Skip to content

Commit

Permalink
Simplify embeddings a bit (#1538)
Browse files Browse the repository at this point in the history
* Remove dependency on execution order for scatter plot tests

Previously `test_scatterplots[umap_with_edges]` would fail if it ran before `test_scatterplots[umap]`, since the `umap` case changed the color palette the anndata object used.
Now the AnnData object is copied for each test that uses it.

This is important because it means we don't have state spilling between the tests and can parallelize them.

* Convert 'components' to 'dimensions' is sc.pl.embedding

Mostly this was to get rid of `_get_data_points`, which was a bit of a horrible function.

Now we get all values of the array, then for each plot index into it with the dimensions. This has removed a number of edge cases and special handling thoughout, making the dataflow a bit easier to follow.

* Rearrange code in embedding to make a bit more sense

* Fix paga plots

* Implement broadcasting of dims and color

* Add simple test for broadcasting of dimensions and color

* Type dimensions

* add test for components and dimensions being interoperable

* Added docs for the dimension argument

* Undo deprecation of components kwarg.

I still want to eventually do that, but it shows up in many other parts of the codebase which I don't have time to replace at the moment. This includes, but is not limited to, the other pca plotting methods.

* Release note

* Fix type annotation for python 3.7
  • Loading branch information
ivirshup authored Mar 29, 2022
1 parent 3db321f commit 0728d55
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 211 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.9.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
- {func}`~scanpy.pl.embedding_density` now allows more than 10 groups {pr}`1936` {smaller}`A Wolf`
- {func}`~scanpy.logging.print_versions` now uses `session_info` {pr}`2089` {smaller}`P Angerer` {smaller}`I Virshup`
- `_choose_representation` now subsets the provided representation to n_pcs, regardless of the name of the provided representation (should affect mostly {func}`~scanpy.pp.neighbors`) {pr}`2179` {smaller}`I Virshup` {smaller}`PG Majev`
- Embedding plots now have a `dimensions` argument, which lets users select which dimensions of their embedding to plot and uses the same broadcasting rules as other arguments {pr}`1538` {smaller}`I Virshup`
5 changes: 5 additions & 0 deletions scanpy/plotting/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
groups
Restrict to a few categories in categorical observation annotation.
The default is not to restrict to any groups.
dimensions
0-indexed dimensions of the embedding to plot as integers. E.g. [(0, 1), (1, 2)].
Unlike `components`, this argument is used in the same way as `colors`, e.g. is
used to specify a single plot at a time. Will eventually replace the components
argument.
components
For instance, `['1,2', '2,3']`. To plot all available components use
`components='all'`.
Expand Down
4 changes: 2 additions & 2 deletions scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ def embedding_density(
ax = embedding(
adata,
basis,
components=components,
dimensions=np.array(components) - 1, # Saved with 1 based indexing
color=density_col_name,
color_map=color_map,
size=dot_sizes,
Expand Down Expand Up @@ -1515,7 +1515,7 @@ def embedding_density(
fig_or_ax = embedding(
adata,
basis,
components=components,
dimensions=np.array(components) - 1, # Saved with 1 based indexing
color=density_col_name,
color_map=color_map,
size=dot_sizes,
Expand Down
11 changes: 7 additions & 4 deletions scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def paga_compare(
else:
basis = 'umap'

from .scatterplots import embedding, _get_data_points
from .scatterplots import embedding, _get_basis, _components_to_dimensions

embedding(
adata,
Expand All @@ -123,9 +123,12 @@ def paga_compare(

if pos is None:
if color == adata.uns['paga']['groups']:
coords = _get_data_points(
adata, basis, projection="2d", components=components, scale_factor=None
)[0][0]
# TODO: Use dimensions here
_basis = _get_basis(adata, basis)
dims = _components_to_dimensions(
components=components, dimensions=None, total_dims=_basis.shape[1]
)[0]
coords = _basis[:, dims]
pos = (
pd.DataFrame(coords, columns=["x", "y"], index=adata.obs_names)
.groupby(adata.obs[color], observed=True)
Expand Down
Loading

0 comments on commit 0728d55

Please sign in to comment.