biyelunwen/99.scripts/trinity_utils/util/misc/TPM_weighted_gene_length.py

158 lines
4.7 KiB
Python

#!/usr/bin/env python
# encoding: utf-8
from __future__ import (absolute_import, division,
print_function, unicode_literals)
import os, sys, re
import logging
import argparse
import collections
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
logger = logging.getLogger(__file__)
def main():
parser = argparse.ArgumentParser(description="estimates gene length as isoform lengths weighted by TPM expression values")
parser.add_argument("--gene_trans_map", dest="gene_trans_map_file", type=str, default="",
required=True, help="gene-to-transcript mapping file, format: gene_id(tab)transcript_id")
parser.add_argument("--trans_lengths", dest="trans_lengths_file", type=str, required=True,
help="transcript length file, format: trans_id(tab)length")
parser.add_argument("--TPM_matrix", dest="TPM_matrix_file", type=str, default="",
required=True, help="isoform TPM expression matrix")
parser.add_argument("--debug", required=False, action="store_true", default=False, help="debug mode")
args = parser.parse_args()
if args.debug:
logger.setLevel(logging.DEBUG)
trans_to_gene_id_dict = parse_gene_trans_map(args.gene_trans_map_file)
trans_lengths_dict = parse_trans_lengths_file(args.trans_lengths_file)
trans_to_TPM_vals_dict = parse_TPM_matrix(args.TPM_matrix_file)
weighted_gene_lengths = compute_weighted_gene_lengths(trans_to_gene_id_dict,
trans_lengths_dict,
trans_to_TPM_vals_dict)
print("#gene_id\tlength")
for gene_id,length in weighted_gene_lengths.items():
print("\t".join([gene_id,str(length)]))
sys.exit(0)
def compute_weighted_gene_lengths(trans_to_gene_id_dict, trans_lengths_dict, trans_to_TPM_vals_dict):
gene_id_to_trans_list = collections.defaultdict(list)
gene_id_to_length = {}
pseudocount = 1
for trans_id,gene_id in trans_to_gene_id_dict.items():
gene_id_to_trans_list[gene_id].append(trans_id)
for gene_id,trans_list in gene_id_to_trans_list.items():
if len(trans_list) == 1:
gene_id_to_length[gene_id] = trans_lengths_dict[ trans_list[0] ]
else:
sum_length_x_expr = 0
sum_expr = 0
trans_expr_lengths = []
for trans_id in trans_list:
trans_len = trans_lengths_dict[trans_id]
expr_vals = trans_to_TPM_vals_dict[trans_id]
trans_sum_expr = sum(expr_vals) + pseudocount
trans_expr_lengths.append((trans_len, trans_sum_expr))
sum_length_x_expr += trans_sum_expr * trans_len
sum_expr += trans_sum_expr
weighted_gene_length = sum_length_x_expr / sum_expr
gene_id_to_length[gene_id] = int(round(weighted_gene_length))
logger.debug("Computing weighted length of {0}: {1} => {2}".format(gene_id,
trans_expr_lengths,
weighted_gene_length))
return gene_id_to_length
def parse_TPM_matrix(TPM_matrix_file):
trans_to_TPM_vals_dict = {}
with open(TPM_matrix_file) as f:
header = next(f)
for line in f:
line = line.rstrip()
vals = line.split("\t")
trans_id = vals[0]
expr_vals_list = vals[1:]
expr_vals_list = [float(x) for x in expr_vals_list]
trans_to_TPM_vals_dict[trans_id] = expr_vals_list
return trans_to_TPM_vals_dict
def parse_trans_lengths_file(trans_lengths_file):
trans_id_to_length = {}
with open(trans_lengths_file) as f:
for line in f:
line = line.rstrip()
if line[0] == '#':
continue
(trans_id, length) = line.split("\t")
if re.match("^\d+$", length):
trans_id_to_length[trans_id] = int(length)
else:
print("Warning - ignoring line: [{0}] since not parsing length value as number".format(line), file=sys.stderr)
return trans_id_to_length
def parse_gene_trans_map(gene_trans_map_file):
trans_to_gene_id = {}
with open(gene_trans_map_file) as f:
for line in f:
line = line.rstrip()
(gene_id, trans_id) = line.split("\t")
trans_to_gene_id[trans_id] = gene_id;
return trans_to_gene_id
####################
if __name__ == "__main__":
main()