|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import sys |
| 4 | +import time |
| 5 | +import loompy as lp |
| 6 | +import pandas as pd |
| 7 | +from multiprocessing import Pool, cpu_count |
| 8 | +import argparse |
| 9 | +import tqdm |
| 10 | + |
| 11 | +from arboreto.utils import load_tf_names |
| 12 | +from arboreto.algo import genie3, grnboost2, _prepare_input |
| 13 | +from arboreto.core import SGBM_KWARGS, RF_KWARGS, EARLY_STOP_WINDOW_LENGTH |
| 14 | +from arboreto.core import to_tf_matrix, target_gene_indices, infer_partial_network |
| 15 | + |
| 16 | +from pyscenic.cli.utils import load_exp_matrix |
| 17 | + |
| 18 | + |
| 19 | +################################################################################ |
| 20 | +################################################################################ |
| 21 | + |
| 22 | +parser_grn = argparse.ArgumentParser(description='Run Arboreto using a multiprocessing pool') |
| 23 | + |
| 24 | +parser_grn.add_argument('expression_mtx_fname', |
| 25 | + type=argparse.FileType('r'), |
| 26 | + help='The name of the file that contains the expression matrix for the single cell experiment.' |
| 27 | + ' Two file formats are supported: csv (rows=cells x columns=genes) or loom (rows=genes x columns=cells).') |
| 28 | +parser_grn.add_argument('tfs_fname', |
| 29 | + type=argparse.FileType('r'), |
| 30 | + help='The name of the file that contains the list of transcription factors (TXT; one TF per line).') |
| 31 | +parser_grn.add_argument('-m', '--method', choices=['genie3', 'grnboost2'], |
| 32 | + default='grnboost2', |
| 33 | + help='The algorithm for gene regulatory network reconstruction (default: grnboost2).') |
| 34 | +parser_grn.add_argument('-o', '--output', |
| 35 | + type=argparse.FileType('w'), default=sys.stdout, |
| 36 | + help='Output file/stream, i.e. a table of TF-target genes (TSV).') |
| 37 | +parser_grn.add_argument('--num_workers', |
| 38 | + type=int, default=cpu_count(), |
| 39 | + help='The number of workers to use. (default: {}).'.format(cpu_count())) |
| 40 | +parser_grn.add_argument('--seed', type=int, required=False, default=None, |
| 41 | + help='Seed value for regressor random state initialization (optional)') |
| 42 | + |
| 43 | +parser_grn.add_argument('--cell_id_attribute', |
| 44 | + type=str, default='CellID', |
| 45 | + help='The name of the column attribute that specifies the identifiers of the cells in the loom file.') |
| 46 | +parser_grn.add_argument('--gene_attribute', |
| 47 | + type=str, default='Gene', |
| 48 | + help='The name of the row attribute that specifies the gene symbols in the loom file.') |
| 49 | +parser_grn.add_argument('--sparse', action='store_const', const=True, default=False, |
| 50 | + help='If set, load the expression data as a sparse (CSC) matrix.') |
| 51 | +parser_grn.add_argument('-t', '--transpose', action='store_const', const = 'yes', |
| 52 | + help='Transpose the expression matrix (rows=genes x columns=cells).') |
| 53 | + |
| 54 | +args = parser_grn.parse_args() |
| 55 | + |
| 56 | + |
| 57 | +################################################################################ |
| 58 | +################################################################################ |
| 59 | + |
| 60 | + |
| 61 | +if(args.method == 'grnboost2'): |
| 62 | + method_params = [ |
| 63 | + 'GBM', # regressor_type |
| 64 | + SGBM_KWARGS # regressor_kwargs |
| 65 | + ] |
| 66 | +elif(args.method == 'genie3'): |
| 67 | + method_params = [ |
| 68 | + 'RF', # regressor_type |
| 69 | + RF_KWARGS # regressor_kwargs |
| 70 | + ] |
| 71 | + |
| 72 | + |
| 73 | +def run_infer_partial_network(target_gene_index): |
| 74 | + target_gene_name = gene_names[target_gene_index] |
| 75 | + target_gene_expression = ex_matrix[:, target_gene_index] |
| 76 | + |
| 77 | + n = infer_partial_network( |
| 78 | + regressor_type=method_params[0], |
| 79 | + regressor_kwargs=method_params[1], |
| 80 | + tf_matrix=tf_matrix, |
| 81 | + tf_matrix_gene_names=tf_matrix_gene_names, |
| 82 | + target_gene_name=target_gene_name, |
| 83 | + target_gene_expression=target_gene_expression, |
| 84 | + include_meta=False, |
| 85 | + early_stop_window_length=EARLY_STOP_WINDOW_LENGTH, |
| 86 | + seed=args.seed) |
| 87 | + return( n ) |
| 88 | + |
| 89 | + |
| 90 | +if __name__ == '__main__': |
| 91 | + |
| 92 | + start_time = time.time() |
| 93 | + ex_matrix = load_exp_matrix(args.expression_mtx_fname.name, |
| 94 | + (args.transpose == 'yes'), |
| 95 | + args.sparse, |
| 96 | + args.cell_id_attribute, |
| 97 | + args.gene_attribute) |
| 98 | + |
| 99 | + if args.sparse: |
| 100 | + gene_names = ex_matrix[1] |
| 101 | + ex_matrix = ex_matrix[0] |
| 102 | + else: |
| 103 | + gene_names = ex_matrix.columns |
| 104 | + |
| 105 | + end_time = time.time() |
| 106 | + print(f'Loaded expression matrix of {ex_matrix.shape[0]} cells and {ex_matrix.shape[1]} genes in {end_time - start_time} seconds...', file=sys.stdout) |
| 107 | + |
| 108 | + tf_names = load_tf_names(args.tfs_fname.name) |
| 109 | + print(f'Loaded {len(tf_names)} TFs...', file=sys.stdout) |
| 110 | + |
| 111 | + ex_matrix, gene_names, tf_names = _prepare_input(ex_matrix, gene_names, tf_names) |
| 112 | + tf_matrix, tf_matrix_gene_names = to_tf_matrix(ex_matrix, gene_names, tf_names) |
| 113 | + |
| 114 | + print(f'starting {args.method} using {args.num_workers} processes...', file=sys.stdout) |
| 115 | + start_time = time.time() |
| 116 | + |
| 117 | + with Pool(args.num_workers) as p: |
| 118 | + adjs = list(tqdm.tqdm(p.imap(run_infer_partial_network, |
| 119 | + target_gene_indices(gene_names, target_genes='all'), |
| 120 | + chunksize=1 |
| 121 | + ), |
| 122 | + total=len(gene_names))) |
| 123 | + |
| 124 | + adj = pd.concat(adjs).sort_values(by='importance', ascending=False) |
| 125 | + |
| 126 | + end_time = time.time() |
| 127 | + print(f'Done in {end_time - start_time} seconds.', file=sys.stdout) |
| 128 | + |
| 129 | + adj.to_csv(args.output, index=False, sep="\t") |
| 130 | + |
0 commit comments