Source code for spaTrack.single_time.vectorfield

import anndata as ad
from typing import Union
import sys
import numpy as np
import numpy.matlib
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from scipy.spatial.distance import cdist
from scipy.linalg.blas import dgemm


def update_n_merge_dict(dict1, dict2):
    dict = {
        **dict1,
        **dict2,
    }

    return dict


def norm(X, V, T, fix_velocity=True):
    Y = X + V
    n, m = X.shape[0], V.shape[0]

    xm = np.mean(X, 0)
    ym = np.mean(Y, 0)

    x, y, t = (
        X - xm[None, :],
        Y - ym[None, :],
        T - (1 / 2 * (xm[None, :] + ym[None, :])) if T is not None else None,
    )

    xscale, yscale = (
        np.sqrt(np.sum(np.sum(x**2, 1)) / n),
        np.sqrt(np.sum(np.sum(y**2, 1)) / m),
    )

    X, Y, T = (
        x / xscale,
        y / yscale,
        t / (1 / 2 * (xscale + yscale)) if T is not None else None,
    )

    X, V, T = X, V if fix_velocity else Y - X, T
    norm_dict = {
        "xm": xm,
        "ym": ym,
        "xscale": xscale,
        "yscale": yscale,
        "fix_velocity": fix_velocity,
    }

    return X, V, T, norm_dict


def sample_by_velocity(V, n, seed=19491001):
    np.random.seed(seed)
    tmp_V = np.linalg.norm(V, axis=1)
    p = tmp_V / np.sum(tmp_V)
    idx = np.random.choice(np.arange(len(V)), size=n, p=p, replace=False)
    return idx


def bandwidth_selector(X):
    """
    This function computes an empirical bandwidth for a Gaussian kernel.
    """
    n, m = X.shape
    if n > 200000 and m > 2:
        from pynndescent import NNDescent

        nbrs = NNDescent(
            X,
            metric="euclidean",
            n_neighbors=max(2, int(0.2 * n)),
            n_jobs=-1,
            random_state=19491001,
        )
        _, distances = nbrs.query(X, k=max(2, int(0.2 * n)))
    else:
        alg = "ball_tree" if X.shape[1] > 10 else "kd_tree"
        nbrs = NearestNeighbors(
            n_neighbors=max(2, int(0.2 * n)), algorithm=alg, n_jobs=-1
        ).fit(X)
        distances, _ = nbrs.kneighbors(X)

    d = np.mean(distances[:, 1:]) / 1.5
    return np.sqrt(2) * d


def con_K(x, y, beta, method="cdist", return_d=False):
    if method == "cdist" and not return_d:
        K = cdist(x, y, "sqeuclidean")
        if len(K) == 1:
            K = K.flatten()
    else:
        n = x.shape[0]
        m = y.shape[0]

        D = np.matlib.tile(x[:, :, None], [1, 1, m]) - np.transpose(
            np.matlib.tile(y[:, :, None], [1, 1, n]), [2, 1, 0]
        )
        K = np.squeeze(np.sum(D**2, 1))
    K = -beta * K
    K = np.exp(K)

    if return_d:
        return K, D
    else:
        return K


def con_K_div_cur_free(x, y, sigma=0.8, eta=0.5):
    m, d = x.shape
    n, d = y.shape
    sigma2 = sigma**2
    G_tmp = np.matlib.tile(x[:, :, None], [1, 1, n]) - np.transpose(
        np.matlib.tile(y[:, :, None], [1, 1, m]), [2, 1, 0]
    )
    G_tmp = np.squeeze(np.sum(G_tmp**2, 1))
    G_tmp3 = -G_tmp / sigma2
    G_tmp = -G_tmp / (2 * sigma2)
    G_tmp = np.exp(G_tmp) / sigma2
    G_tmp = np.kron(G_tmp, np.ones((d, d)))

    x_tmp = np.matlib.tile(x, [n, 1])
    y_tmp = np.matlib.tile(y, [1, m]).T
    y_tmp = y_tmp.reshape((d, m * n), order="F").T
    xminusy = x_tmp - y_tmp
    G_tmp2 = np.zeros((d * m, d * n))

    tmp4_ = np.zeros((d, d))
    for i in tqdm(range(d), desc="Iterating each dimension in con_K_div_cur_free:"):
        for j in np.arange(i, d):
            tmp1 = xminusy[:, i].reshape((m, n), order="F")
            tmp2 = xminusy[:, j].reshape((m, n), order="F")
            tmp3 = tmp1 * tmp2
            tmp4 = tmp4_.copy()
            tmp4[i, j] = 1
            tmp4[j, i] = 1
            G_tmp2 = G_tmp2 + np.kron(tmp3, tmp4)

    G_tmp2 = G_tmp2 / sigma2
    G_tmp3 = np.kron((G_tmp3 + d - 1), np.eye(d))
    G_tmp4 = np.kron(np.ones((m, n)), np.eye(d)) - G_tmp2
    df_kernel, cf_kernel = (1 - eta) * G_tmp * (G_tmp2 + G_tmp3), eta * G_tmp * G_tmp4
    G = df_kernel + cf_kernel

    return G, df_kernel, cf_kernel


def get_P(Y, V, sigma2, gamma, a, div_cur_free_kernels=False):
    if div_cur_free_kernels:
        Y = Y.reshape((2, int(Y.shape[0] / 2)), order="F").T
        V = V.reshape((2, int(V.shape[0] / 2)), order="F").T

    D = Y.shape[1]
    temp1 = np.exp(-np.sum((Y - V) ** 2, 1) / (2 * sigma2))
    temp2 = (2 * np.pi * sigma2) ** (D / 2) * (1 - gamma) / (gamma * a)
    temp1[temp1 == 0] = np.min(temp1[temp1 != 0])
    P = temp1 / (temp1 + temp2)
    E = (
        P.T.dot(np.sum((Y - V) ** 2, 1)) / (2 * sigma2)
        + np.sum(P) * np.log(sigma2) * D / 2
    )

    return (P[:, None], E) if P.ndim == 1 else (P, E)


def lstsq_solver(lhs, rhs, method="drouin"):
    C = linear_least_squares(lhs, rhs)
    return C


def linear_least_squares(a, b, residuals=False):
    a = np.asarray(a, order="c")
    i = dgemm(alpha=1.0, a=a.T, b=a.T, trans_b=True)
    x = np.linalg.solve(i, dgemm(alpha=1.0, a=a.T, b=b))

    if residuals:
        return x, np.linalg.norm(np.dot(a, x) - b)
    else:
        return x


def SparseVFC(
    X: np.ndarray,
    Y: np.ndarray,
    Grid: np.ndarray,
    M: int = 100,
    a: float = 5,
    beta: float = None,
    ecr: float = 1e-5,
    gamma: float = 0.9,
    lambda_: float = 3,
    minP: float = 1e-5,
    MaxIter: int = 500,
    theta: float = 0.75,
    div_cur_free_kernels: bool = False,
    velocity_based_sampling: bool = True,
    sigma: float = 0.8,
    eta: float = 0.5,
    seed=0,
    lstsq_method: str = "drouin",
    verbose: int = 1,
) -> dict:
    need_utility_time_measure = verbose > 1
    X_ori, Y_ori = X.copy(), Y.copy()
    valid_ind = np.where(np.isfinite(Y.sum(1)))[0]
    X, Y = X[valid_ind], Y[valid_ind]
    N, D = Y.shape
    grid_U = None

    # Construct kernel matrix K
    tmp_X, uid = np.unique(X, axis=0, return_index=True)  # return unique rows
    M = min(M, tmp_X.shape[0])
    if velocity_based_sampling:
        idx = sample_by_velocity(Y[uid], M, seed=seed)
    else:
        idx = np.random.RandomState(seed=seed).permutation(
            tmp_X.shape[0]
        )  # rand select some initial points
        idx = idx[range(M)]
    ctrl_pts = tmp_X[idx, :]

    if beta is None:
        h = bandwidth_selector(ctrl_pts)
        beta = 1 / h**2

    K = (
        con_K(ctrl_pts, ctrl_pts, beta)
        if div_cur_free_kernels is False
        else con_K_div_cur_free(ctrl_pts, ctrl_pts, sigma, eta)[0]
    )
    U = (
        con_K(X, ctrl_pts, beta)
        if div_cur_free_kernels is False
        else con_K_div_cur_free(X, ctrl_pts, sigma, eta)[0]
    )
    if Grid is not None:
        grid_U = (
            con_K(Grid, ctrl_pts, beta)
            if div_cur_free_kernels is False
            else con_K_div_cur_free(Grid, ctrl_pts, sigma, eta)[0]
        )
    M = ctrl_pts.shape[0] * D if div_cur_free_kernels else ctrl_pts.shape[0]

    if div_cur_free_kernels:
        X = X.flatten()[:, None]
        Y = Y.flatten()[:, None]

    # Initialization
    V = X.copy() if div_cur_free_kernels else np.zeros((N, D))
    C = np.zeros((M, 1)) if div_cur_free_kernels else np.zeros((M, D))
    i, tecr, E = 0, 1, 1
    # test this
    sigma2 = (
        sum(sum((Y - X) ** 2)) / (N * D)
        if div_cur_free_kernels
        else sum(sum((Y - V) ** 2)) / (N * D)
    )
    sigma2 = 1e-7 if sigma2 < 1e-8 else sigma2
    tecr_vec = np.ones(MaxIter) * np.nan
    E_vec = np.ones(MaxIter) * np.nan
    P = None
    while i < MaxIter and tecr > ecr and sigma2 > 1e-8:
        # E_step
        E_old = E
        P, E = get_P(Y, V, sigma2, gamma, a, div_cur_free_kernels)

        E = E + lambda_ / 2 * np.trace(C.T.dot(K).dot(C))
        E_vec[i] = E
        tecr = abs((E - E_old) / E)
        tecr_vec[i] = tecr

        P = np.maximum(P, minP)
        if div_cur_free_kernels:
            P = np.kron(
                P, np.ones((int(U.shape[0] / P.shape[0]), 1))
            )  # np.kron(P, np.ones((D, 1)))
            lhs = (U.T * np.matlib.tile(P.T, [M, 1])).dot(U) + lambda_ * sigma2 * K
            rhs = (U.T * np.matlib.tile(P.T, [M, 1])).dot(Y)
        else:
            UP = U.T * numpy.matlib.repmat(P.T, M, 1)
            lhs = UP.dot(U) + lambda_ * sigma2 * K
            rhs = UP.dot(Y)

        C = lstsq_solver(lhs, rhs, method=lstsq_method)

        # Update V and sigma**2
        V = U.dot(C)
        Sp = sum(P) / 2 if div_cur_free_kernels else sum(P)
        sigma2 = (sum(P.T.dot(np.sum((Y - V) ** 2, 1))) / np.dot(Sp, D))[0]

        # Update gamma
        numcorr = len(np.where(P > theta)[0])
        gamma = numcorr / X.shape[0]

        if gamma > 0.95:
            gamma = 0.95
        elif gamma < 0.05:
            gamma = 0.05

        i += 1
    if i == 0 and not (tecr > ecr and sigma2 > 1e-8):
        raise Exception(
            "please check your input parameters, "
            f"tecr: {tecr}, ecr {ecr} and sigma2 {sigma2},"
            f"tecr must larger than ecr and sigma2 must larger than 1e-8"
        )

    grid_V = None
    if Grid is not None:
        grid_V = np.dot(grid_U, C)

    VecFld = {
        "X": X_ori,
        "valid_ind": valid_ind,
        "X_ctrl": ctrl_pts,
        "ctrl_idx": idx,
        "Y": Y_ori,
        "beta": beta,
        "V": V.reshape((N, D)) if div_cur_free_kernels else V,
        "C": C,
        "P": P,
        "VFCIndex": np.where(P > theta)[0],
        "sigma2": sigma2,
        "grid": Grid,
        "grid_V": grid_V,
        "iteration": i - 1,
        "tecr_traj": tecr_vec[:i],
        "E_traj": E_vec[:i],
    }

    return VecFld


def get_vf_dict(adata, basis="", vf_key="VecFld"):
    if basis is not None:
        if len(basis) > 0:
            vf_key = "%s_%s" % (vf_key, basis)

    if vf_key not in adata.uns.keys():
        raise ValueError(
            f"Vector field function {vf_key} is not included in the adata object! "
            f"Try firstly running dyn.vf.VectorField(adata, basis='{basis}')"
        )

    vf_dict = adata.uns[vf_key]
    return vf_dict


def vecfld_from_adata(adata, basis="", vf_key="VecFld"):
    vf_dict = get_vf_dict(adata, basis=basis, vf_key=vf_key)

    method = vf_dict["method"]
    func = lambda x: vector_field_function(x, vf_dict)

    return vf_dict, func


class BaseVectorField:
    def __init__(
        self,
        X=None,
        V=None,
        Grid=None,
        *args,
        **kwargs,
    ):
        self.data = {"X": X, "V": V, "Grid": Grid}
        self.vf_dict = kwargs.pop("vf_dict", {})
        self.func = kwargs.pop("func", None)
        self.fixed_points = kwargs.pop("fixed_points", None)
        super().__init__(**kwargs)

    def from_adata(self, adata, basis="", vf_key="VecFld"):
        vf_dict, func = vecfld_from_adata(adata, basis=basis, vf_key=vf_key)
        self.data["X"] = vf_dict["X"]
        self.data["V"] = vf_dict["Y"]  # use the raw velocity
        self.vf_dict = vf_dict
        self.func = func


class DifferentiableVectorField(BaseVectorField):
    def get_Jacobian(self, method=None):
        # subclasses must implement this function.
        pass


class SvcVectorField(DifferentiableVectorField):
    def __init__(self, X=None, V=None, Grid=None, *args, **kwargs):
        super().__init__(X, V, Grid)
        if X is not None and V is not None:
            self.parameters = kwargs
            self.parameters = update_n_merge_dict(
                self.parameters,
                {
                    "M": kwargs.pop("M", None)
                    or max(min([50, len(X)]), int(0.05 * len(X)) + 1),
                    # min(len(X), int(1500 * np.log(len(X)) / (np.log(len(X)) + np.log(100)))),
                    "a": kwargs.pop("a", 5),
                    "beta": kwargs.pop("beta", None),
                    "ecr": kwargs.pop("ecr", 1e-5),
                    "gamma": kwargs.pop("gamma", 0.9),
                    "lambda_": kwargs.pop("lambda_", 3),
                    "minP": kwargs.pop("minP", 1e-5),
                    "MaxIter": kwargs.pop("MaxIter", 500),
                    "theta": kwargs.pop("theta", 0.75),
                    "div_cur_free_kernels": kwargs.pop("div_cur_free_kernels", False),
                    "velocity_based_sampling": kwargs.pop(
                        "velocity_based_sampling", True
                    ),
                    "sigma": kwargs.pop("sigma", 0.8),
                    "eta": kwargs.pop("eta", 0.5),
                    "seed": kwargs.pop("seed", 0),
                },
            )
        self.norm_dict = {}

    def train(self, normalize=False, **kwargs):
        if normalize:
            X_norm, V_norm, T_norm, norm_dict = norm(
                self.data["X"], self.data["V"], self.data["Grid"]
            )
            (self.data["X"], self.data["V"], self.data["Grid"], self.norm_dict,) = (
                X_norm,
                V_norm,
                T_norm,
                norm_dict,
            )

        verbose = kwargs.pop("verbose", 0)
        lstsq_method = kwargs.pop("lstsq_method", "drouin")
        VecFld = SparseVFC(
            self.data["X"],
            self.data["V"],
            self.data["Grid"],
            **self.parameters,
            verbose=verbose,
            lstsq_method=lstsq_method,
        )
        self.parameters = update_dict(self.parameters, VecFld)

        self.vf_dict = VecFld

        self.func = lambda x: vector_field_function(x, VecFld)
        self.vf_dict["V"] = self.func(self.data["X"])
        self.vf_dict["normalize"] = normalize

        return self.vf_dict

    def get_Jacobian(
        self, method="analytical", input_vector_convention="row", **kwargs
    ):
        return lambda x: Jacobian_rkhs_gaussian(x, self.vf_dict, **kwargs)


def Jacobian_rkhs_gaussian(x, vf_dict, vectorize=False):
    if x.ndim == 1:
        K, D = con_K(x[None, :], vf_dict["X_ctrl"], vf_dict["beta"], return_d=True)
        J = (vf_dict["C"].T * K) @ D[0].T
    elif not vectorize:
        n, d = x.shape
        J = np.zeros((d, d, n))
        for i, xi in enumerate(x):
            K, D = con_K(xi[None, :], vf_dict["X_ctrl"], vf_dict["beta"], return_d=True)
            J[:, :, i] = (vf_dict["C"].T * K) @ D[0].T
    else:
        K, D = con_K(x, vf_dict["X_ctrl"], vf_dict["beta"], return_d=True)
        if K.ndim == 1:
            K = K[None, :]
        J = np.einsum("nm, mi, njm -> ijn", K, vf_dict["C"], D)

    return -2 * vf_dict["beta"] * J


def vector_field_function(
    x, vf_dict, dim=None, kernel="full", X_ctrl_ind=None, **kernel_kwargs
):
    """vector field function constructed by sparseVFC.
    Reference: Regularized vector field learning with sparse approximation for mismatch removal, Ma, Jiayi, etc. al, Pattern Recognition
    """
    # x=np.array(x).reshape((1, -1))
    if "div_cur_free_kernels" in vf_dict.keys():
        has_div_cur_free_kernels = True
    else:
        has_div_cur_free_kernels = False

    # x = np.array(x)
    if x.ndim == 1:
        x = x[None, :]

    if has_div_cur_free_kernels:
        if kernel == "full":
            kernel_ind = 0
        elif kernel == "df_kernel":
            kernel_ind = 1
        elif kernel == "cf_kernel":
            kernel_ind = 2
        else:
            raise ValueError(
                f"the kernel can only be one of {'full', 'df_kernel', 'cf_kernel'}!"
            )

        K = con_K_div_cur_free(
            x,
            vf_dict["X_ctrl"],
            vf_dict["sigma"],
            vf_dict["eta"],
            **kernel_kwargs,
        )[kernel_ind]
    else:
        Xc = vf_dict["X_ctrl"]
        K = con_K(x, Xc, vf_dict["beta"], **kernel_kwargs)

    if X_ctrl_ind is not None:
        C = np.zeros_like(vf_dict["C"])
        C[X_ctrl_ind, :] = vf_dict["C"][X_ctrl_ind, :]
    else:
        C = vf_dict["C"]

    K = K.dot(C)

    if dim is not None and not has_div_cur_free_kernels:
        if np.isscalar(dim):
            K = K[:, :dim]
        elif dim is not None:
            K = K[:, dim]

    return K


[docs]def VectorField( adata: ad.AnnData, basis: Union[None, str] = None, normalize: bool = False, result_key: Union[str, None] = None, **kwargs, ): """ Learn the function of vector filed. Parameters ---------- adata An :class:`~anndata.AnnData` object. basis The label of cell coordinates, for example, `umap` or `spatial`. (Default: None) normalize Logic flag to determine whether to normalize the data to have zero means and unit covariance. (Default: False) Returns ----------- BaseVectorfield A vector field class object. """ method = "SparseVFC" result_key = None if basis is not None: X = adata.obsm["X_" + basis].copy() V = adata.obsm["velocity_" + basis].copy() else: sys.exit("please give the right basis like 'umap' or 'spatial...'") Grid = None if X is None: raise Exception( f"X is None. Make sure you passed the correct X or {basis} dimension reduction method." ) elif V is None: raise Exception("V is None. Make sure you passed the correct V.") if method.lower() == "sparsevfc": vf_kwargs = { "M": None, "a": 5, "beta": None, "ecr": 1e-5, "gamma": 0.9, "lambda_": 3, "minP": 1e-5, "MaxIter": 30, "theta": 0.75, "div_cur_free_kernels": False, "velocity_based_sampling": True, "sigma": 0.8, "eta": 0.5, "seed": 0, } else: raise ValueError("only support SparseVFC") vf_kwargs = update_dict(vf_kwargs, kwargs) if method.lower() == "sparsevfc": VecFld = SvcVectorField(X, V, Grid, **vf_kwargs) vf_dict = VecFld.train(normalize=normalize, **kwargs) if result_key is None: vf_key = "VecFld" if basis is None else "VecFld_" + basis else: vf_key = result_key if basis is None else result_key + "_" + basis vf_dict["method"] = method if basis is not None: key = "velocity_" + basis + "_" + method X_copy_key = "X_" + basis + "_" + method adata.obsm[key] = vf_dict["V"] adata.obsm[X_copy_key] = vf_dict["X"] adata.uns[vf_key] = vf_dict control_point, inlier_prob, valid_ids = ( "control_point_" + basis if basis is not None else "control_point", "inlier_prob_" + basis if basis is not None else "inlier_prob", vf_dict["valid_ind"], ) if method.lower() == "sparsevfc": adata.obs[control_point], adata.obs[inlier_prob] = False, np.nan adata.obs.loc[adata.obs_names[vf_dict["ctrl_idx"]], control_point] = True adata.obs.loc[adata.obs_names[valid_ids], inlier_prob] = vf_dict["P"].flatten() # angles between observed velocity and that predicted by vector field across cells: cell_angles = np.zeros(adata.n_obs, dtype=float) for i, u, v in zip(valid_ids, V[valid_ids], vf_dict["V"]): # fix the u, v norm == 0 in angle function cell_angles[i] = angle(u.astype("float64"), v.astype("float64")) if basis is not None: temp_key = "obs_vf_angle_" + basis adata.obs[temp_key] = cell_angles return VecFld
def update_dict(dict1, dict2): dict1.update((k, dict2[k]) for k in dict1.keys() & dict2.keys()) return dict1 def angle(vector1, vector2): """Returns the angle in radians between given vectors""" v1_norm, v1_u = unit_vector(vector1) v2_norm, v2_u = unit_vector(vector2) if v1_norm == 0 or v2_norm == 0: return np.nan else: minor = np.linalg.det(np.stack((v1_u[-2:], v2_u[-2:]))) if minor == 0: sign = 1 else: sign = -np.sign(minor) dot_p = np.dot(v1_u, v2_u) dot_p = min(max(dot_p, -1.0), 1.0) return sign * np.arccos(dot_p) def unit_vector(vector): """Returns the unit vector of the vector.""" vec_norm = np.linalg.norm(vector) if vec_norm == 0: return vec_norm, vector else: return vec_norm, vector / vec_norm