Source code for spaTrack.single_time.Pgene

import pandas as pd
import numpy as np
import math
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc
from scipy import stats
from sklearn import preprocessing
import multiprocessing as mp
from pygam import LinearGAM, s, f
import statsmodels.stats as stat
from scipy.signal import savgol_filter
import statsmodels.formula.api as smf
import scipy.stats
from scipy import sparse
import gc
from anndata import AnnData
from matplotlib.axes import Axes


"""
Caculate JS score to determine data trend is increase or decrease

"""


def js_score(gam_fit, grid_X):
    """
    Parameters
    ----------
    gam_fit :
        Fitted model by pyGAM

    grid_X: array
        An array value grided by pyGAM's generate_X_grid function


    Returns
    -------
    trend: str
        Mark the fitted model is increase or decrease

    """

    def JS_divergence(p, q):
        """
        Parameters
        ----------
        p,q :
                Two same length arrays
                p: array fitted by model
                q: standard distribution array

        Returns
        -------
            JS score. More smaller value indicate the distribution of inputed data is more similar with standard distribution
        """
        M = (p + q) / 2
        return 0.5 * scipy.stats.entropy(p, M) + 0.5 * scipy.stats.entropy(q, M)

    x = [i for i in range(100)]
    l = np.array(x).reshape(-1, 1)
    increase_trend = preprocessing.MaxAbsScaler().fit_transform(l).reshape(1, -1)[0]
    decrease_trend = increase_trend[::-1]
    decrease_score = JS_divergence(gam_fit.predict(grid_X), decrease_trend)
    increase_score = JS_divergence(gam_fit.predict(grid_X), increase_trend)
    pattern_dict = {"decrease": decrease_score, "increase": increase_score}
    gene_trend = min(pattern_dict, key=pattern_dict.get)
    return gene_trend


##Fit gene expression and ptime by generalized additive model
##Identify pesudotime-dependent genes may drive cell transition

"""
Filter genes by minimum expression proporation and cluster differential expression.
Cluster differential expression is used to as a reference to order gene.
"""


[docs]def filter_gene(adata: AnnData, min_exp_prop: float, hvg_gene: int = 2000)->AnnData: """ Filter genes by minimum expression proporation and cluster differential expression. Parameters ---------- adata An :class:`~anndata.AnnData` object. min_exp_prop Minimum expression proporation. abs_FC Log2 foldchange in differential expression. Returns ---------- :class:`~anndata.AnnData` The :class:`~anndata.AnnData` object formed by filtered genes. """ ptime_list = list(adata.obs["ptime"]) if sorted(ptime_list) == ptime_list: pass else: raise Exception("error: Please sort adata by ptime.") cluster_order = adata.obs.groupby(["cluster" ]).mean().sort_values(["ptime"]).index print("clusters ordered by ptime: ", list(cluster_order)) ptime_sort_matrix = adata.X.copy() if type(ptime_sort_matrix) == sparse._csr.csr_matrix: ptime_sort_matrix = ptime_sort_matrix.toarray() df_exp = pd.DataFrame(data=ptime_sort_matrix, index=adata.obs.index, columns=adata.var.index) # endog = adata.obs["ptime"] ##minimum expression proporation min_prop_filter = df_exp[df_exp.columns[(df_exp > 0).sum( axis=0) > int(len(adata) * min_exp_prop)]] sc.pp.highly_variable_genes(adata, n_top_genes=hvg_gene) ##cluster differential expression # sc.tl.rank_genes_groups(adata, 'cluster', method='wilcoxon') # result = adata.uns['rank_genes_groups'] # groups = result['names'].dtype.names # df_diff_res = pd.DataFrame( # {group + '_' + key[:1]: result[key][group] # for group in groups for key in ['names', 'pvals_adj','logfoldchanges']}) # diff_gene_list=list() # for cluster_name in cluster_order: # print(cluster_name) # df_cluster_diff = df_diff_res.loc[df_diff_res[cluster_name+'_p']<0.01].sort_values([cluster_name+'_l'],ascending=False) # gene_list1 = df_cluster_diff.loc[df_cluster_diff[cluster_name+'_l']>abs_FC][cluster_name+'_n'] # diff_gene_list = diff_gene_list+list(gene_list1) # gene_list_lm=set(diff_gene_list).intersection(set(list(min_prop_filter.columns))) gene_list_lm = np.intersect1d( min_prop_filter.columns, adata[:, adata.var.highly_variable].var_names) adata_filter = adata[:, min_prop_filter.columns] adata_filter.uns["gene_list_lm"] = gene_list_lm # adata_filter.uns['diff_gene_list'] = diff_gene_list print("Cell number" + "\t" + str(len(adata_filter))) print("Gene number" + "\t" + str(len(gene_list_lm))) return adata_filter
def GAM_gene_fit(exp_gene_list): """ Parameters ---------- exp_gene_list : multi layer list exp_gene_list[0]: dataframe columns : ptime,gene_expression exp_gene_list[1]: gene_name """ r_list = list() trend_list = list() gene_list = list() pvalue_list = list() df_new = exp_gene_list[0] gene = exp_gene_list[1] x = df_new[["ptime"]].values y = df_new[gene] gam = LinearGAM(s(0, n_splines=8)) gam_fit = gam.gridsearch(x, y, progress=False) grid_X = gam_fit.generate_X_grid(term=0) r_list.append(gam_fit.statistics_["pseudo_r2"]["explained_deviance"]) pvalue_list.append(gam_fit.statistics_["p_values"][0]) gene_list.append(gene) trend_list.append(js_score(gam_fit, grid_X)) ## sort gene by fdr and R2 df_batch_res = pd.DataFrame( { "gene": gene_list, "pvalue": pvalue_list, "model_fit": r_list, "pattern": trend_list, } ) return df_batch_res """ function: Call ptime_gene_GAM() by multi-process computing to improve operational speed """
[docs]def ptime_gene_GAM(adata: AnnData, core_number: int = 3) -> pd.DataFrame: """ Fit GAM model by formula gene_exp ~ Ptime. Call GAM_gene_fit() by multi-process computing to improve operational speed. Parameters ---------- adata An :class:`~anndata.AnnData` object. core_number Number of processes for caculating. Returns ------- :class:`~pandas.DataFrame` An :class:`~pandas.DataFrame` object, each column is one index. - pvalue: calculated from GAM - R2: a goodness-of-fit measure. larger value means better fit - pattern: increase or decrease. drection of gene expression changes across time - fdr: BH fdr """ # perform GAM model on each gene gene_list_for_gam = adata.uns["gene_list_lm"] df_exp_filter = pd.DataFrame(data=adata.X, index=adata.obs.index, columns=adata.var.index) print("Genes number fitted by GAM model: ", len(gene_list_for_gam)) if core_number >= 1: para_list = list() for gene in gene_list_for_gam: df_new = pd.DataFrame({ "ptime": list(adata.obs["ptime"]), gene: list(df_exp_filter[gene]) }) # df_new=df_new.loc[df_new[gene]>0] para_list.append((df_new, gene)) p = mp.Pool(core_number) df_res = p.map(GAM_gene_fit, para_list) p.close() p.join() df_res = pd.concat(df_res) del para_list gc.collect() fdr = stat.multitest.fdrcorrection(np.array(df_res["pvalue"]))[1] df_res["fdr"] = fdr df_res.index = list(df_res["gene"]) return df_res
""" function: Split cells sorted by ptime into widonws. Order genes according number id of the maximum expression window """
[docs]def order_trajectory_genes(adata:AnnData, df_sig_res:pd.DataFrame, cell_number:int): """ Split cells sorted by ptime into widonws. Order genes according number id of the maximum expression window. Parameters ---------- adata An :class:`~anndata.AnnData` object. df_sig_res Return dataframe by ptime_gene_GAM() after filtering as significat gene dataframe. cell_number Cell number within splited window. Returns ------- :class:`~pandas.DataFrame` - columns:Sortted significant genes expression matrix according to mean expression value in windows - index: cell_id """ ptime_sort_exp_matrix = adata.X.copy() df_exp_filter = pd.DataFrame( data=ptime_sort_exp_matrix, index=adata.obs.index, columns=adata.var.index ) df_one_cell_exp_sig = df_exp_filter.loc[:, df_sig_res.index] sig_genes = df_one_cell_exp_sig.columns max_cell = pd.DataFrame(index=sig_genes, columns=["max"]) df_one_cell_exp_matrix = np.array(df_one_cell_exp_sig) # windows number window_number = math.ceil(len(df_one_cell_exp_sig) / cell_number) df_one_cell_exp_matrix.resize( (cell_number * window_number, len(sig_genes)), refcheck=False ) # divide block block_matrix = df_one_cell_exp_matrix.reshape( (window_number, cell_number, len(sig_genes)) ) window_matrix = np.sum(block_matrix, axis=1) # cell number in each window cell_in_window = np.array( [[cell_number]] * (window_number - 1) + [[len(df_one_cell_exp_sig) - cell_number * (window_number - 1)]] ) # mean expression in each window mean_window_matrix = window_matrix / cell_in_window window_exp = pd.DataFrame( data=mean_window_matrix, index=["window_" + str(i) for i in range(window_number)], columns=sig_genes, ) for i in sig_genes: max_cell.loc[i, "max"] = window_exp[i].idxmax() ptime = np.array(adata.obs["ptime"]) ptime.resize((window_number, cell_number), refcheck=False) mean_ptime = ptime.sum(axis=1) / cell_in_window.T endog = pd.DataFrame( data=mean_ptime.T, index=["window_" + str(i) for i in range(window_number)], columns=["ptime"], ) max_cell["ptime"] = endog.loc[max_cell["max"].values].values max_cell = max_cell.iloc[np.argsort(max_cell["ptime"].values), :] sort_window_exp = window_exp.loc[:, max_cell.index] print("Finally selected", len(sort_window_exp.columns), "genes.") ## return gene order gene_sort_list = sort_window_exp.columns df_one_cell_exp_sort = df_one_cell_exp_sig[gene_sort_list] return df_one_cell_exp_sort
""" function: Plot ordered gene expression heatmap of the selected candidate trajectory genes """
[docs]def plot_trajectory_gene_heatmap( sig_gene_exp_order: pd.DataFrame, smooth_length:int, cmap_name: str ="twilight_shifted", gene_label_size:int =30, fig_width=8,fig_height=10 ): """ Parameters ---------- sig_gene_exp_order Gene ordered expression dataframe. smooth_length length of smoothing window cmap_name Color palette fig_width,fig_height The width and height of figure Returns ------- A heatmap plot, column-representing cells, row-representing genes. """ ## only show TF gene # TF_file=pd.read_table('hs_hgnc_tfs.txt',header=None) # cell_TF_exp=cell_exp[cell_exp.columns[cell_exp.columns.isin(TF_file[0])]] sort_window_exog_z = stats.zscore(sig_gene_exp_order, axis=0) last_pd = pd.DataFrame( data=sort_window_exog_z.T, columns=sort_window_exog_z.index, index=sort_window_exog_z.columns, ) # smooth data last_pd_smooth = savgol_filter(last_pd, smooth_length, 1) last_pd_smooth = pd.DataFrame(last_pd_smooth) last_pd_smooth.columns = last_pd.columns last_pd_smooth.index = last_pd.index #fig = plt.figure(figsize=(8, 10)) fig = plt.figure(figsize=(fig_width,fig_height)) #ax1 = plt.subplot2grid((8, 10), (0, 0), colspan=10, rowspan=8) pseudotime_gene_heatmap = sns.heatmap( last_pd_smooth, cmap=cmap_name, cbar_kws={"shrink": 0.3, "label": "normalized expression"}, ) cbar = pseudotime_gene_heatmap.collections[0].colorbar cbar.ax.tick_params(labelsize=22) ## add cell type # df_cell=pd.DataFrame(sig_gene_exp_order.index) # df_cell[1]=list(adata.obs['cluster']) # plt.axis('off') # cell_line_plot = sns.histplot(data = df_cell, x = 0,hue=1,ax=ax2) # cell_line_plot.set_frame_on(False) # cell_line_plot.get_legend().remove() # cell_line_plot._legend.remove() pseudotime_gene_heatmap.figure.axes[-1].yaxis.label.set_size(25) pseudotime_gene_heatmap.xaxis.tick_top() pseudotime_gene_heatmap.set_xticks([]) pseudotime_gene_heatmap.yaxis.set_tick_params(labelsize=gene_label_size) plt.xticks(rotation=90) return fig.tight_layout()
""" function: Plot one trajectory gene """
[docs]def plot_trajectory_gene(adata:AnnData, gene_name:str, line_width:int=5, show_cell_type:bool=False,point_size=20)->Axes: """ Parameters ---------- adata An :class:`~anndata.AnnData` object. gene_name Gene used to plot. line_width Widthe of fitting line. show_cell_type Whether to show cell type in plot. point_size The size of point Returns ------- :class:`~matplotlib.axes.Axes` An :class:`~matplotlib.axes.Axes` object. X axis indicates pseduotime and y axis indicates gene expression value. """ gene_expression = pd.DataFrame( data=adata.X, index=adata.obs.index, columns=adata.var.index ) df_new = pd.DataFrame( { "ptime": list(adata.obs["ptime"]), gene_name: list(gene_expression[gene_name]), "cell_type": list(adata.obs["cluster"]), } ) # df_new=df_new.loc[df_new[gene_name]>0] x_ptime = df_new[["ptime"]].values y_exp = df_new[gene_name] gam = LinearGAM(s(0, n_splines=10)) gam_res = gam.gridsearch(x_ptime, y_exp,progress=False) fig, axs = plt.subplots(figsize=(10,6)) XX = gam_res.generate_X_grid(term=0) axs.plot(XX, gam.predict(XX), color="#aa4d3d", linewidth=line_width) if show_cell_type == True: sns.scatterplot( x="ptime", y=gene_name, palette="deep", ax=axs, data=df_new,s=point_size, hue="cell_type" ) else: sns.scatterplot( x="ptime", y=gene_name, cmap="plasma", ax=axs, s=point_size,data=df_new, c=x_ptime ) if show_cell_type == True: plt.gca().legend().set_title("") plt.legend(fontsize="xx-large", loc=(1.01, 0.5)) plt.title(gene_name, fontsize=30) plt.xlabel("ptime", fontsize=30) plt.ylabel("expression", fontsize=30) plt.xticks(fontsize=18) plt.yticks(fontsize=18) norm = plt.Normalize(df_new['ptime'].min(), df_new['ptime'].max()) sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm) sm.set_array([]) #axs.get_legend().remove() axs.figure.colorbar(sm,ax=axs) fig.tight_layout() return axs
""" funtion: Plot a group of trajectory genes """ def plot_trajectory_gene_list( adata, gene_name_list, col_num=4, title_fontsize=25, label_fontsize=22, line_width=5, fig_legnth=10, fig_width=8, ): """ Parameters ---------- adata : AnnData object. gene_name_list: List object gene list used to plot col_num: int Number of genes displayed per line in picature (Default: 4) title_fontsize: int title fontsize of picture (Default: 25) label_fontsize: int x and y label fontsize (Default: 22) fig_legnth,fig_width: The legenth and width of picture size (Default: 10,8) Returns ------- ax: fig object x axis indicate pseduotime; y axis indicate gene expression value """ gene_number = len(gene_name_list) row_num = math.ceil(gene_number / col_num) fig, axs = plt.subplots( ncols=col_num, nrows=row_num, figsize=(fig_legnth, fig_width), sharey=False, sharex=True, ) i = -1 gene_expression = pd.DataFrame( data=adata.X, index=adata.obs.index, columns=adata.var.index ) for m in range(row_num): for n in range(col_num): i = i + 1 if i > gene_number - 1: break gene_name = gene_name_list[i] ax = axs[m, n] df_new = pd.DataFrame( { "ptime": list(adata.obs["ptime"]), gene_name: list(gene_expression[gene_name]), } ) # df_new=df_new.loc[df_new[gene_name]>0] x_ptime = df_new[["ptime"]].values y_exp = df_new[gene_name] gam = LinearGAM(s(0, n_splines=10)) gam_res = gam.gridsearch(x_ptime, y_exp) XX = gam_res.generate_X_grid(term=0) ax.plot(XX, gam.predict(XX), color="#aa4d3d", linewidth=line_width) ax.scatter(x_ptime, y_exp, cmap="plasma", c=x_ptime) ax.set_title(gene_name, fontsize=title_fontsize) for pos in range(gene_number, row_num * col_num): axs.flat[pos].set_visible(False) # fig.tight_layout() fig.text(0.5, -0.04, "ptime", ha="center", fontsize=label_fontsize) fig.text( 0.01, 0.5, "expression ", va="center", rotation="vertical", fontsize=label_fontsize, ) # plt.tight_layout() fig.tight_layout() fig.subplots_adjust(left=0.06) return ax # 01 filter gene by expression # sub_adata=sti.Pgene.filter_gene(sub_adata,min_exp_prop=0.1,hvg_gene=3000) # 02 fit GAM model # df_res = sti.Pgene.ptime_gene_GAM(sub_adata,core_number=5) # 03 filter gene by GAM model indicators # df_sig_res = df_res.loc[(df_res['model_fit']>0.05) & (df_res['fdr']<0.05)] # 04 order trajectory genes # sort_exp_sig = sti.Pgene.order_trajectory_genes(sub_adata,df_sig_res,cell_number=20) # 05 plot trajectory gene heatmap # sti.Pgene.plot_trajectory_gene_heatmap(sort_exp_sig,smooth_length=100,gene_label_size=20) # 06 plot one or multiple trajectory genes # sti.Pgene.plot_trajectory_gene(sub_adata,gene_name='APOE',show_cell_type=False) # sti.Pgene.plot_trajectory_gene_list(sub_adata,gene_name_list=['COL1A1','ACTB','TNC','AQP1'],col_num=2)