#!/usr/bin/python3
# encoding: utf-8

from __future__ import (absolute_import, division,
                        print_function, unicode_literals)
import os, sys, re
import logging
import argparse
import collections

logging.basicConfig(stream=sys.stderr, level=logging.INFO)
logger = logging.getLogger(__file__)
 
def main():

    parser = argparse.ArgumentParser(description="estimates gene length as isoform lengths weighted by TPM expression values")
    
    parser.add_argument("--gene_trans_map", dest="gene_trans_map_file", type=str, default="",
                        required=True, help="gene-to-transcript mapping file, format: gene_id(tab)transcript_id")

    parser.add_argument("--trans_lengths", dest="trans_lengths_file", type=str, required=True,
                        help="transcript length file, format: trans_id(tab)length")

    parser.add_argument("--TPM_matrix", dest="TPM_matrix_file", type=str, default="",
                        required=True, help="isoform TPM expression matrix")

    parser.add_argument("--debug", required=False, action="store_true", default=False, help="debug mode")


    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)
    

    trans_to_gene_id_dict = parse_gene_trans_map(args.gene_trans_map_file)

    trans_lengths_dict = parse_trans_lengths_file(args.trans_lengths_file)

    trans_to_TPM_vals_dict = parse_TPM_matrix(args.TPM_matrix_file)

    weighted_gene_lengths = compute_weighted_gene_lengths(trans_to_gene_id_dict,
                                                          trans_lengths_dict,
                                                          trans_to_TPM_vals_dict)

    print("#gene_id\tlength")
    for gene_id,length in weighted_gene_lengths.items():
        print("\t".join([gene_id,str(length)]))


    sys.exit(0)
    


def compute_weighted_gene_lengths(trans_to_gene_id_dict, trans_lengths_dict, trans_to_TPM_vals_dict):

    gene_id_to_trans_list = collections.defaultdict(list)

    gene_id_to_length = {}

    pseudocount = 1

    for trans_id,gene_id in trans_to_gene_id_dict.items():
        gene_id_to_trans_list[gene_id].append(trans_id)

    for gene_id,trans_list in gene_id_to_trans_list.items():

        if len(trans_list) == 1:

            gene_id_to_length[gene_id] = trans_lengths_dict[ trans_list[0] ]
        else:
            
            sum_length_x_expr = 0
            sum_expr = 0

            trans_expr_lengths = []

            for trans_id in trans_list:
                trans_len = trans_lengths_dict[trans_id]
                expr_vals = trans_to_TPM_vals_dict[trans_id]
                trans_sum_expr = sum(expr_vals) + pseudocount

                trans_expr_lengths.append((trans_len, trans_sum_expr))

                sum_length_x_expr += trans_sum_expr * trans_len
                sum_expr += trans_sum_expr


            weighted_gene_length = sum_length_x_expr / sum_expr
            gene_id_to_length[gene_id] = int(round(weighted_gene_length))

            logger.debug("Computing weighted length of {0}: {1} => {2}".format(gene_id,
                                                                            trans_expr_lengths,
                                                                            weighted_gene_length))
            
    return gene_id_to_length
            


def parse_TPM_matrix(TPM_matrix_file):

    trans_to_TPM_vals_dict = {}
    
    with open(TPM_matrix_file) as f:
        header = next(f)
        for line in f:
            line = line.rstrip()
            vals = line.split("\t")
            trans_id = vals[0]
            expr_vals_list = vals[1:]
            expr_vals_list = [float(x) for x in expr_vals_list]
            trans_to_TPM_vals_dict[trans_id] = expr_vals_list

    return trans_to_TPM_vals_dict


def parse_trans_lengths_file(trans_lengths_file):

    trans_id_to_length = {}

    with open(trans_lengths_file) as f:
        for line in f:
            line = line.rstrip()
            if line[0] == '#':
                continue
            
            (trans_id, length) = line.split("\t")

            if re.match("^\d+$", length):
                trans_id_to_length[trans_id] = int(length)
            else:
                print("Warning - ignoring line: [{0}] since not parsing length value as number".format(line), file=sys.stderr)

    return trans_id_to_length



def parse_gene_trans_map(gene_trans_map_file):

    trans_to_gene_id = {}

    with open(gene_trans_map_file) as f:
        for line in f:
            line = line.rstrip()
            (gene_id, trans_id) = line.split("\t")

            trans_to_gene_id[trans_id] = gene_id;
        

    return trans_to_gene_id


 
####################
 
if __name__ == "__main__":
    main()
