Source code for sclkme.tools._sketch

import warnings
from typing import Any, Dict, Literal, Optional, Union

import numpy as np
import scipy
from anndata import AnnData
from scanpy._utils import AnyRandom
from scanpy.tools._utils import _choose_representation
from sklearn.utils import check_random_state

from ._utils import random_feats


[docs]def sketch( adata: AnnData, n_sketch: int, use_rep: Optional[str] = None, n_pcs: Optional[int] = None, random_state: AnyRandom = 0, method: Literal["gs", "kernel_herding", "random"] = "gs", replace: bool = False, inplace: bool = False, sketch_kwds: Dict[str, Any] = {}, key_added: Optional[str] = None, copy: bool = False, ) -> Optional[AnnData]: """\ Cell sketching :cite:`Hie19gs` :cite:`Vishal22` *sketch* is a function that can be used for summarizing the landscape of a large single-cell dataset by selecting a compact subset of cells. This function supports three different methods to perform cell sketching: geoemetric sketching :cite:`Hie19gs`, kernel herding :cite:`Vishal22`, and simple random sampling without replacement. .. note:: More information and bug reports about the geometric sketching method `here <https://github.com/brianhie/geosketch>`__. Parameters ---------- adata Annotated data matrix. n_sketch number of subsampled cells use_rep Use the indicated representation. `'X'` or any key for `.obsm` is valid. If `None`, the representation is chosen automatically: For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. If 'X_pca' is not present, it's computed with default parameters. n_pcs Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None` random_state A numpy random seed. method Use 'gs' :cite:`Hie19gs`, 'kernel_herding' :cite:`Vishal22` or 'random' for for summarizing the landscape of large single-cell datasets. replace whether to sample with replacement (default=False) inplace It is a boolean which subsample cells in `adata` itself if True. sketch_kwds A dict that holds other specific parameters used by the sketching methods. (e.g. scaling parameter `gamma` and dimensionality of the random Fourier frequency features `D` in the kernel herding method.) key_added If not specified, the subsampled index in boolean format is stored in .obs['sketch'], and the sketching parameters are stored in .uns['sketch']. If specified, the subsampled index in boolean format is stored in .obs[key_added+'sketch'], and the sketching parameters are stored in .uns[key_added+'sketch']. copy Return a copy instead of writing to adata. Returns ------- Depending on `copy`, updates or returns `adata` with the following: See `key_added` parameter description for the storage path of subsampled indices. Example ------- >>> import scanpy as sc >>> import sclkme Load annotated dataset: >>> adata = sc.datasets.pbmc3k_processed() Run the cell sketching using geometric sketching: >>> sclkme.tl.sketch(adata, n_sketch=128, use_rep="X_pca", method="gs", key_added = "gs") Run the cell sketching using kernel herding: >>> sclkme.tl.sketch(adata, n_sketch=128, use_rep="X_pca", method="kernel_herding", key_added = "kh") Visualize the sketched dataset: >>> adata_sketch_gs = adata[adata.obs['gs_sketch']] >>> sc.pl.umap(adata_sketch_gs, color="louvain", size=100) >>> adata_sketch_kh = adata[adata.obs['kh_sketch']] >>> sc.pl.umap(adata_sketch_kh, color="louvain", size=100) """ # start = logg.info("Cell sketching") if method == "gs": try: from geosketch import gs except ImportError: raise ImportError( "Please install the package of geometric sketching:\n\t `pip install geosketch`." ) adata = adata.copy() if copy else adata if adata.is_view: adata._init_as_actual(adata.copy()) if key_added is None: key_added = "sketch" else: key_added = key_added + "_sketch" adata.uns[key_added] = {"params": {"n_sketch": n_sketch}, "method": method} if use_rep is not None: adata.uns[key_added]["params"]["use_rep"] = use_rep if n_pcs is not None: adata.uns[key_added]["params"]["n_pcs"] = n_pcs adata.uns[key_added]["params"]["random_state"] = random_state random_state = check_random_state(random_state) if n_sketch >= adata.n_obs: replace = True warnings.warn("n_sketch too large, adjusting to sampling with replacement.") # find the representation for sketching X = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs, silent=True) # main routine if method == "gs": # get the seed from random state random_seed = random_state.get_state()[1][0] sketch_index = gs(X, n_sketch, seed=random_seed, replace=replace) elif method == "kernel_herding": if replace: raise ValueError( "`kernel herding` cannot perform sampling with replacement, " "set `rsampling=False` or decrease the size of `n_sketch`." ) if isinstance(X, scipy.sparse.csr_matrix): X = X.toarray() if sketch_kwds: for key in sketch_kwds: adata.uns[key_added]["params"][key] = sketch_kwds[key] adata.uns[key_added]["params"]["gamma"] = sketch_kwds.get("gamma", 1) adata.uns[key_added]["params"]["D"] = sketch_kwds.get("D", 2000) phi_X = random_feats( X, gamma=adata.uns[key_added]["params"]["gamma"], D=adata.uns[key_added]["params"]["D"], random_state=random_state, ) sketch_index, _ = kernel_herding(X, phi_X, n_sketch) elif method == "random": size = adata.n_obs sketch_index = random_state.choice( np.arange(size), size=n_sketch, replace=replace ) else: raise ValueError( f"method: {method} is not supported, use `gs`, `kernel_herding` or `random` instead." ) # sort the indices sketch_index = np.sort(sketch_index) if inplace: adata = adata[sketch_index, :].copy() else: adata.obs[key_added] = False adata.obs.loc[adata.obs.index[sketch_index], key_added] = True # logg.info( # " finished", # time=start, # deep=f'added column: "{key_added}", sketch mask (bool) in `adata.obs` ', # ) return adata if copy else None
def kernel_herding(X: np.ndarray, phi: np.ndarray, num_samples: int): """\ Computes random fourier frequency features: https://arxiv.org/pdf/2207.00584.pdf Parameters X: np.ndarray array of input data (dimensions = cells x features) phi: np.ndarray random fourier features generated from the input data: `X` num_samples: int number of cells to subsample ---------- Returns indices: np.ndarray indices of sampled cells (dimensions = (`num_samples`, )) subsample: np.ndarray subsampled `num_samples` cells (dimensions = `num_samples` x features) ---------- """ w_t = np.mean(phi, axis=0) w_0 = w_t indices, subsample = [], [] ind_set = set() while len(indices) < num_samples: new_ind = np.argmax(np.dot(phi, w_t)) x_t = X[new_ind] w_t = w_t + w_0 - phi[new_ind] if new_ind not in ind_set: indices.append(new_ind) subsample.append(x_t) ind_set.add(new_ind) indices = np.asarray(indices) subsample = np.stack(subsample, axis=0) return indices, subsample