diff --git a/pyproject.toml b/pyproject.toml index 8b2d1a50..2c51baf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ dependencies = [ "scanpy>=1.11", "scikit-learn>=1.6", "scipy>=1.14", - "scvelo>=0.3", "seaborn>=0.13", "session-info2>=0.2", "wrapt>=1.16", @@ -69,6 +68,7 @@ optional-dependencies.moscot = [ "moscot" ] optional-dependencies.petsc = [ "mpi4py", "petsc4py", "slepc4py" ] optional-dependencies.plot = [ "adjusttext" ] optional-dependencies.r = [ "rpy2>=3.3" ] +optional-dependencies.scvelo = [ "scvelo>=0.3" ] urls."Bug Tracker" = "https://github.com/theislab/cellrank/issues" urls.Documentation = "https://cellrank.readthedocs.io" urls.Download = "https://cellrank.readthedocs.io/en/latest/installation.html" diff --git a/src/cellrank/estimators/mixins/_lineage_drivers.py b/src/cellrank/estimators/mixins/_lineage_drivers.py index c063457c..e15c1b75 100644 --- a/src/cellrank/estimators/mixins/_lineage_drivers.py +++ b/src/cellrank/estimators/mixins/_lineage_drivers.py @@ -286,7 +286,7 @@ def prepare_format( axes = np.ravel([axes]) basis = kwargs.pop("basis", "umap") - # scvelo compat: "right" means "right margin" in scanpy + # "right" means "right margin" in scanpy if kwargs.get("legend_loc") == "right": kwargs["legend_loc"] = "right margin" _i = 0 diff --git a/src/cellrank/estimators/terminal_states/_term_states_estimator.py b/src/cellrank/estimators/terminal_states/_term_states_estimator.py index b63de011..99434309 100644 --- a/src/cellrank/estimators/terminal_states/_term_states_estimator.py +++ b/src/cellrank/estimators/terminal_states/_term_states_estimator.py @@ -443,7 +443,7 @@ def _plot_discrete( same_plot = same_plot or len(names) == 1 kwargs.setdefault("legend_loc", "on data") - # scvelo compat: "right" means "right margin" in scanpy + # "right" means "right margin" in scanpy if kwargs.get("legend_loc") == "right": kwargs["legend_loc"] = "right margin" kwargs.pop("color_map", None) @@ -580,13 +580,12 @@ def _plot_continuous( ) elif mode == PlotMode.EMBEDDING: kwargs.setdefault("legend_loc", "on data") - # scvelo compat: "right" means "right margin" in scanpy + # "right" means "right margin" in scanpy if kwargs.get("legend_loc") == "right": kwargs["legend_loc"] = "right margin" if same_plot: if color: - # https://github.com/theislab/scvelo/issues/673 logger.warning("Ignoring `color` when `mode='embedding'` and `same_plot=True`") title = [_title] if title is None else title _plot_color_gradients(self.adata, _data, basis=basis, title=title, diff --git a/src/cellrank/kernels/_base_kernel.py b/src/cellrank/kernels/_base_kernel.py index ec36fb84..e463c658 100644 --- a/src/cellrank/kernels/_base_kernel.py +++ b/src/cellrank/kernels/_base_kernel.py @@ -226,7 +226,7 @@ def plot_random_walks( %(parallel)s %(plotting)s kwargs - Keyword arguments for :func:`~scvelo.pl.scatter`. + Keyword arguments for :func:`~scanpy.pl.embedding`. Returns ------- diff --git a/src/cellrank/kernels/utils/_random_walk.py b/src/cellrank/kernels/utils/_random_walk.py index aaafd773..aab655c9 100644 --- a/src/cellrank/kernels/utils/_random_walk.py +++ b/src/cellrank/kernels/utils/_random_walk.py @@ -239,7 +239,7 @@ def plot( # scVelo accepted color="none" to skip coloring; scanpy does not. if kwargs.get("color") == "none": kwargs.pop("color") - # scvelo compat: "right" means "right margin" in scanpy + # "right" means "right margin" in scanpy if kwargs.get("legend_loc") == "right": kwargs["legend_loc"] = "right margin" sc.pl.embedding(self._adata, basis=basis, show=False, ax=ax, **kwargs) diff --git a/tests/_helpers.py b/tests/_helpers.py index e87907c3..680e78cc 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -6,7 +6,6 @@ import pytest import scanpy as sc import scipy.sparse as sp -import scvelo as scv from anndata import AnnData from pandas.testing import assert_frame_equal, assert_series_equal from PIL import Image @@ -26,6 +25,15 @@ def _jax_not_installed() -> bool: return True +def _scvelo_not_installed() -> bool: + try: + import scvelo # noqa + + return False + except ImportError: + return True + + def _rpy2_mgcv_not_installed() -> bool: try: import rpy2 @@ -362,6 +370,8 @@ def _create_dummy_adata(n_obs: int) -> AnnData: ------- The created adata object. """ + import scvelo as scv + np.random.seed(42) # noqa: NPY002 adata = scv.datasets.toy_data(n_obs=n_obs) adata.obs_names_make_unique() @@ -396,6 +406,7 @@ def _create_dummy_adata(n_obs: int) -> AnnData: jax_not_installed_skip = pytest.mark.skipif(_jax_not_installed(), reason="JAX is not installed.") gamr_skip = pytest.mark.skipif(_rpy2_mgcv_not_installed(), reason="Cannot import `rpy2` or R's `mgcv` package.") +scvelo_skip = pytest.mark.skipif(_scvelo_not_installed(), reason="scVelo is not installed.") if __name__ == "__main__": for size in [50, 100, 200]: diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 8aa4b581..d6edc566 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -41,6 +41,7 @@ create_kernels, jax_not_installed_skip, random_transition_matrix, + scvelo_skip, ) _rtol = 1e-6 @@ -950,6 +951,7 @@ def test_save_to_anndata(self, adata: AnnData, tmpdir): assert vk.params == bdata.uns[f"{key}_params"]["params"] +@scvelo_skip class TestComputeProjection: def test_no_transition_matrix(self, adata: AnnData): with pytest.raises(RuntimeError, match=r"Compute transition matrix first as"): @@ -1476,6 +1478,7 @@ def test_precomputed_autodetection(self, adata: AnnData, backward: bool): assert key in pk.params["origin"] np.testing.assert_array_equal(mat, pk.transition_matrix) + @scvelo_skip def test_projection_explicit_connectivities(self, adata: AnnData): mat = random_transition_matrix(adata.n_obs) pk = PrecomputedKernel(mat, adata=adata) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c46b9043..f96c3716 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -24,6 +24,7 @@ create_model, gamr_skip, resize_images_to_same_sizes, + scvelo_skip, ) setup() @@ -2253,6 +2254,7 @@ def test_gpcca_fate_probs_cont_same_clusters(self, mc: GPCCA, fpath: str): def test_gpcca_fate_probs_cont_not_same(self, mc: GPCCA, fpath: str): mc.plot_fate_probabilities(color="clusters", same_plot=False, dpi=DPI, save=fpath) + @scvelo_skip @compare(kind="gpcca") def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str): mc.kernel.plot_projection( @@ -3335,6 +3337,7 @@ def test_msc_legend_loc(self, mc: GPCCA, fpath: str): mc.plot_macrostate_composition("clusters_enlarged", dpi=DPI, save=fpath, legend_loc="upper left out") +@scvelo_skip class TestProjectionEmbedding: @compare() def test_scvelo_connectivity_kernel_emb_stream(self, adata: AnnData, fpath: str):