import anndata as ad
import scanpy as sc
import squidpy as sq
from squidpy._constants._constants import CoordType

import numpy as np
from scipy.sparse import spmatrix
from scipy.interpolate import griddata

import pandas as pd

from PIL import Image
from matplotlib import colors
from sklearn.cluster import KMeans

from typing import List, Union, Optional, Tuple, Dict
import numbers

from . import models as m
from . import utils as ut
from pathlib import Path

[docs]def get_landmark_distance( adata: ad.AnnData, landmark_position_key: str = "curated_landmarks", landmark_distance_key: str = "landmark_distances", reference: Optional[Union[m.Reference, np.ndarray]] = None, **kwargs, ) -> None: """compute landmark distances :param adata: AnnData object where distance between landmarks and observations should be measured :type adata: ad.AnnData :param landmark_position_key: key of landmark coordinates, defaults to "curated_landmarks :type landmark_position_key: str :param landmark_position_key: key to use for landmark distances in .obsm, defaults to "landmark_distances" :type landmark_distance_key: str = "landmark_distances", :param reference: provide reference if non-homogeneous distortions should be corrected for using TPS (thin plate splines) :type reference: Optional[Union[m.Reference, np.ndarray]] """ assert "spatial" in adata.obsm, "no coordinates for the data" assert landmark_position_key in adata.uns, "landmarks not found in data" n_obs = adata.shape[0] n_landmarks = adata.uns[landmark_position_key].shape[0] distances = np.zeros((n_obs, n_landmarks)) obs_crd = adata.obsm["spatial"].copy() lmk_crd = adata.uns[landmark_position_key].copy() if isinstance(lmk_crd, pd.DataFrame): lmk_crd_names = list(lmk_crd.index) lmk_crd = lmk_crd.values else: lmk_crd_names = None if reference is not None: import morphops as mops if isinstance(reference, m.Reference): ref_lmk_crd = reference.landmarks.numpy() ref_lmk_crd_names = list(reference.lmk_to_pos.keys()) if isinstance(reference, np.ndarray): ref_lmk_crd = reference ref_lmk_crd_names = None ref_lmk_crd, lmk_crd = ut.match_arrays_by_names( ref_lmk_crd, lmk_crd, ref_lmk_crd_names, lmk_crd_names, ) obs_crd = mops.tps_warp(lmk_crd, ref_lmk_crd, obs_crd) lmk_crd = mops.tps_warp(lmk_crd, ref_lmk_crd, lmk_crd) for obs in range(n_obs): obs_x, obs_y = obs_crd[obs, :] for lmk in range(n_landmarks): lmk_x, lmk_y = lmk_crd[lmk, :] distances[obs, lmk] = ((obs_x - lmk_x) ** 2 + (obs_y - lmk_y) ** 2) ** 0.5 adata.obsm[landmark_distance_key] = distances
[docs]def reference_to_grid( ref_img: Union[Image.Image, str], n_approx_points: int = 1e3, background_color: Union[str, Union[np.ndarray, tuple]] = "white", n_regions: int = 1, ) -> Tuple[np.ndarray, np.ndarray]: """convert image to grid of observations when creating a reference we will discretize the domain into fixed locations where feature values will be predicted :param ref_img: PIL.Image or path of/to reference image :type ref_img: Union[Image.Image, str] :param n_approx_points: approximate number of points to include in the discretized grid. The number of grid points will be in the magnitude of the provided number, defaults to 1000. :type n_approx_points: int = 1e3, :param background: background color of reference image, all elements with this color will be excluded. Can be either an array/tuple of RGB values as well as matplotlib color strings. Defaults to "white". :type background_color: Union[str, np.ndarray, tuple] :param n_regions: number of regions (indicated by different colors) contained in the reference. :type n_regions: int = 1, :returns: A tuple where the first element is an n_obs x 2 array representing the coordinates of each grid point. Second element is a n_obs numeric vector where the i:th element indicates the region that the i:th observation belongs to. :rtype: Tuple[np.ndarray,np.ndarray] """ if isinstance(ref_img, str): ref_img_pth = Path(ref_img) if ref_img_pth.exists(): ref_img = else: raise FileNotFoundError( f"The file {ref_img_pth} cannot be found." " Please enter a different image path." ) w, h = ref_img.size new_w = 500 w_ratio = new_w / w new_h = int(round(h * w_ratio)) ref_img = ref_img if ref_img.mode == "L" else ref_img.convert("RGBA") img = ref_img.resize((new_w, new_h)) img = np.asarray(img) if img.max() > 1: img = img / 255 if len(img.shape) == 3: if isinstance(background_color, str): background_color = colors.to_rgba(background_color) elif isinstance(background_color, numbers.Number): background_color = np.array(background_color) else: raise ValueError(f"Color format {background_color} not supported.") km = KMeans(n_clusters=n_regions + 1, random_state=1) nw, nh, nc = img.shape idx = km.fit_predict(img.reshape(nw * nh, nc)) centers = km.cluster_centers_[:, 0:3] bg_id = np.argmin(np.linalg.norm(centers - background_color[0:3], axis=1)) bg_row, bg_col = np.unravel_index(np.where(idx == bg_id), shape=(nw, nh)) img = np.ones((nw, nh)) img[bg_row, bg_col] = 0 reg_img = np.ones(img.shape) * -1 for clu in np.unique(idx): if clu != bg_id: reg_row, reg_col = np.unravel_index( np.where(idx == clu), shape=(nw, nh) ) reg_img[reg_row, reg_col] = clu elif len(img.shape) == 2: color_map = dict( black=0, white=1, ) is_ref = img.round(0) == color_map[background_color] img = np.zeros((img.shape[0], img.shape[1])) img[is_ref] = 1 img[~is_ref] = 0 reg_img = np.ones(img.shape) reg_img[img == 0] = -1 else: raise Exception("Wrong image format, must be grayscale or color") f_ref = img.sum() / (img.shape[0] * img.shape[1]) f_ratio = img.shape[1] / img.shape[0] n_points = n_approx_points / f_ref size_x = np.sqrt(n_points / f_ratio) size_y = size_x * f_ratio xx = np.linspace(0, img.shape[0], int(round(size_x))) yy = np.linspace(0, img.shape[1], int(round(size_y))) xx, yy = np.meshgrid(xx, yy) crd = np.hstack((xx.flatten()[:, np.newaxis], yy.flatten()[:, np.newaxis])) img_x = np.arange(img.shape[0]) img_y = np.arange(img.shape[1]) img_xx, img_yy = np.meshgrid(img_x, img_y) img_xx = img_xx.flatten() img_yy = img_yy.flatten() img_crd = np.hstack((img_xx[:, np.newaxis], img_yy[:, np.newaxis])) del img_xx, img_yy, img_x, img_y # zz = griddata(img_crd, img.T.flatten(), (xx, yy)) ww = griddata(img_crd, reg_img.T.flatten(), (xx, yy), method="nearest") # crd = crd[zz.flatten() >= 0.5] crd = crd[ww.flatten() >= 0.0] crd = crd / w_ratio meta = ww.flatten()[ww.flatten() >= 0].round(0).astype(int) uni, mem = np.unique(meta, return_counts=True) srt = np.argsort(mem)[::-1] rordr = {old: new for new, old in enumerate(uni[srt])} meta = np.array([rordr[x] for x in meta]) return crd[:, [1, 0]], meta
[docs]def match_scales( adata: ad.AnnData, reference: Union[np.ndarray, "m.Reference"], ) -> None: """match scale between observed and spatial domains Simple scaling with a single value based on the distances between landmarks. :param adata: AnnData object holding observed data :type adata: ad.AnnData :param reference: Refernce to which observed data will be transferred :type reference: Union[np.ndarray, "m.Reference"] """ n_lmk_thrs = 100 if "curated_landmarks" not in adata.uns.keys(): raise Exception("curated_landmarks key nor found in the adata.uns slot") elif hasattr(adata.uns["curated_landmarks"], "copy"): obs_lmk = adata.uns["curated_landmarks"].copy() else: obs_lmk = adata.uns["curated_landmarks"] if isinstance(obs_lmk, pd.DataFrame): obs_lmk_names = list(obs_lmk.index) obs_lmk = obs_lmk.values elif isinstance(obs_lmk, np.ndarray): obs_lmk_names = None else: raise NotImplementedError( "landmarks of type : {} is not supported".format(type(obs_lmk)) ) if isinstance(reference, m.Reference): ref_lmk = reference.landmarks.detach().numpy() ref_lmk_names = list(reference.lmk_to_pos.keys()) elif isinstance(reference, pd.DataFrame): ref_lmk = reference.values ref_lmk_names = list(reference.index) elif isinstance(reference, np.ndarray): ref_lmk = reference ref_lmk_names = None else: raise NotImplementedError( "reference of type : {} is not supported".format(type(reference)) ) ref_lmk, obs_lmk = ut.match_arrays_by_names( ref_lmk, obs_lmk, ref_lmk_names, obs_lmk_names, ) n_lmk = len(ref_lmk) n_use_lmk = min(n_lmk, n_lmk_thrs) lmk_idx = np.random.choice(n_lmk, replace=False, size=n_use_lmk) av_ratio = ut.average_distance_ratio(ref_lmk, obs_lmk, lmk_idx) adata.obsm["spatial"] = adata.obsm["spatial"] * av_ratio adata.uns["curated_landmarks"] = adata.uns["curated_landmarks"] * av_ratio try: sample_name = list(adata.uns["spatial"].keys())[0] scalef_names = [x for x in adata.uns["spatial"][sample_name] if "scalef" in x] for scalef in scalef_names: old_sf = adata.uns["spatial"][sample_name]["scalefactors"].get(scalef, 1) adata.uns["spatial"][sample_name]["scalefactors"][scalef] = ( old_sf / av_ratio ) except KeyError: pass
[docs]def join_adatas( adatas: List[ad.AnnData], **kwargs, ) -> None: """join together a set of AnnData objects :param adatas: AnnData objects to be merged :type adatas: List[ad.AnnData] """ obs = np.array([0] + [a.shape[0] for a in adatas]) features = pd.Index([]) for a in adatas: features = features.union(a.var.index) n_features = len(features) starts = np.cumsum(obs).astype(int) n_obs = starts[-1] joint_matrix = pd.DataFrame( np.zeros((n_obs, n_features)), columns=features, ) joint_obs = pd.DataFrame([]) joint_obsm = {k: [] for k in adatas[0].obsm.keys()} for k, adata in enumerate(adatas): inter_features = features.intersection(adata.var.index) joint_matrix.loc[starts[k] : (starts[k + 1] - 1), inter_features] = ( adata.to_df().loc[:, inter_features].values ) tmp_obs = adata.obs.copy() tmp_obs["split_id"] = k joint_obs = pd.concat((joint_obs, tmp_obs)) for key in joint_obsm.keys(): joint_obsm[key].append(adatas[k].obsm[key]) for key in joint_obsm.keys(): joint_obsm[key] = np.concatenate(joint_obsm[key]) var = pd.DataFrame( features.values, index=features, columns=["features"], ) adata = ad.AnnData( joint_matrix, obs=joint_obs, var=var, ) adata.obsm = joint_obsm return adata
[docs]def spatial_smoothing( adata: ad.AnnData, distance_key: str = "spatial", n_neigh: int = 4, coord_type: Union[str, CoordType] = "generic", sigma: float = 50, **kwargs, ) -> None: """spatial smoothing function :param adata: AnnData object holding data to be smoothed :type adata: ad.AnnData, :param distance_key: key holding spatial coordinates in .obsm, defaults to spatial :type distance_key: str :param n_neigh: number of neighbors to use for smoothing, defaults to 4 :type n_neigh: int :param coord_type: type of coordinates, see squidpy documentation for more information, defaults to "generic". :type coord_type: Union[str, CoordType], :param sigma: sigma value to use in smoothing, higher values gives higher influence to far away points on a given grid point. :type sigma: float = 50, """ if "spatial_key" in kwargs: spatial_key = kwargs.pop("spatial_key") else: spatial_key = "spatial" if spatial_key not in adata.obsm.keys(): raise Exception("Spatial key not present in AnnData object") # TODO: n_neigh -> n_neighs in new squidpy if distance_key not in adata.obsp.keys(): adata, spatial_key=spatial_key, coord_type=coord_type, n_neighs=n_neigh, key_added=distance_key, **kwargs, ) distance_key = distance_key + "_distances" gr = adata.obsp[distance_key] n_obs, n_features = adata.shape new_X = np.zeros((n_obs, n_features)) old_X = adata.X if isinstance(old_X, spmatrix): sp_type = type(old_X) old_X = np.array(old_X.todense()) else: sp_type = None for obs in range(n_obs): ptr = slice(gr.indptr[obs], gr.indptr[obs + 1]) ind = gr.indices[ptr] ws = np.append([ptr], 0) ws = np.exp(-ws / sigma) ws /= ws.sum() ws = ws.reshape(-1, 1) new_X[obs, :] = np.sum(old_X[np.append(ind, obs), :] * ws, axis=0) if sp_type is not None: new_X = sp_type(new_X) adata.layers["smoothed"] = new_X
[docs]def intersect_features( adatas: Union[List[ad.AnnData], Dict[str, ad.AnnData]], ) -> None: if isinstance(adatas, list): _adatas = dict(enumerate(adatas)) else: _adatas = adatas for k, adata in enumerate(_adatas.values()): if k == 0: inter_features = set(adata.var.index.values) else: inter_features = inter_features.intersection(set(adata.var.index.values)) for key, adata in _adatas.items(): keep_features = np.array( list(map(lambda x: x in inter_features, adata.var.index.values)) ) adatas[key] = adatas[key][:, keep_features]
[docs]def joint_highly_variable_genes( adatas: Union[List[ad.AnnData], Dict[str, ad.AnnData]], **kwargs, ) -> None: if isinstance(adatas, list): _adatas = dict(enumerate(adatas)) else: _adatas = adatas joint_adatas = ad.concat(_adatas, label="origin", join="inner") sc.pp.log1p(joint_adatas) sc.pp.highly_variable_genes(joint_adatas, **kwargs) hvg_genes = joint_adatas.var.index.values[ joint_adatas.var["highly_variable"].values ] for key in _adatas.keys(): is_hvg = list(map(lambda x: x in hvg_genes, adatas[key].var.index.values)) adatas[key].var["highly_variable"] = np.zeros(adatas[key].shape[1]).astype(bool) adatas[key].var.loc[is_hvg, "highly_variable"] = True return adatas
[docs]def default_normalization( adata: ad.AnnData, min_cells: float = 0.1, total_counts: float = 1e4, exclude_highly_expressed: bool = False, compute_highly_variable_genes: bool = False, n_top_genes: int = 2000, ) -> None: """default normalization recipe the normalization strategy that applied for a majority of the analyses presented in the original manuscript. We abstain from calling it a recommended strategy, as the best strategy is depends on your data. However, this strategy have worked well with several data types. The recipe is based on preprocessing functions from the :mod:`scanpy.preprocess` module and is given as follows: .. code-block:: python sc.pp.filter_genes(adata, min_cells=min_cells) sc.pp.normalize_total(adata,total_counts, exclude_highly_expressed=exclude_highly_expressed) sc.pp.log1p(adata) sc.pp.scale(adata) :param adata: anndata object to normalize :type adata: ad.AnnData, :param min_cells: argument to :func:`scanpy.preprocess.filter_genes` :type min_cells: float = 0.1, :param total_counts: argument to :func:`scanpy.preprocess.normalize_total`, default is `1e4` :type total_counts: float :param exclude_highly_expressed: argument to :func:`scanpy.preprocess.normalize_total`, default False :type exclude_highly_expressed: bool """ if min_cells < 1: min_cells = int(adata.shape[0] * min_cells) sc.pp.filter_genes(adata, min_cells=min_cells) sc.pp.normalize_total( adata, total_counts, exclude_highly_expressed=exclude_highly_expressed ) sc.pp.log1p(adata) if compute_highly_variable_genes: sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes) sc.pp.scale(adata)