Source code for snaf.gtex

#!/data/salomonis2/LabFiles/Frank-Li/refactor/neo_env/bin/python3.7

import numpy as np
import pandas as pd
import os
import sys
import pickle
import h5py
import matplotlib.pyplot as plt
import anndata as ad
from scipy.optimize import minimize, minimize_scalar
from scipy import stats
from scipy.sparse import csr_matrix, find
from tqdm import tqdm
import re

try:
    import pymc3 as pm   # conda install -c conda-forge pymc3 mkl-service
    import theano
    import arviz as az
except ImportError:
    print('''
        Optional package pymc3 is not installed, it is for calculating tumor specificity using hirerarchical bayesian model
        For Linux: https://github.com/pymc-devs/pymc/wiki/Installation-Guide-(Linux)
        For MacOS: https://github.com/pymc-devs/pymc/wiki/Installation-Guide-(MacOS)
        For PC:    https://github.com/pymc-devs/pymc/wiki/Installation-Guide-(Windows)
    ''')

'''
this script is to query the tumor specificity of the junction
'''


def gtex_configuration(df,gtex_db,t_min_arg,n_max_arg,normal_cutoff_arg,tumor_cutoff_arg,normal_prevalance_cutoff_arg,tumor_prevalance_cutoff_arg,add_control=None):
    global adata_gtex
    global adata
    global t_min
    global n_max
    global normal_cutoff
    global tumor_cutoff
    global normal_prevalance_cutoff
    global tumor_prevalance_cutoff
    tested_junctions = set(df.index)
    adata = ad.read_h5ad(gtex_db)
    adata = adata[np.logical_not(adata.obs_names.duplicated()),:] 
    adata = adata[list(set(adata.obs_names).intersection(tested_junctions)),:]  
    print('Current loaded gtex cohort with shape {}'.format(adata.shape))
    tissue_dict = adata.var['tissue'].to_dict()
    adata_gtex = adata   # already has mean and tissue variables
    if add_control is not None:
        for id_, control in add_control.items():
            if isinstance(control,pd.DataFrame):
                assert len(set(control.columns).intersection(tissue_dict.keys())) == 0  # sample id can not be ambiguous
                control = control.loc[np.logical_not(control.index.duplicated()),:]
                control = control.loc[list(set(control.index).intersection(tested_junctions)),:]
                print('Adding cohort {} with shape {} to the database'.format(id_,control.shape))
                tissue_dict_right = {k:id_ for k in control.columns}
                tissue_dict.update(tissue_dict_right)
                df_left = adata.to_df()
                df_right = control
                df_combine = df_left.join(other=df_right,how='outer').fillna(0)
                adata = ad.AnnData(X=csr_matrix(df_combine.values),obs=pd.DataFrame(index=df_combine.index),var=pd.DataFrame(index=df_combine.columns))

            elif isinstance(control,ad.AnnData):
                assert len(set(control.var_names).intersection(tissue_dict.keys())) == 0
                control = control[np.logical_not(control.obs_names.duplicated()),:]
                control = control[list(set(control.obs_names).intersection(tested_junctions)),:]
                print('Adding cohort {} with shape {} to the database'.format(id_,control.shape))
                if 'tissue' in control.var.columns:   # if tissue is in var columns, it will be used 
                    tissue_dict_right = control.var['tissue'].to_dict()
                else:
                    tissue_dict_right = {k:id_ for k in control.var_names}
                tissue_dict.update(tissue_dict_right)
                df_left = adata.to_df()
                df_right = control.to_df()
                df_combine = df_left.join(other=df_right,how='outer').fillna(0)
                adata = ad.AnnData(X=csr_matrix(df_combine.values),obs=pd.DataFrame(index=df_combine.index),var=pd.DataFrame(index=df_combine.columns))
            
            else:
                raise Exception('control must be either in dataframe or anndata format')

            print('now the shape of control db is {}'.format(adata.shape))
            adata.var['tissue'] = adata.var_names.map(tissue_dict).values
            adata.obs['mean'] = np.array(adata.X.mean(axis=1)).squeeze()
            total_count = np.array(adata.X.sum(axis=0)).squeeze() / 1e6
            adata.var['total_count'] = total_count
            

    t_min = t_min_arg
    n_max = n_max_arg
    normal_cutoff = normal_cutoff_arg
    tumor_cutoff = tumor_cutoff_arg
    normal_prevalance_cutoff = normal_prevalance_cutoff_arg
    tumor_prevalance_cutoff = tumor_prevalance_cutoff_arg

    return adata


def multiple_crude_sifting(junction_count_matrix,add_control,dict_exonlist,outdir,filter_mode):
    if filter_mode == 'prevalance':
        valid,invalid,cond_df = multiple_crude_sifting_prevalance(junction_count_matrix,add_control,dict_exonlist,outdir)
    elif filter_mode == 'maxmin':
        valid,invalid,cond_df = multiple_crude_sifting_maxmin(junction_count_matrix,add_control,dict_exonlist,outdir)
    return valid,invalid, cond_df

def multiple_crude_sifting_prevalance(junction_count_matrix,add_control=None,dict_exonlist=None,outdir='.'):

    if not os.path.exists(outdir):
        os.mkdir(outdir)

    df_to_write = []
    df = pd.DataFrame(index=junction_count_matrix.index)
    prevalance_tumor = np.count_nonzero((junction_count_matrix > tumor_cutoff).values,axis=1) / junction_count_matrix.shape[1]
    df['prevalance_tumor'] = prevalance_tumor
    # consider gtex
    prevalance_normal = np.count_nonzero((adata_gtex.X > normal_cutoff).toarray(),axis=1) / adata_gtex.shape[1]
    prevalance_normal_dict = {j:v for j,v in zip(adata_gtex.obs_names,prevalance_normal)}
    df['prevalance_normal'] = df.index.map(prevalance_normal_dict).fillna(value=0)
    df['cond'] = (df['prevalance_tumor'] > tumor_prevalance_cutoff) & (df['prevalance_normal'] < normal_prevalance_cutoff)
    valid = df.loc[df['cond']].index.tolist()
    tmp = df.copy()
    df_to_write.append(tmp)
    print('reduce valid NeoJunction from {} to {} because they are present in GTEx'.format(df.shape[0],len(valid)))
    if dict_exonlist is not None:   # a valid junction can not be present in any ensembl documented transcript
        updated_valid = []
        for uid in tqdm(valid):
            ensg = uid.split(':')[0]
            exons = ':'.join(uid.split(':')[1:])
            if '_' in exons or 'U' in exons or 'ENSG' in exons or 'I' in exons:
                updated_valid.append(uid)
            else:
                exonlist = dict_exonlist[ensg]
                exonstring = '|'.join(exonlist)
                e1,e2 = exons.split('-')
                pattern1 = re.compile(r'^{}\|{}\|'.format(e1,e2))  # ^E1.1|E2.3|
                pattern2 = re.compile(r'\|{}\|{}$'.format(e1,e2))  # |E1.1|E2.3$
                pattern3 = re.compile(r'\|{}\|{}\|'.format(e1,e2)) # |E1.1|E2.3|
                if re.search(pattern3,exonstring) or re.search(pattern2,exonstring) or re.search(pattern1,exonstring):   # as long as match one pattern, should be eliminated
                    continue
                else:
                    updated_valid.append(uid)
        print('reduce valid Neojunction from {} to {} because they are present in Ensembl db'.format(len(valid),len(updated_valid)))
        valid = updated_valid
    # consider add_control
    if add_control is not None:
        for i,(id_,control) in enumerate(add_control.items()):
            n_previous_valid = len(valid)
            if isinstance(control,pd.DataFrame):
                prevalance_normal = np.count_nonzero((control > normal_cutoff).values,axis=1) / control.shape[1]
                prevalance_normal_dict = {j:v for j,v in zip(control.index, prevalance_normal)}
            elif isinstance(control,ad.AnnData):
                prevalance_normal = np.count_nonzero((control.X > normal_cutoff).toarray(),axis=1) / control.shape[1]
                prevalance_normal_dict = {j:v for j,v in zip(control.obs_names,prevalance_normal)}
            else:
                raise Exception('control must be either in dataframe or anndata format')
            df['prevalance_normal_add'] = df.index.map(prevalance_normal_dict).fillna(value=0)
            df['cond_add'] = (df['prevalance_tumor'] > tumor_prevalance_cutoff) & (df['prevalance_normal_add'] < normal_prevalance_cutoff)
            valid_add = df.loc[df['cond_add']].index.tolist()
            valid = list(set(valid).intersection(set(valid_add)))
            tmp = df.copy(); tmp.drop(columns=['prevalance_tumor','prevalance_normal','cond'],inplace=True); tmp.rename(columns=lambda x:x+'_{}'.format(id_),inplace=True)
            df_to_write.append(tmp)
            print('reduce valid Neojunction from {} to {} because they are present in added control {}'.format(n_previous_valid,len(valid),id_))
    invalid = list(set(junction_count_matrix.index).difference(set(valid)))
    # now, consider each entry
    t_min = tumor_cutoff
    valid_set = set(valid)
    cond_dict = {j:(True if j in valid_set else False) for j in junction_count_matrix.index}
    tmp = pd.DataFrame(index=junction_count_matrix.index,data={'placeholder':junction_count_matrix.index.map(cond_dict).values})
    first_half_cond_df = pd.concat([tmp]*junction_count_matrix.shape[1],axis=1)
    first_half_cond_df.columns = junction_count_matrix.columns
    cond_df = (first_half_cond_df) & (junction_count_matrix > t_min)
    # write the df
    df_to_write = pd.concat(df_to_write,axis=1)
    df_to_write.to_csv(os.path.join(outdir,'NeoJunction_statistics_prevalance.txt'),sep='\t')
    return valid,invalid,cond_df



def multiple_crude_sifting_maxmin(junction_count_matrix,add_control=None,dict_exonlist=None,outdir='.'):   # for JunctionCountMatrixQuery class, only consider gtex
    if not os.path.exists(outdir):
        os.mkdir(outdir)
        
    df = pd.DataFrame(index=junction_count_matrix.index,data = {'max':junction_count_matrix.max(axis=1).values})
    df_to_write = []
    # consider gtex
    junction_to_mean = adata_gtex.obs.loc[adata_gtex.obs_names.isin(junction_count_matrix.index),'mean'].to_dict()
    df['mean'] = df.index.map(junction_to_mean).fillna(value=0)
    df['diff'] = df['max'] - df['mean']
    df['cond'] = (df['mean'] < n_max) & (df['diff'] > t_min)
    valid = df.loc[df['cond']].index.tolist()
    tmp = df.copy()
    df_to_write.append(tmp)
    print('reduce valid NeoJunction from {} to {} because they are present in GTEx'.format(df.shape[0],len(valid)))
    if dict_exonlist is not None:   # a valid junction can not be present in any ensembl documented transcript
        updated_valid = []
        for uid in tqdm(valid):
            ensg = uid.split(':')[0]
            exons = ':'.join(uid.split(':')[1:])
            if '_' in exons or 'U' in exons or 'ENSG' in exons or 'I' in exons:
                updated_valid.append(uid)
            else:
                exonlist = dict_exonlist[ensg]
                exonstring = '|'.join(exonlist)
                e1,e2 = exons.split('-')
                pattern1 = re.compile(r'^{}\|{}\|'.format(e1,e2))  # ^E1.1|E2.3|
                pattern2 = re.compile(r'\|{}\|{}$'.format(e1,e2))  # |E1.1|E2.3$
                pattern3 = re.compile(r'\|{}\|{}\|'.format(e1,e2)) # |E1.1|E2.3|
                if re.search(pattern3,exonstring) or re.search(pattern2,exonstring) or re.search(pattern1,exonstring):   # as long as match one pattern, should be eliminated
                    continue
                else:
                    updated_valid.append(uid)
        print('reduce valid Neojunction from {} to {} because they are present in Ensembl db'.format(len(valid),len(updated_valid)))
        valid = updated_valid
    # consider add_control
    mean_add_list = []
    if add_control is not None:
        for i,(id_,control) in enumerate(add_control.items()):
            n_previous_valid = len(valid)
            if isinstance(control,pd.DataFrame):
                junction_to_mean = control.mean(axis=1).to_dict()
            elif isinstance(control,ad.AnnData):
                junction_to_mean = control.to_df().mean(axis=1).to_dict()
            else:
                raise Exception('control must be either in dataframe or anndata format')
            df['mean_add'] = df.index.map(junction_to_mean).fillna(value=0)
            df['diff_add'] = df['max'] - df['mean_add']
            df['cond_add'] = (df['mean_add'] < n_max) & (df['diff_add'] > t_min)
            mean_add_list.append(df['mean_add'])
            valid_add = df.loc[df['cond_add']].index.tolist()
            valid = list(set(valid).intersection(set(valid_add)))
            tmp = df.copy(); tmp.drop(columns=['mean','diff','cond'],inplace=True); tmp.rename(columns=lambda x:x+'_{}'.format(id_),inplace=True)
            df_to_write.append(tmp)
            print('reduce valid Neojunction from {} to {} because they are present in added control {}'.format(n_previous_valid,len(valid),id_))
    invalid = list(set(junction_count_matrix.index).difference(set(valid)))
    # now, consider each entry
    gtex_df = pd.concat([df['mean']]*junction_count_matrix.shape[1],axis=1)
    gtex_df.columns = junction_count_matrix.columns
    diff_df_gtex = junction_count_matrix - gtex_df
    cond_df = (gtex_df < n_max) & (diff_df_gtex > t_min)
    if add_control is not None:
        for mean_add in mean_add_list:
            add_df = pd.concat([mean_add]*junction_count_matrix.shape[1],axis=1)    
            add_df.columns = junction_count_matrix.columns
            diff_df_add = junction_count_matrix - add_df
            cond_df = cond_df & (add_df < n_max) & (diff_df_add > t_min)
    # write the df
    df_to_write = pd.concat(df_to_write,axis=1)
    df_to_write.to_csv(os.path.join(outdir,'NeoJunction_statistics_maxmin.txt'),sep='\t')
    return valid,invalid,cond_df


def crude_tumor_specificity(uid,count):    # for NeoJunction class, since we normally start from Jcmq with check_gtex=False, rarely being called.
    detail = ''
    if uid not in set(adata.obs_names):
        mean_value = 0
    else:
        mean_value = adata.obs.loc[uid,'mean']
    diff = count - mean_value
    if mean_value < n_max and diff >= t_min:
        identity = True
    else:
        identity = False
    return identity,mean_value


def mle_func(parameters,y):
    sigma = parameters
    ll = np.sum(stats.halfnorm.logpdf(y,0,sigma))
    neg_ll = -1 * ll
    return neg_ll

def split_df_to_chunks(df,cores=None):
    df_index = np.arange(df.shape[0])
    if cores is None:
        cores = mp.cpu_count()
    sub_indices = np.array_split(df_index,cores)
    sub_dfs = [df.iloc[sub_index,:] for sub_index in sub_indices]
    return sub_dfs


def split_array_to_chunks(array,cores=None):
    if not isinstance(array,list):
        raise Exception('split_array_to_chunks function works for list, not ndarray')
    array_index = np.arange(len(array))
    if cores is None:
        cores = mp.cpu_count()
    sub_indices = np.array_split(array_index,cores)
    sub_arrays = []
    for sub_index in sub_indices:
        item_in_group = []
        for i in sub_index:
            item_in_group.append(array[i])
        sub_arrays.append(item_in_group)
    return sub_arrays


[docs]def add_tumor_specificity_frequency_table(df,method='mean',remove_quote=True,cores=None): ''' add tumor specificty to each neoantigen-uid in the frequency table produced by SNAF T pipeline :param df: DataFrame, the frequency table produced by SNAF T pipeline :param method: string, either 'mean', or 'mle', or 'bayesian' :param remove quote: boolean, whether to remove the quotation or not, as one column in frequency table df is list, when loaded in memory using pandas, it will be added a quote, we can remove it :param cores: int, how many cpu cores to use for this computation, default None and use all the cpu the program detected :return new_df: a dataframe with one added column containing tumor specificity score Example:: snaf.add_tumor_specificity_frequency_table(df,'mle',remove_quote=True) ''' from ast import literal_eval import multiprocessing as mp if remove_quote: df['samples'] = [literal_eval(item) for item in df['samples']] if cores is None: cores = mp.cpu_count() if method != 'bayesian': pool = mp.Pool(processes=cores) print('{} subprocesses have been spawned'.format(cores)) all_unique_junctions = list(set([item.split(',')[1] for item in df.index])) sub_arrays = split_array_to_chunks(all_unique_junctions,cores=cores) r = [pool.apply_async(func=add_tumor_specificity_frequency_table_atomic_func,args=(sub_array,method,)) for sub_array in sub_arrays] pool.close() pool.join() results = [] for collect in r: result = collect.get() results.append(result) all_score_dict = {} for score_dict in results: all_score_dict.update(score_dict) col = [] for item in df.index: col.append(all_score_dict[item.split(',')[1]]) new_df = df.copy() new_df['tumor_specificity_{}'.format(method)] = col else: # seems like bayesian doesn't work well with multiprocessing all_unique_junctions = list(set([item.split(',')[1] for item in df.index])) score_dict = {} for uid in tqdm(all_unique_junctions,total=len(all_unique_junctions)): score_dict[uid] = tumor_specificity(uid,'bayesian') for item in df.index: col.append(score_dict[item.split(',')[1]]) new_df = df.copy() new_df['tumor_specificity_{}'.format(method)] = col return new_df
def add_tumor_specificity_frequency_table_atomic_func(sub_array,method): uid_list = sub_array score_dict = {uid:tumor_specificity(uid,method) for uid in tqdm(uid_list,total=len(uid_list))} return score_dict def tumor_specificity(uid,method,return_df=False): try: info = adata[[uid],:] except: print('{} not detected in gtex, impute as zero'.format(uid)) info = ad.AnnData(X=csr_matrix(np.full((1,adata.shape[1]),0)),obs=pd.DataFrame(data={'mean':[0]},index=[uid]),var=adata.var) # weired , anndata 0.7.6 can not modify the X in place? anndata 0.7.2 can do that in scTriangulate df = pd.DataFrame(data={'value':info.X.toarray().squeeze(),'tissue':info.var['tissue'].values},index=info.var_names) if method == 'mean': try: sigma = adata.obs.loc[uid,'mean'] except KeyError: sigma = 0 if return_df: return sigma,df else: return sigma elif method == 'mle': scale_factor_dict = adata.var['total_count'].to_dict() df['value_cpm'] = df['value'].values / df.index.map(scale_factor_dict).values y = df['value_cpm'].values # mle_model = minimize(mle_func,np.array([0.2]),args=(y,),bounds=((0,1),),method='Nelder-Mead') mle_model = minimize_scalar(mle_func,bounds=(0,1),args=(y,),method='bounded') if mle_model.success: sigma = mle_model.x else: sigma = 0 print(uid,y, mle_model) # debug purpose if return_df: return sigma,df else: return sigma elif method == 'bayesian': scale_factor_dict = adata.var['total_count'].to_dict() df['value_cpm'] = df['value'].values / df.index.map(scale_factor_dict).values y = df['value_cpm'].values x = [] for tissue in adata.var['tissue'].unique(): sub = adata[uid,adata.var['tissue']==tissue] total_count = sub.shape[1] c = np.count_nonzero(sub.X.toarray()) scaled_c = round(c * (25/total_count),0) x.append(scaled_c) x = np.array(x) try: with pm.Model() as m: sigma = pm.Uniform('sigma',lower=0,upper=1) nc = pm.HalfNormal('nc',sigma=sigma,observed=y) nc_hat = pm.Deterministic('nc_hat',pm.math.sum(nc)/len(y)) psi = pm.Beta('psi',alpha=2,beta=nc_hat*20) mu = pm.Gamma('mu',alpha=nc_hat*50,beta=1) c = pm.ZeroInflatedPoisson('c',psi,mu,observed=x) trace = pm.sample(draws=1000,step=pm.NUTS(target_accept=0.95),tune=1000,return_inferencedata=False,cores=1) ''' the error of "Got error No model on context stack. trying to find log_likelihood in translation" maybe due to pymc build and how they launch multi-cores. remember, my build can only work when cores=1, which further indicate there might be an issue revolving around it. https://stackoverflow.com/questions/69888492/sampling-of-pymc3-in-python-gets-runtime-error-of-bootstrapping-phase ''' df = az.summary(trace,round_to=2) ''' az.summary(trace) mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat sigma 0.47 0.01 0.46 0.48 0.00 0.00 182.23 98.84 1.02 nc_hat 0.22 0.00 0.22 0.22 0.00 0.00 200.00 200.00 NaN mu 22.97 0.52 21.87 23.87 0.04 0.03 196.69 94.84 1.00 az.plot_posterior(trace,var_names=['sigma','nc_hat','mu']) az.plot_forest(trace,,var_names=['sigma','nc_hat','mu']) gv = pm.model_to_graphviz(m) gv.format = 'pdf' gv.render(filename='model_graph');sys.exit('stop') # to run the above, you need to module load graphviz so that dot is exposed to the program ''' sigma = df.iloc[0]['mean'] except: sigma = None print(uid,x,y) if return_df: return sigma,df else: return sigma