Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ dependencies = [
"scanpy>=1.11",
"scikit-learn>=1.6",
"scipy>=1.14",
"scvelo>=0.3",
"seaborn>=0.13",
"session-info2>=0.2",
"wrapt>=1.16",
Expand All @@ -69,6 +68,7 @@ optional-dependencies.moscot = [ "moscot" ]
optional-dependencies.petsc = [ "mpi4py", "petsc4py", "slepc4py" ]
optional-dependencies.plot = [ "adjusttext" ]
optional-dependencies.r = [ "rpy2>=3.3" ]
optional-dependencies.scvelo = [ "scvelo>=0.3" ]
urls."Bug Tracker" = "https://github.com/theislab/cellrank/issues"
urls.Documentation = "https://cellrank.readthedocs.io"
urls.Download = "https://cellrank.readthedocs.io/en/latest/installation.html"
Expand Down
2 changes: 1 addition & 1 deletion src/cellrank/estimators/mixins/_lineage_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def prepare_format(
axes = np.ravel([axes])

basis = kwargs.pop("basis", "umap")
# scvelo compat: "right" means "right margin" in scanpy
# "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"
_i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _plot_discrete(

same_plot = same_plot or len(names) == 1
kwargs.setdefault("legend_loc", "on data")
# scvelo compat: "right" means "right margin" in scanpy
# "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"
kwargs.pop("color_map", None)
Expand Down Expand Up @@ -580,13 +580,12 @@ def _plot_continuous(
)
elif mode == PlotMode.EMBEDDING:
kwargs.setdefault("legend_loc", "on data")
# scvelo compat: "right" means "right margin" in scanpy
# "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"

if same_plot:
if color:
# https://github.com/theislab/scvelo/issues/673
logger.warning("Ignoring `color` when `mode='embedding'` and `same_plot=True`")
title = [_title] if title is None else title
_plot_color_gradients(self.adata, _data, basis=basis, title=title,
Expand Down
2 changes: 1 addition & 1 deletion src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def plot_random_walks(
%(parallel)s
%(plotting)s
kwargs
Keyword arguments for :func:`~scvelo.pl.scatter`.
Keyword arguments for :func:`~scanpy.pl.embedding`.

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/cellrank/kernels/utils/_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def plot(
# scVelo accepted color="none" to skip coloring; scanpy does not.
if kwargs.get("color") == "none":
kwargs.pop("color")
# scvelo compat: "right" means "right margin" in scanpy
# "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"
sc.pl.embedding(self._adata, basis=basis, show=False, ax=ax, **kwargs)
Expand Down
13 changes: 12 additions & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import scanpy as sc
import scipy.sparse as sp
import scvelo as scv
from anndata import AnnData
from pandas.testing import assert_frame_equal, assert_series_equal
from PIL import Image
Expand All @@ -26,6 +25,15 @@ def _jax_not_installed() -> bool:
return True


def _scvelo_not_installed() -> bool:
try:
import scvelo # noqa

return False
except ImportError:
return True


def _rpy2_mgcv_not_installed() -> bool:
try:
import rpy2
Expand Down Expand Up @@ -362,6 +370,8 @@ def _create_dummy_adata(n_obs: int) -> AnnData:
-------
The created adata object.
"""
import scvelo as scv

np.random.seed(42) # noqa: NPY002
adata = scv.datasets.toy_data(n_obs=n_obs)
adata.obs_names_make_unique()
Expand Down Expand Up @@ -396,6 +406,7 @@ def _create_dummy_adata(n_obs: int) -> AnnData:

jax_not_installed_skip = pytest.mark.skipif(_jax_not_installed(), reason="JAX is not installed.")
gamr_skip = pytest.mark.skipif(_rpy2_mgcv_not_installed(), reason="Cannot import `rpy2` or R's `mgcv` package.")
scvelo_skip = pytest.mark.skipif(_scvelo_not_installed(), reason="scVelo is not installed.")

if __name__ == "__main__":
for size in [50, 100, 200]:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
create_kernels,
jax_not_installed_skip,
random_transition_matrix,
scvelo_skip,
)

_rtol = 1e-6
Expand Down Expand Up @@ -950,6 +951,7 @@ def test_save_to_anndata(self, adata: AnnData, tmpdir):
assert vk.params == bdata.uns[f"{key}_params"]["params"]


@scvelo_skip
class TestComputeProjection:
def test_no_transition_matrix(self, adata: AnnData):
with pytest.raises(RuntimeError, match=r"Compute transition matrix first as"):
Expand Down Expand Up @@ -1476,6 +1478,7 @@ def test_precomputed_autodetection(self, adata: AnnData, backward: bool):
assert key in pk.params["origin"]
np.testing.assert_array_equal(mat, pk.transition_matrix)

@scvelo_skip
def test_projection_explicit_connectivities(self, adata: AnnData):
mat = random_transition_matrix(adata.n_obs)
pk = PrecomputedKernel(mat, adata=adata)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
create_model,
gamr_skip,
resize_images_to_same_sizes,
scvelo_skip,
)

setup()
Expand Down Expand Up @@ -2253,6 +2254,7 @@ def test_gpcca_fate_probs_cont_same_clusters(self, mc: GPCCA, fpath: str):
def test_gpcca_fate_probs_cont_not_same(self, mc: GPCCA, fpath: str):
mc.plot_fate_probabilities(color="clusters", same_plot=False, dpi=DPI, save=fpath)

@scvelo_skip
@compare(kind="gpcca")
def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str):
mc.kernel.plot_projection(
Expand Down Expand Up @@ -3335,6 +3337,7 @@ def test_msc_legend_loc(self, mc: GPCCA, fpath: str):
mc.plot_macrostate_composition("clusters_enlarged", dpi=DPI, save=fpath, legend_loc="upper left out")


@scvelo_skip
class TestProjectionEmbedding:
@compare()
def test_scvelo_connectivity_kernel_emb_stream(self, adata: AnnData, fpath: str):
Expand Down
Loading