#!/usr/bin/python

"""
Returns a bed-like translation of a CDS in which each record corresponds to
a single site in the CDS and includes additional fields for site degenaracy,
position ind CDS, and amino acid encoded.

usage: %prog nibdir genefile [options]
    -o, --outfile=o:      output file
    -f, --format=f:       format bed (default), or gtf|gff
    -a, --allpositions: 1st, 2nd and 3rd positions are evaluated for degeneracy given the sequence at the other two positions.  Many 1d sites in 1st codon positions become 2d sites when considered this way.
    -n, --include_name: include the 'name' or 'id' field from the source file on every line of output
"""
from __future__ import print_function

import os
import re
import string
import sys

from bx.cookbook import doc_optparse
from bx.gene_reader import CDSReader
from bx.seq import nib

GENETIC_CODE = """
TTT (Phe/F)Phenylalanine
TTC (Phe/F)Phenylalanine
TTA (Leu/L)Leucine
TTG (Leu/L)Leucine, Start
TCT (Ser/S)Serine
TCC (Ser/S)Serine
TCA (Ser/S)Serine
TCG (Ser/S)Serine
TAT (Tyr/Y)Tyrosine
TAC (Tyr/Y)Tyrosine
TAA Ochre (Stop)
TAG Amber (Stop)
TGT (Cys/C)Cysteine
TGC (Cys/C)Cysteine
TGA Opal (Stop)
TGG (Trp/W)Tryptophan
CTT (Leu/L)Leucine
CTC (Leu/L)Leucine
CTA (Leu/L)Leucine
CTG (Leu/L)Leucine, Start
CCT (Pro/P)Proline
CCC (Pro/P)Proline
CCA (Pro/P)Proline
CCG (Pro/P)Proline
CAT (His/H)Histidine
CAC (His/H)Histidine
CAA (Gln/Q)Glutamine
CAG (Gln/Q)Glutamine
CGT (Arg/R)Arginine
CGC (Arg/R)Arginine
CGA (Arg/R)Arginine
CGG (Arg/R)Arginine
ATT (Ile/I)Isoleucine, Start2
ATC (Ile/I)Isoleucine
ATA (Ile/I)Isoleucine
ATG (Met/M)Methionine, Start1
ACT (Thr/T)Threonine
ACC (Thr/T)Threonine
ACA (Thr/T)Threonine
ACG (Thr/T)Threonine
AAT (Asn/N)Asparagine
AAC (Asn/N)Asparagine
AAA (Lys/K)Lysine
AAG (Lys/K)Lysine
AGT (Ser/S)Serine
AGC (Ser/S)Serine
AGA (Arg/R)Arginine
AGG (Arg/R)Arginine
GTT (Val/V)Valine
GTC (Val/V)Valine
GTA (Val/V)Valine
GTG (Val/V)Valine, Start2
GCT (Ala/A)Alanine
GCC (Ala/A)Alanine
GCA (Ala/A)Alanine
GCG (Ala/A)Alanine
GAT (Asp/D)Aspartic acid
GAC (Asp/D)Aspartic acid
GAA (Glu/E)Glutamic acid
GAG (Glu/E)Glutamic acid
GGT (Gly/G)Glycine
GGC (Gly/G)Glycine
GGA (Gly/G)Glycine
GGG (Gly/G)Glycine
"""


def translate(codon, genetic_code):
    c1, c2, c3 = codon
    return genetic_code[c1][c2][c3]


""" parse the doc string to hash the genetic code"""
GEN_CODE = {}
for line in GENETIC_CODE.split('\n'):
    if line.strip() == '':
        continue
    f = re.split(r'\s|\(|\)|\/', line)
    codon = f[0]
    c1, c2, c3 = codon
    aminoacid = f[3]
    if c1 not in GEN_CODE:
        GEN_CODE[c1] = {}
    if c2 not in GEN_CODE[c1]:
        GEN_CODE[c1][c2] = {}

    GEN_CODE[c1][c2][c3] = aminoacid


def getnib(nibdir):
    seqs = {}
    for nibf in os.listdir(nibdir):
        if not nibf.endswith('.nib'):
            continue
        chr = nibf.replace('.nib', '')
        file = os.path.join(nibdir, nibf)
        seqs[chr] = nib.NibFile(open(file))

    return seqs


REVMAP = string.maketrans("ACGTacgt", "TGCAtgca")


def revComp(seq):
    return seq[::-1].translate(REVMAP)


def Comp(seq):
    return seq.translate(REVMAP)


def main():

    options, args = doc_optparse.parse(__doc__)
    try:
        if options.outfile:
            out = open(options.outfile, "w")
        else:
            out = sys.stdout
        if options.format:
            format = options.format
        else:
            format = 'bed'

        allpositions = bool(options.allpositions)
        include_name = bool(options.include_name)
        nibdir = args[0]
        bedfile = args[1]
    except Exception:
        doc_optparse.exit()

    nibs = getnib(nibdir)

    for chrom, strand, cds_exons, name in CDSReader(open(bedfile), format=format):

        cds_seq = ''

        # genome_seq_index maps the position in CDS to position on the genome
        genome_seq_index = []
        for (c_start, c_end) in cds_exons:
            cds_seq += nibs[chrom].get(c_start, c_end-c_start)
            for i in range(c_start, c_end):
                genome_seq_index.append(i)

        cds_seq = cds_seq.upper()

        if strand == '+':
            frsts = range(0, len(cds_seq), 3)
            offsign = 1
        else:
            cds_seq = Comp(cds_seq)
            frsts = range(2, len(cds_seq), 3)
            offsign = -1

        offone = 1 * offsign
        offtwo = 2 * offsign

        all = ['A', 'C', 'G', 'T']

        for first_pos in frsts:
            c1 = first_pos
            c2 = first_pos + offone
            c3 = first_pos + offtwo
            try:
                assert c3 < len(cds_seq)
            except AssertionError:
                print("out of sequence at %d for %s, %d" % (c3, chrom, genome_seq_index[first_pos]), file=sys.stderr)
                continue
            codon = cds_seq[c1], cds_seq[c2], cds_seq[c3]
            aa = translate(codon, GEN_CODE)
            degeneracy3 = str(list(GEN_CODE[codon[0]][codon[1]].values()).count(aa)) + "d"

            if not include_name:
                name_text = ''
            else:
                name_text = name.replace(' ', '_')

            if allpositions:
                try:
                    degeneracy1 = str([GEN_CODE[k][codon[1]][codon[2]] for k in all].count(aa)) + "d"
                    degeneracy2 = str([GEN_CODE[codon[0]][k][codon[2]] for k in all].count(aa)) + "d"
                except TypeError as s:
                    print(list(GEN_CODE.values()), file=sys.stderr)
                    raise TypeError(s)

                if strand == '+':
                    print(chrom, genome_seq_index[c1], genome_seq_index[c1] + 1, cds_seq[c1], degeneracy1, aa, name_text, file=out)
                    print(chrom, genome_seq_index[c2], genome_seq_index[c2] + 1, cds_seq[c2], degeneracy2, aa, name_text, file=out)
                    print(chrom, genome_seq_index[c3], genome_seq_index[c3] + 1, cds_seq[c3], degeneracy3, aa, name_text, file=out)
                else:
                    print(chrom, genome_seq_index[c3], genome_seq_index[c3] + 1, cds_seq[c3], degeneracy3, aa, name_text, file=out)
                    print(chrom, genome_seq_index[c2], genome_seq_index[c2] + 1, cds_seq[c2], degeneracy2, aa, name_text, file=out)
                    print(chrom, genome_seq_index[c1], genome_seq_index[c1] + 1, cds_seq[c1], degeneracy1, aa, name_text, file=out)
            else:
                if strand == '+':
                    for b in c1, c2:
                        print(chrom, genome_seq_index[b], genome_seq_index[b] + 1, cds_seq[b], "1d", aa, name_text, file=out)
                    print(chrom, genome_seq_index[c3], genome_seq_index[c3] + 1, cds_seq[c3], degeneracy3, aa, name_text, file=out)
                else:
                    print(chrom, genome_seq_index[c3], genome_seq_index[c3] + 1, cds_seq[c3], degeneracy3, aa, name_text, file=out)
                    for b in c2, c1:
                        print(chrom, genome_seq_index[b], genome_seq_index[b] + 1, cds_seq[b], "1d", aa, name_text, file=out)
    out.close()


if __name__ == '__main__':
    main()
