Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 74 additions & 26 deletions src/rapids_singlecell/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import gc
import warnings
from typing import TYPE_CHECKING

import cupy as cp
Expand All @@ -12,6 +14,9 @@
from scanpy.get import _get_obs_rep

from rapids_singlecell import preprocessing as pp
from rapids_singlecell._compat import DaskArray
from rapids_singlecell.get import X_to_GPU
from rapids_singlecell.preprocessing._utils import _check_gpu_X

from . import pipeline
from .core import Scrublet
Expand Down Expand Up @@ -74,6 +79,7 @@ def scrublet(
preprocessing, simulate doublets with
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, and run the core scrublet
function :func:`~rapids_singlecell.pp.scrublet` with ``adata_sim`` set.
Scrublet can also be run with a `dask array` if a batch key is provided. Please make sure that each batch can fit into memory. In addition to that scrublet will not return the full scrublet results, but only the `doublet score` and `predicted doublet`, not `.uns['scrublet']`. `adata_sim` is not supported for`dask arrays`.

Parameters
----------
Expand All @@ -89,6 +95,7 @@ def scrublet(
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, with same number of vars
as adata. This should have been built from adata_obs after
filtering genes and cells and selecting highly-variable genes.
Not supported for dask arrays.
batch_key
Optional :attr:`~anndata.AnnData.obs` column name discriminating between batches.
sim_doublet_ratio
Expand Down Expand Up @@ -180,11 +187,10 @@ def scrublet(

start = logg.info("Running Scrublet")

adata_obs = adata.copy()

def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
# With no adata_sim we assume the regular use case, starting with raw
# counts and simulating doublets
ad_obs.X = X_to_GPU(ad_obs.X)

if ad_sim is None:
pp.filter_genes(ad_obs, min_cells=3, verbose=False)
Expand Down Expand Up @@ -241,7 +247,8 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
verbose=verbose,
)

return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]}
out = {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]}
return out

if batch_key is not None:
if batch_key not in adata.obs.keys():
Expand All @@ -252,33 +259,74 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
# Run Scrublet independently on batches and return just the
# scrublet-relevant parts of the objects to add to the input object

batches = np.unique(adata.obs[batch_key])
scrubbed = [
_run_scrublet(
adata_obs[adata_obs.obs[batch_key] == batch].copy(),
adata_sim,
if isinstance(adata.X, DaskArray):
# Define function to process each batch chunk
def _process_batch_chunk(X_chunk):
"""Process a single batch chunk through Scrublet."""
batch_adata = AnnData(X_chunk)
batch_results = _run_scrublet(batch_adata, None)
return np.array(
batch_results["obs"][["doublet_score", "predicted_doublet"]]
).astype(np.float64)

# Get batch information and sort data by batch
batch_codes = adata.obs[batch_key].astype("category").cat.codes
sort_indices = np.argsort(batch_codes)
X_sorted = adata.X[sort_indices]

# Calculate chunk sizes based on batch sizes
batch_sizes = np.bincount(batch_codes.iloc[sort_indices])
X_rechunked = X_sorted.rechunk((tuple(batch_sizes), adata.X.shape[1]))

# Process all batches in parallel using map_blocks
batch_results = X_rechunked.map_blocks(
_process_batch_chunk,
meta=np.array([], dtype=np.float64),
dtype=np.float64,
chunks=(X_rechunked.chunks[0], 2),
)
for batch in batches
]
scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed])

# Now reset the obs to get the scrublet scores

adata.obs = scrubbed_obs.loc[adata.obs_names.values]

# Save the .uns from each batch separately

adata.uns["scrublet"] = {}
adata.uns["scrublet"]["batches"] = dict(
zip(batches, [scrub["uns"] for scrub in scrubbed])
)

# Record that we've done batched analysis, so e.g. the plotting
# function knows what to do.

adata.uns["scrublet"]["batched_by"] = batch_key
# Convert results to DataFrame and restore original order
results_df = pd.DataFrame(
batch_results.compute(), columns=["doublet_score", "predicted_doublet"]
)
final_results = results_df.iloc[np.argsort(sort_indices)]

# Update the original AnnData object with results
adata.obs["doublet_score"] = final_results["doublet_score"].values
adata.obs["predicted_doublet"] = final_results[
"predicted_doublet"
].values.astype(bool)
adata.uns["scrublet"] = {"batched_by": batch_key}

else:
batches = np.unique(adata.obs[batch_key])
scrubbed = [
_run_scrublet(
adata[adata.obs[batch_key] == batch].copy(),
adata_sim,
)
for batch in batches
]

scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed])
# Now reset the obs to get the scrublet scores
adata.obs = scrubbed_obs.loc[adata.obs_names.values]

# Save the .uns from each batch separately

adata.uns["scrublet"] = {}
adata.uns["scrublet"]["batches"] = dict(
zip(batches, [scrub["uns"] for scrub in scrubbed])
)
adata.uns["scrublet"]["batched_by"] = batch_key

else:
adata_obs = adata.copy()
if isinstance(adata_obs.X, DaskArray):
raise ValueError(
"Dask arrays are not supported for Scrublet without a batch key. Please provide a batch key."
)
scrubbed = _run_scrublet(adata_obs, adata_sim)

# Copy outcomes to input object from our processed version
Expand Down
49 changes: 49 additions & 0 deletions tests/dask/test_dask_scrublet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import cupy as cp
import numpy as np
import pytest
from cupyx.scipy import sparse as cusparse
from scanpy.datasets import paul15, pbmc3k

import rapids_singlecell as rsc
from rapids_singlecell._testing import (
as_dense_cupy_dask_array,
as_sparse_cupy_dask_array,
)


@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
def test_dask_scrublet(data_kind):
if data_kind == "sparse":
adata_1 = pbmc3k()[200:400].copy()
adata_2 = pbmc3k()[200:400].copy()
adata_2.X = cusparse.csr_matrix(adata_2.X.astype(np.float64))
adata_1.X = as_sparse_cupy_dask_array(adata_1.X.astype(np.float64))
elif data_kind == "dense":
adata_1 = paul15()[200:400].copy()
adata_2 = paul15()[200:400].copy()
adata_2.X = cp.array(adata_2.X.astype(np.float64))
adata_1.X = as_dense_cupy_dask_array(adata_1.X.astype(np.float64))
else:
raise ValueError(f"Unknown data_kind {data_kind}")

batch = np.random.randint(0, 2, size=adata_1.shape[0])
adata_1.obs["batch"] = batch
adata_2.obs["batch"] = batch
rsc.pp.scrublet(adata_1, batch_key="batch", verbose=False)

# sort adata_2 to compare results
batch_codes = adata_2.obs["batch"].astype("category").cat.codes
order = np.argsort(batch_codes)
adata_2 = adata_2[order]

rsc.pp.scrublet(adata_2, batch_key="batch", verbose=False)
adata_2 = adata_2[np.argsort(order)]

np.testing.assert_allclose(
adata_1.obs["doublet_score"], adata_2.obs["doublet_score"]
)
np.testing.assert_array_equal(
adata_1.obs["predicted_doublet"], adata_2.obs["predicted_doublet"]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check that both .uns["scrublet"]["batched_by"] and .uns["scrublet"]["batches"] match

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment I don't return .uns because of the way dask handles things. I can only return NDArrays

Loading