Source code for spaTrack.single_time.gene_regulation

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import scanpy as sc
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
import random
import matplotlib.pyplot as plt
import sys
from scipy.sparse import issparse
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm

from typing import Literal, Union, List


class Model(nn.Module):
    """
    Model for exploring the relationship between TFs and genes.

    Parameters
    ----------
    n_gene
        The dimensionality of the input, i.e. the number of genes.
    n_TF
        The dimensonality of the output, i.e. the number of TFs.
    """

    def __init__(
        self,
        n_gene: int,
        n_TF: int,
    ) -> None:
        super(Model, self).__init__()
        self.n_gene = n_gene
        self.n_TF = n_TF

        self.linear = nn.Linear(n_gene, n_TF)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Give the gene changes of cells, and get the relationship between genes and TFs through linear regression.

        Parameters
        ----------
        x
            The input data (gene changes)

        Returns
        -------
        :class:`torch.Tensor`

                    Tensors for the output data (TF expresssion):
        """
        y_pred = self.linear(x)
        return y_pred


[docs]class Trainer: """ Class for implementing the training process. parameters ---------- type The Data type. Including dual time point data of two slices and pseudotime data of one slice. expression_matrix_path The path of the expression matrix file. tfs_path The path of the tf names file. cell_mapping_path The path of the cell mapping file, where column `slice1` indicates the start cell and column `slice2` indicates the end cell. ptime_path The path of the ptime file, used to determine the sequence of the ptime data. min_cells, optional The minimum number of cells for gene filtration. cell_divide_per_time, optional The cell number generated at each time point using the meta-analysis method, by default 500. cell_select_per_time, optional The number of randomly selected cells at each time point. cell_generate_per_time, optional The number of cells generated at each time point. train_ratio Ratio of training data. use_gpu, optional Whether to use gpu, by default True. random_state Random seed of numpy and torch. """ def __init__( self, data_type: Literal["2_time", "p_time"], expression_matrix_path: Union[str, List[str]], tfs_path: str, cell_mapping_path: str = None, ptime_path: str = None, min_cells: Union[int, List[int]] = None, cell_divide_per_time: int = 80, cell_select_per_time: int = 10, cell_generate_per_time: int = 500, train_ratio: float = 0.8, use_gpu: bool = True, random_state: int = 0, ) -> None: self.train_ratio = train_ratio gpu = torch.cuda.is_available() and use_gpu if gpu: torch.cuda.manual_seed(random_state) self.device = torch.device("cuda") else: self.device = torch.device("cpu") np.random.seed(random_state) random.seed(random_state) torch.manual_seed(random_state) # read gene expression, cell ptime and tfs files. if data_type == "p_time": if type(expression_matrix_path) == list: expression_matrix_path = expression_matrix_path[0] self.adata = sc.read(expression_matrix_path) self.adata.obs["ptime"] = pd.read_table(ptime_path, index_col=0) self.tfs_path = tfs_path if min_cells == None: min_cells = round(self.adata.n_obs*0.3) # filter out genes expressed in few cells sc.pp.filter_genes(self.adata, min_cells=min_cells, inplace=True) print(f"Genes expressed in less than {min_cells} cells have been filtered.") all_tfs = pd.read_csv(self.tfs_path,header=None) self.genes = self.adata.var_names self.tfs = self.genes.intersection(all_tfs[0].tolist()) self.generate_data_p_time( cell_divide_per_time, cell_select_per_time, cell_generate_per_time ) elif data_type == "2_time": # 2_time Pattern need two adata files, the cell mapping file and the tfs file. if type(expression_matrix_path) != list or len(expression_matrix_path) != 2: sys.exit( "In 2_time mode, a list of gene expression matrix paths for two slices must be provided." ) self.adata1 = sc.read(expression_matrix_path[0]) self.adata2 = sc.read(expression_matrix_path[1]) self.cell_mapping = pd.read_csv(cell_mapping_path, index_col=0) self.tfs_path = tfs_path # filter genes if min_cells == None: min_cells = [round(self.adata1.n_obs*0.3),round(self.adata2.n_obs*0.3)] if type(min_cells) != list or len(min_cells) != 2: sys.exit( "In 2_time mode, you should provide a list contains two min_cells parameters to filter two adata files." ) sc.pp.filter_genes(self.adata1, min_cells=min_cells[0], inplace=True) sc.pp.filter_genes(self.adata2, min_cells=min_cells[1], inplace=True) # only use the mapping cells self.adata1 = self.adata1[self.cell_mapping.slice1] self.adata2 = self.adata2[self.cell_mapping.slice2] # get same genes same_genes = list(self.adata1.var_names & self.adata2.var_names) self.adata1 = self.adata1[:, same_genes] self.adata2 = self.adata2[:, same_genes] self.genes = self.adata1.var_names # get tfs all_tfs = pd.read_table(self.tfs_path, header=None) self.tfs = self.genes.intersection(all_tfs[0].tolist()) self.generate_data_2_time(cell_generate_per_time, cell_select_per_time) def getMetaData(self, cell_generate_per_time, cell_select_per_time) -> None: """ Randomly sample cell expression and generate meta data. Parameters ---------- cell_generate_per_time The amount of data generated at each time point. cell_select_per_time The amount of data randomly selected at each time point. """ new_data_in = [] new_data_out = [] for i in range(cell_generate_per_time): cell_indexes = random.sample( range(self.input_data.shape[0]), cell_select_per_time ) new_data_in.append(np.mean(self.input_data[cell_indexes], axis=0)) new_data_out.append(np.mean(self.output_data[cell_indexes], axis=0)) self.input_data = np.stack(new_data_in) self.output_data = np.stack(new_data_out)
[docs] def run( self, training_times: int = 10, iter_times: int = 30, mapping_num: int = 3000, filename: str = "weights.csv", lr_ratio: float = 0.1 ) -> None: """ Run the trainer. Parameters ---------- training_times Number of times to randomly initialize the model and retrain. (Default: 10) iter_times The number of iterations for each training model, by default 30. (Default: 30) mapping_num The number of top weight pairs you want to extract. (Default: 3000) filename The saved file name. (Default: 'weights.csv') """ self.mapping_num = mapping_num self.all_gtf = [] with tqdm(total=training_times * iter_times) as pbar: for i in range(training_times): pbar.set_description(f"Train {i + 1}") self.model = Model(len(self.genes), len(self.tfs)).to(self.device) loss_fn = nn.MSELoss() optimizer = torch.optim.SGD(self.model.parameters(), lr=lr_ratio) for t in range(iter_times): train_loss = self.train( self.train_dl, self.model, loss_fn, optimizer ) test_loss = self.test(self.test_dl, self.model, loss_fn) pbar.set_postfix( {"train_loss": train_loss, "test_loss": test_loss}, refresh=True ) pbar.update() gtf = self.model.linear.weight.T self.all_gtf.append(gtf) # set the highest weighted map (TF itself) to 0 sum_all_gtf = torch.mean(torch.stack(self.all_gtf), dim=0) _, self_idx = torch.max(sum_all_gtf, dim=0) for i in range(len(self_idx)): sum_all_gtf[self_idx[i], i] = 0 # save most important maps flat_tensor = torch.flatten(sum_all_gtf) sorted_tensor, _ = torch.sort(flat_tensor, descending=True) if len(sorted_tensor) <= mapping_num: mapping_num = len(sorted_tensor) - 1 print(f"Only got {mapping_num+1} pairs of weights.") max_value = sorted_tensor[mapping_num] # min_value = sorted_tensor[-(mapping_num + 1)] self.max_TF_idx = torch.nonzero(sum_all_gtf > max_value) # self.min_TF_idx = torch.nonzero(sum_all_gtf < min_value) network_rows = [] for i in self.max_TF_idx: gene = self.genes[i[0].item()] TF = self.tfs[i[1].item()] weight = sum_all_gtf[i[0].item(), i[1].item()].item() one_row = [TF, gene, weight] network_rows.append(one_row) # for i in self.min_TF_idx: # gene = self.genes[i[0].item()] # TF = self.tfs[i[1].item()] # weight = sum_all_gtf[i[0].item(), i[1].item()].item() # one_row = [TF, gene, weight] # network_rows.append(one_row) columns = ["TF", "gene", "weight"] self.network_df = pd.DataFrame(data=network_rows, columns=columns) self.network_df.sort_values(by="weight", ascending=False, inplace=True) self.network_df.to_csv(filename, index=0) print(f"Weight relationships of tfs and genes are stored in {filename}.")
def lm(self,gene_name,TF_name): gene_loc = self.genes.get_loc(gene_name) TF_loc = self.tfs.get_loc(TF_name) train_x = self.output_data[:, TF_loc] # TF train_y = self.input_data[:, gene_loc] # gene # zero_TF_idx=np.where(train_x==0)[0] # train_x = np.delete(train_x,zero_TF_idx) # train_y = np.delete(train_y, zero_TF_idx) # print(train_x.shape,train_y.shape) theta0=np.random.rand() theta1=np.random.rand() def f(x): return theta0+theta1*x def E(x,y): return 0.5*np.sum((y-f(x))**2) ETA=1e-4 diff=1 count=0 error=E(train_x,train_y) while diff>1E-2: tmp0=theta0-ETA*np.sum((f(train_x)-train_y)) tmp1=theta1-ETA*np.sum((f(train_x)-train_y)*train_x) theta0=tmp0 theta1=tmp1 current_error=E(train_x,train_y) diff=error-current_error error=current_error count+=1 x=np.linspace(0,1,100) plt.plot(train_x,train_y,'o',markersize=3,color='black',alpha=0.5) plt.plot( x,f(x),color='#e87d72',linewidth=3) plt.xlabel(TF_name,fontsize=14) plt.ylabel('Dynamics of '+gene_name,fontsize=14) plt.axis('equal') return plt
[docs] def plot_scatter( self, num_rows: int = 3, num_cols: int = 3, fig_width: int = 10, fig_height: int = 9.5, ) -> None: """ Show the relationship between TF and gene changes through scatter plot. Parameters ---------- num_rows The number of rows in the graph. (Default: 3) num_cols The number of columns in the graph. (Default: 3) fig_width The width of the image. (Default: 10) fig_height The height of the image. (Default: 9.5) """ fig = plt.figure(figsize=(fig_width,fig_height)) for i in range(num_rows*num_cols): row_idx=i row=self.network_df.iloc[row_idx] gene_name=row['gene'] TF_name=row['TF'] subplt = fig.add_subplot(num_rows, num_cols, i+1) plot = self.lm(gene_name,TF_name) # subplt.imshow(plot) # subplt.axis('off') fig.savefig('output.png',dpi=300) plt.show()
def generate_data_2_time(self, cell_generate_per_time, cell_select_per_time): """ Generate data in the 2_time mode. Parameters ---------- cell_generate_per_time The amount of data generated at each time point. cell_select_per_time The amount of data randomly selected at each time point. """ delta_gene = self.adata2.X - self.adata1.X if issparse(delta_gene): delta_gene = delta_gene.A self.get_one_hot() self.input_data = np.array(delta_gene, dtype=np.float32) self.output_data = np.array(self.adata2.X @ self.T.T, dtype=np.float32) self.getMetaData(cell_generate_per_time, cell_select_per_time) # normalize data self.input_data = (self.input_data - self.input_data.mean(axis=0)) / ( self.input_data.max(axis=0) - self.input_data.min(axis=0) ) self.output_data = (self.output_data - self.output_data.min(axis=0)) / ( self.output_data.max(axis=0) - self.output_data.min(axis=0) ) # # shuffle the cell order (2_time data has been shuffled in the meta step) # permuted_idxs=np.random.permutation(self.input_data.shape[0]) # self.input_data=nor_input_data[permuted_idxs] # self.output_data=nor_output_data[permuted_idxs] self.get_dataloader() def generate_data_p_time( self, cell_divide_per_time, cell_select_per_time, cell_generate_per_time ) -> None: """ Generate data in the 2_time mode. Parameters ---------- cell_divide_per_time The number of divided cells per time point. cell_select_per_time The amount of data randomly selected at each time point. cell_generate_per_time The amount of data generated at each time point. """ sub_index = self.sort_idx(cell_divide_per_time) self.mean_data = [] # mean expression of genes at each time point self.origin_data = [] # time_point * cell * gene expression for i in range(len(sub_index)): self.origin_data.append(np.array(self.adata[sub_index[i]].X)) self.mean_data.append(self.adata[sub_index[i]].X.mean(axis=0)) self.mean_data = np.array(self.mean_data) self.origin_data = np.stack(self.origin_data) self.get_one_hot() input_data = [] output_data = [] for i in range(1, len(self.origin_data)): for j in range(cell_generate_per_time): random_idxs = random.sample( range(len(self.origin_data[i])), cell_select_per_time ) meta_expr = np.array(self.origin_data[i][random_idxs].mean(axis=0)) delta_gene = meta_expr - self.mean_data[i - 1] tf_expr = meta_expr @ self.T.T input_data.append(delta_gene) output_data.append(tf_expr) print( f"{(len(self.origin_data)-1)} groups of new data were generated, each with {cell_generate_per_time} meta cells." ) input_data = np.array(input_data, dtype=np.float32) output_data = np.array(output_data, dtype=np.float32) nor_input_data = (input_data - input_data.mean(axis=0)) / ( input_data.max(axis=0) - input_data.min(axis=0) ) nor_output_data = (output_data - output_data.min(axis=0)) / ( output_data.max(axis=0) - output_data.min(axis=0) ) permuted_idxs = np.random.permutation(input_data.shape[0]) self.input_data = nor_input_data[permuted_idxs] self.output_data = nor_output_data[permuted_idxs] self.get_dataloader() def get_dataloader(self): """ Convert data to dataloader form. """ train_ratio = self.train_ratio num_samples = self.input_data.shape[0] train_size = int(num_samples * train_ratio) train_data_in = torch.from_numpy(self.input_data[:train_size, :]) train_data_out = torch.from_numpy(self.output_data[:train_size, :]) test_data_in = torch.from_numpy(self.input_data[train_size:, :]) test_data_out = torch.from_numpy(self.output_data[train_size:, :]) train_set = TensorDataset(train_data_in, train_data_out) test_set = TensorDataset(test_data_in, test_data_out) batch_size = 32 self.train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True) self.test_dl = DataLoader(test_set, batch_size=batch_size) def plot_gene_regulation(self, min_weight, min_node_num, cmap="coolwarm") -> None: """ Draw the gene regulation network graph. Parameters ---------- min_weight Filter relation pairs whose weight is less than min_weight. min_node_num Min node numbef of Tfs. cmap Platte used.. (Default: 'coolwarm') """ df = self.network_df df = df.loc[df["weight"].abs() > min_weight] print(f"num of weight pairs after weight filtering: {len(df)}") df_TF = pd.Series(df["TF"].value_counts()) label_name = pd.Series(df_TF[df_TF >= min_node_num].index).to_dict() df = df.loc[df["TF"].isin(label_name.values())] print(f"num of weight pairs after node_count filtering: {len(df)}") G = nx.from_pandas_edgelist(df, "TF", "gene", create_using=nx.Graph()) nodes = G.nodes() degree = G.degree() colors = [degree[n] for n in nodes] # size = [(degree[n]) for n in nodes] pos = nx.kamada_kawai_layout(G) betCent = nx.betweenness_centrality(G, normalized=True, endpoints=True) node_color = [2000.0 * G.degree(v) for v in G] # node_color = [community_index[n] for n in H] node_size = 5 label_name_new = dict(zip(label_name.values(), label_name.values())) fig = plt.figure(figsize=(6, 6), dpi=200) # nx.draw_networkx(G,pos,alpha = 0.8, node_color = node_color, # node_size = node_size ,font_size = 20, width = 0.4, cmap = cmap, # with_labels=True, labels=label_name_new,edge_color ='grey') nx.draw_networkx_nodes( G, pos, node_color=node_color, node_size=node_size, cmap=cmap, alpha=0.8 ) nx.draw_networkx_edges(G, pos, alpha=0.1, width=1) nx.draw_networkx_labels( G, pos, font_size=8, font_color="orange", labels=label_name_new, bbox={"boxstyle": "round", "facecolor": "white", "edgecolor": "orange"}, ) def get_one_hot( self, ) -> None: """ Generate one-hot matrix from TFs to genes. Returns ------- One-hot matrix from TFs to genes. """ vectorizer = CountVectorizer( vocabulary=self.genes.tolist(), lowercase=False ) # lowercase gene names self.T = vectorizer.fit_transform(self.tfs).toarray() def sort_idx(self, cell_divide_per_time) -> np.array: """ Divide the cells in the pseudo-time series to obtain cell sets of different time segments. Parameters ---------- cell_divide_per_time The number of divided cells per time point. Returns ----------- np.array An array to divide cells. """ ptime = self.adata.obs["ptime"] ptime_0_idx = ptime[ptime == 0].index ptime_1_idx = ptime[ptime == 1].index middle_idx = list( self.adata.obs.index[len(ptime_0_idx) : len(self.adata) - len(ptime_1_idx)] ) sub_length = cell_divide_per_time sub_index = [ middle_idx[i : i + sub_length] for i in range(0, len(middle_idx), sub_length) ][:-1] print( f"{len(self.adata)} cells were divided into {len(sub_index)} groups according to the pseudo-time, with {cell_divide_per_time} cells in each group. The first and last cells are discarded." ) # sub_index.insert(0,list(ptime_0_index)) # insert head indexes # sub_index.append(list(ptime_1_index)) # append tail inedexes return sub_index def train(self, dataloader, model, loss_fn, optimizer): """ Train step. Parameters ---------- dataloader Dataloader that contains the input and output data. model Model used to infer the realationship. loss_fn Loss function. optimizer The optimizer used to reduce the loss value. Returns ----------- float Training loss. """ size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(self.device), y.to(self.device) # 计算预测误差 pred = model(X) loss = loss_fn(pred, y) # 反向传播 loss.backward() optimizer.step() optimizer.zero_grad() # if batch % 25 ==0: # loss,current=loss.item(), (batch+1)*len(X) # print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") return loss.item() def test(self, dataloader, model, loss_fn): """ Test step Parameters ---------- dataloader Dataloader that contains the input and output data. model Model used to infer the realationship. loss_fn The optimizer used to reduce the loss value. Returns ----------- float Testing loss. """ # size=len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss = 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(self.device), y.to(self.device) pred = model(X) test_loss += loss_fn(pred, y).item() test_loss /= num_batches return test_loss print(f"Avg loss: {test_loss:>8f} \n")