Source code for spaTrack.single_time.lap

import anndata as ad
from typing import Callable, Union
import networkx as nx
import numpy as np
import scipy.sparse as sp
from warnings import warn
from scipy import interpolate
from scipy.optimize import minimize
from scipy.interpolate import interp1d
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.neighbors import NearestNeighbors

# from .plot_least_action_path import *
from .utils import nearest_neighbors

# We caculated and visualized the LAP path along the trajectory using the corresponding functions implemented in the Dynamo.
# Ref: Qiu X, Zhang Y, Martin-Rufino JD, Weng C, Hosseinzadeh S, Yang D, et al. Mapping transcriptomic vector fields of single cells. Cell. 2022 Feb 17;185(4):690-711.e45. doi: 10.1016/j.cell.2021.12.045. 
# dynamo: https://github.com/aristoteleo/dynamo-release.

def distance_point_to_segment(point, segment_start, segment_end):
    segment_vector = segment_end - segment_start
    point_vector = point - segment_start
    projection = np.dot(point_vector, segment_vector) / np.dot(
        segment_vector, segment_vector
    )

    if projection < 0:
        distance = np.linalg.norm(point - segment_start)
        intersection = segment_start
    elif projection > 1:
        distance = np.linalg.norm(point - segment_end)
        intersection = segment_end
    else:
        distance = np.linalg.norm(
            np.cross(segment_vector, point_vector)
        ) / np.linalg.norm(segment_vector)
        intersection = segment_start + projection * segment_vector

    return distance, intersection


[docs]def map_cell_to_LAP(adata,basis='spatial',cell_neighbors=150): """ Assign a new pseudotime value to each of these cells based on their position along the LAP. Parameters ---------- adata An :class:`~anndata.AnnData` object. basis The embedding data. (Default: 'spatial') cell_neighbors The number of cell neighbors. (Default: 150) Returns ----------- tuple The ptime of cells map to the LAP and selected neighbor cells. """ LAP_points = adata.uns["LAP_"+basis]["prediction"][0] LAP_neighbor_cells = nearest_neighbors(LAP_points, adata.obsm["X_"+basis], n_neighbors=cell_neighbors) LAP_neighbor_cells = np.unique(LAP_neighbor_cells.flatten()) n_segments = len(LAP_points) - 1 total_length_dict = {} total_length = 0 for i in range(n_segments): # do not traverse the last line segment segment_length = np.linalg.norm(LAP_points[i] - LAP_points[i + 1]) total_length += segment_length total_length_dict[i] = total_length point_total_length_list = [] for i in LAP_neighbor_cells: point = adata.obsm["X_"+basis][i] min_distance = np.linalg.norm(point - LAP_points[0]) min_intersection = LAP_points[0] min_segment = 0 for j in range(n_segments): segment_start = LAP_points[j] segment_end = LAP_points[j + 1] distance, intersection = distance_point_to_segment( point, segment_start, segment_end ) if distance < min_distance: min_distance = distance min_intersection = intersection min_segment = j point_total_length = total_length_dict[min_segment] - np.linalg.norm( min_intersection - LAP_points[min_segment + 1] ) point_total_length_list.append(point_total_length) LAP_ptime = point_total_length_list / max(point_total_length_list) return LAP_ptime, LAP_neighbor_cells
def log1p_(adata, X_data): if "norm_method" not in adata.uns["pp"].keys(): return X_data else: if adata.uns["pp"]["norm_method"] is None: if sp.issparse(X_data): X_data.data = np.log1p(X_data.data) else: X_data = np.log1p(X_data) return X_data def fetch_states(adata, init_states, init_cells, basis, layer, average, t_end): if basis is not None: vf_key = "VecFld_" + basis else: vf_key = "VecFld" VecFld = adata.uns[vf_key] X = VecFld["X"] valid_genes = None if init_states is None and init_cells is None: raise Exception("Either init_state or init_cells should be provided.") elif init_states is None and init_cells is not None: if type(init_cells) == str: init_cells = [init_cells] intersect_cell_names = sorted( set(init_cells).intersection(adata.obs_names), key=lambda x: list(init_cells).index(x), ) _cell_names = init_cells if len(intersect_cell_names) == 0 else intersect_cell_names if basis is not None: init_states = adata[_cell_names].obsm["X_" + basis].copy() if len(_cell_names) == 1: init_states = init_states.reshape((1, -1)) VecFld = adata.uns["VecFld_" + basis] X = adata.obsm["X_" + basis] valid_genes = [basis + "_" + str(i) for i in np.arange(init_states.shape[1])] if init_states.shape[0] > 1 and average in ["origin", "trajectory", True]: init_states = init_states.mean(0).reshape((1, -1)) if t_end is None: t_end = getTend(X, VecFld["V"]) if sp.issparse(init_states): init_states = init_states.A return init_states, VecFld, t_end, valid_genes def getTend(X, V): xmin, xmax = X.min(0), X.max(0) V_abs = np.abs(V) t_end = np.max(xmax - xmin) / np.percentile(V_abs[V_abs > 0], 1) return t_end def _nearest_neighbors(coord, coords, k=5): nbrs = NearestNeighbors(n_neighbors=k, algorithm="ball_tree").fit(coords) _, neighs = nbrs.kneighbors(np.atleast_2d(coord)) return neighs def arclength_sampling_n(X, num, t=None): arclen = np.cumsum(np.linalg.norm(np.diff(X, axis=0), axis=1)) arclen = np.hstack((0, arclen)) z = np.linspace(arclen[0], arclen[-1], num) X_ = interpolate.interp1d(arclen, X, axis=0)(z) if t is not None: t_ = interpolate.interp1d(arclen, t)(z) return X_, arclen[-1], t_ else: return X_, arclen[-1] def get_init_path(G, start, end, coords, interpolation_num=20): source_ind = _nearest_neighbors(start, coords, k=1)[0][0] target_ind = _nearest_neighbors(end, coords, k=1)[0][0] path = nx.shortest_path(G, source_ind, target_ind) init_path = coords[path, :] # _, arclen, _ = remove_redundant_points_trajectory(init_path, tol=1e-4, output_discard=True) # arc_stepsize = arclen / (interpolation_num - 1) # init_path_final, _, _ = arclength_sampling(init_path, step_length=arc_stepsize, t=np.arange(len(init_path))) init_path_final, _, _ = arclength_sampling_n(init_path, interpolation_num, t=np.arange(len(init_path))) # add the beginning and end point init_path_final = np.vstack((start, init_path_final, end)) return init_path_final def least_action_path(start, end, vf_func, jac_func, n_points=20, init_path=None, D=1, dt_0=1, EM_steps=2): if init_path is None: path = ( np.tile(start, (n_points + 1, 1)) + (np.linspace(0, 1, n_points + 1, endpoint=True) * np.tile(end - start, (n_points + 1, 1)).T).T ) else: path = np.array(init_path, copy=True) # initial dt estimation: t_dict = minimize(lambda t: action(path, vf_func, D=D, dt=t), dt_0) dt = t_dict["x"][0] while EM_steps > 0: EM_steps -= 1 path, dt, action_opt = lap_T(path, dt * len(path), vf_func, jac_func, D=D) return path, dt, action_opt def action_aux(path_flatten, vf_func, dim, start=None, end=None, **kwargs): path = reshape_path(path_flatten, dim, start=start, end=end) return action(path, vf_func, **kwargs) def action_grad_aux(path_flatten, vf_func, jac_func, dim, start=None, end=None, **kwargs): path = reshape_path(path_flatten, dim, start=start, end=end) return action_grad(path, vf_func, jac_func, **kwargs).flatten() def reshape_path(path_flatten, dim, start=None, end=None): path = path_flatten.reshape(int(len(path_flatten) / dim), dim) if start is not None: path = np.vstack((start, path)) if end is not None: path = np.vstack((path, end)) return path def action_grad(path, vf_func, jac_func, D=1, dt=1): x = (path[:-1] + path[1:]) * 0.5 v = np.diff(path, axis=0) / dt dv = v - vf_func(x) J = jac_func(x) z = np.zeros(dv.shape) for s in range(dv.shape[0]): z[s] = dv[s] @ J[:, :, s] grad = (dv[:-1] - dv[1:]) / D - dt / (2 * D) * (z[:-1] + z[1:]) return grad def lap_T(path_0, T, vf_func, jac_func, D=1): n = len(path_0) dt = T / (n - 1) dim = len(path_0[0]) def fun(x): return action_aux(x, vf_func, dim, start=path_0[0], end=path_0[-1], D=D, dt=dt) def jac(x): return action_grad_aux(x, vf_func, jac_func, dim, start=path_0[0], end=path_0[-1], D=D, dt=dt) sol_dict = minimize(fun, path_0[1:-1], jac=jac) path_sol = reshape_path(sol_dict["x"], dim, start=path_0[0], end=path_0[-1]) # further optimization by varying dt t_dict = minimize(lambda t: action(path_sol, vf_func, D=D, dt=t), dt) action_opt = t_dict["fun"] dt_sol = t_dict["x"][0] return path_sol, dt_sol, action_opt def action(path, vf_func, D=1, dt=1): # centers x = (path[:-1] + path[1:]) * 0.5 v = np.diff(path, axis=0) / dt s = (v - vf_func(x)).flatten() s = 0.5 * s.dot(s) * dt / D return s def minimize_lap_time(path_0, t0, t_min, vf_func, jac_func, D=1, num_t=20, elbow_method="hessian", hes_tol=3): T = np.linspace(t_min, t0, num_t) A = np.zeros(num_t) opt_T = np.zeros(num_t) laps = [] for i, t in enumerate(T): path, dt, action = lap_T(path_0, t, vf_func, jac_func, D=D) A[i] = action opt_T[i] = dt * (len(path_0) - 1) laps.append(path) i_elbow = find_elbow(opt_T, A, method=elbow_method, order=-1, tol=hes_tol) return i_elbow, laps, A, opt_T def normalize(x): x_min = np.min(x) return (x - x_min) / (np.max(x) - x_min) def interp_second_derivative(t, f, num=5e2, interp_kind="cubic", **interp_kwargs): """ interpolate f(t) and calculate the discrete second derivative using: d^2 f / dt^2 = (f(x+h1) - 2f(x) + f(x-h2)) / (h1 * h2) """ t_ = np.linspace(t[0], t[-1], int(num)) f_ = interpolate.interp1d(t, f, kind=interp_kind, **interp_kwargs)(t_) dt = np.diff(t_) df = np.diff(f_) t_ = t_[1:-1] d2fdt2 = np.zeros(len(t_)) for i in range(len(t_)): d2fdt2[i] = (df[i + 1] - df[i]) / (dt[i + 1] * dt[i]) return t_, d2fdt2 def interp_curvature(t, f, num=5e2, interp_kind="cubic", **interp_kwargs): """""" t_ = np.linspace(t[0], t[-1], int(num)) f_ = interpolate.interp1d(t, f, kind=interp_kind, **interp_kwargs)(t_) dt = np.diff(t_) df = np.diff(f_) dfdt_ = df / dt t_ = t_[1:-1] d2fdt2 = np.zeros(len(t_)) dfdt = np.zeros(len(t_)) for i in range(len(t_)): dfdt[i] = (dfdt_[i] + dfdt_[i + 1]) / 2 d2fdt2[i] = (df[i + 1] - df[i]) / (dt[i + 1] * dt[i]) cur = d2fdt2 / (1 + dfdt * dfdt) ** 1.5 return t_, cur def kneedle_difference(t, f, type="decrease"): if type == "decrease": diag_line = lambda x: -x + 1 elif type == "increase": diag_line = lambda x: x else: raise NotImplementedError(f"Unsupported function type {type}") t_ = normalize(t) f_ = normalize(f) res = np.abs(f_ - diag_line(t_)) return res def find_elbow(T, F, method="kneedle", order=1, **kwargs): i_elbow = None if method == "hessian": T_ = normalize(T) F_ = normalize(F) tol = kwargs.pop("tol", 2) t_, der = interp_second_derivative(T_, F_, **kwargs) found = False for i, t in enumerate(t_[::order]): if der[::order][i] > tol: i_elbow = np.argmin(np.abs(T_ - t)) found = True break if not found: warn("The elbow was not found.") elif method == "curvature": T_ = normalize(T) F_ = normalize(F) t_, cur = interp_curvature(T_, F_, **kwargs) i_elbow = np.argmax(cur) elif method == "kneedle": type = "decrease" if order == -1 else "increase" res = kneedle_difference(T, F, type=type) i_elbow = np.argmax(res) else: raise NotImplementedError(f"The method {method} is not supported.") return i_elbow
[docs]def least_action( adata: ad.AnnData, init_cells: Union[str, list], target_cells: Union[str, list], basis: str = "umap", vf_key: str = "VecFld", vecfld: Union[None, Callable] = None, adj_key: str = "pearson_transition_matrix", n_points: int = 25, n_neighbors: int =100, **kwargs, ): """ Calculate the optimal paths between any two cell states. Parameters ---------- adata An :class:`~anndata.AnnData` object. init_cells Cell name or indices of the initial cell states. target_cells Cell name or indices of the terminal cell states. basis The embedding data used to predict the least action path. (Default: "umap") vf_key A key to the vector field functions in adata.uns. (Default: "VecFld") vecfld The vector field function. (Default: None) adj_key The key to the adjacency matrix in adata.obsp. (Default: "pearson_transition_matrix") n_points The number of points on the least action path. (Default: 25) n_neighbors The number of neighbors. (Default: 100) Returns ----------- LeastActionPath A trajectory class containing the least action paths information. """ init_states,target_states = None,None paired = True min_lap_t = False elbow_method = "hessian" num_t = 20 init_paths = None D = 10 PCs = None expr_func: callable = np.expm1 add_key = None sc.pp.neighbors(adata,use_rep='X_'+basis,key_added='X_'+basis,n_neighbors=n_neighbors) if vecfld is None: vf = SvcVectorField() vf.from_adata(adata, basis=basis, vf_key=vf_key) else: vf = vecfld coords = adata.obsm["X_" + basis] T = adata.obsp[adj_key] G = nx.from_scipy_sparse_array(T) init_states, _, _, _ = fetch_states( adata, init_states, init_cells, basis, "X", False, None, ) target_states, _, _, valid_genes = fetch_states( adata, target_states, target_cells, basis, "X", False, None, ) init_states = np.atleast_2d(init_states) target_states = np.atleast_2d(target_states) if paired: if init_states.shape[0] != target_states.shape[0]: warn("The numbers of initial and target states are not equal. The longer one is trimmed") num = min(init_states.shape[0], target_states.shape[0]) init_states = init_states[:num] target_states = target_states[:num] pairs = [(init_states[i], target_states[i]) for i in range(init_states.shape[0])] else: pairs = [(pi, pt) for pi in init_states for pt in target_states] warn( f"A total of {len(pairs)} pairs of initial and target states will be calculated." "To reduce the number of LAP calculations, please use the `paired` mode." ) t, prediction, action, exprs, mftp, trajectory = [], [], [], [], [], [] if min_lap_t: i_elbow = [] laps = [] opt_T = [] A = [] path_ind = 0 for (init_state, target_state) in pairs: if init_paths is None: init_path = get_init_path(G, init_state, target_state, coords, interpolation_num=n_points) else: init_path = init_paths if type(init_paths) == np.ndarray else init_paths[path_ind] path_ind += 1 path_sol, dt_sol, action_opt = least_action_path( init_state, target_state, vf.func, vf.get_Jacobian(), n_points=n_points, init_path=init_path, D=D, **kwargs ) n_points = len(path_sol) # the actual #points due to arclength resampling if min_lap_t: t_sol = dt_sol * (n_points - 1) t_min = 0.3 * t_sol i_elbow_, laps_, A_, opt_T_ = minimize_lap_time( path_sol, t_sol, t_min, vf.func, vf.get_Jacobian(), D=D, num_t=num_t, elbow_method=elbow_method ) if i_elbow_ is None: i_elbow_ = 0 path_sol = laps_[i_elbow_] dt_sol = opt_T_[i_elbow_] / (n_points - 1) i_elbow.append(i_elbow_) laps.append(laps_) A.append(A_) opt_T.append(opt_T_) traj = LeastActionPath(X=path_sol, vf_func=vf.func, D=D, dt=dt_sol) trajectory.append(traj) t.append(np.arange(path_sol.shape[0]) * dt_sol) prediction.append(path_sol) action.append(traj.action()) mftp.append(traj.mfpt()) if basis == "pca": pc_keys = "PCs" if PCs is None else PCs if pc_keys not in adata.uns.keys(): warn("Expressions along the trajectories cannot be retrieved, due to lack of `PCs` in .uns.") else: if "pca_mean" not in adata.uns.keys(): pca_mean = None else: pca_mean = adata.uns["pca_mean"] exprs.append(pca_to_expr(traj.X, adata.uns["PCs"], pca_mean, func=expr_func)) if add_key is None: LAP_key = "LAP" if basis is None else "LAP_" + basis else: LAP_key = add_key adata.uns[LAP_key] = { "init_states": init_states, "init_cells": init_cells, "t": t, "mftp": mftp, "prediction": prediction, "action": action, # "genes": adata.var_names[adata.var.use_for_pca], "exprs": exprs, "vf_key": vf_key, } if min_lap_t: adata.uns[LAP_key]["min_t"] = {"A": A, "T": opt_T, "i_elbow": i_elbow, "paths": laps, "method": elbow_method} return trajectory[0] if len(trajectory) == 1 else trajectory
def pca_to_expr(X, PCs, mean=0, func=None): # reverse project from PCA back to raw expression space if PCs.shape[1] == X.shape[1]: exprs = X @ PCs.T + mean if func is not None: exprs = func(exprs) else: raise Exception("PCs dim 1 (%d) does not match X dim 1 (%d)." % (PCs.shape[1], X.shape[1])) return exprs class Trajectory: def __init__(self, X, t=None) -> None: """ Base class for handling trajectory interpolation, resampling, etc. """ self.X = X self.t = t class LeastActionPath(Trajectory): def __init__(self, X, vf_func, D=1, dt=1) -> None: super().__init__(X, t=np.arange(X.shape[0]) * dt) self.func = vf_func self.D = D self._action = np.zeros(X.shape[0]) for i in range(1, len(self._action)): self._action[i] = action(self.X[: i + 1], self.func, self.D, dt) def get_t(self): return self.t def get_dt(self): return np.mean(np.diff(self.t)) def action(self, t=None, **interp_kwargs): if t is None: return self._action else: return interp1d(self.t, self._action, **interp_kwargs)(t) def mfpt(self, action=None): """Eqn. 7 of Epigenetics as a first exit problem.""" action = self._action if action is None else action return 1 / np.exp(-action) def optimize_dt(self): dt_0 = self.get_dt() t_dict = minimize(lambda t: action(self.X, self.func, D=self.D, dt=t), dt_0) dt_sol = t_dict["x"][0] self.t = np.arange(self.X.shape[0]) * dt_sol return dt_sol
[docs]def plot_least_action_path(adata,basis='spatial',ax=None, linewidth=3, point_size=6, linestyle='solid' ): """ Plot the LAP and selected subset of cells. Parameters ---------- adata An :class:`~anndata.AnnData` object. basis The embedding data used to predict the least action path. (Default: 'spatial') ax Figure axes. (Default: None) linewidth Linewidth of the LAP. (Default: 3) point_size Point size of the LAP. (Default: 6) linestyle Linestyple of the LAP. (Default: 'solid') Returns ----------- ax The plot of the LAP and cells. """ lap_dict=adata.uns['LAP_'+basis] id_array=np.arange(0,len(lap_dict['prediction'][0]),2) lap_point_pos = lap_dict['prediction'][0][id_array] lap_value = lap_dict['action'][0][id_array] minima=np.min(lap_value) maxima=np.max(lap_value) norm=matplotlib.colors.Normalize(vmin=minima,vmax=maxima,clip=True) mapper=cm.ScalarMappable(norm=norm,cmap=plt.get_cmap('hsv')) cols=[mapper.to_rgba(v) for v in lap_value] ax.plot(*lap_point_pos.T, c="k",linewidth= linewidth,linestyle=linestyle,zorder=3) ax.scatter(*lap_point_pos.T, c=cols,s=point_size,zorder=2) return ax