#!/usr/bin/python

"""
Returns a bed-like translation of a CDS in which each record corresponds to
a single site in the CDS and includes additional fields for site degenaracy,
position ind CDS, and amino acid encoded.

usage: %prog nibdir genefile [options]
    -o, --outfile=o:      output file
    -f, --format=f:       format bed (default), or gtf|gff
    -a, --allpositions: 1st, 2nd and 3rd positions are evaluated for degeneracy given the sequence at the other two positions.  Many 1d sites in 1st codon positions become 2d sites when considered this way.
    -n, --include_name: include the 'name' or 'id' field from the source file on every line of output
"""

import os
import re
import sys

from bx.cookbook import doc_optparse
from bx.gene_reader import CDSReader
from bx.seq import nib

GENETIC_CODE = """
TTT (Phe/F)Phenylalanine
TTC (Phe/F)Phenylalanine
TTA (Leu/L)Leucine
TTG (Leu/L)Leucine, Start
TCT (Ser/S)Serine
TCC (Ser/S)Serine
TCA (Ser/S)Serine
TCG (Ser/S)Serine
TAT (Tyr/Y)Tyrosine
TAC (Tyr/Y)Tyrosine
TAA Ochre (Stop)
TAG Amber (Stop)
TGT (Cys/C)Cysteine
TGC (Cys/C)Cysteine
TGA Opal (Stop)
TGG (Trp/W)Tryptophan
CTT (Leu/L)Leucine
CTC (Leu/L)Leucine
CTA (Leu/L)Leucine
CTG (Leu/L)Leucine, Start
CCT (Pro/P)Proline
CCC (Pro/P)Proline
CCA (Pro/P)Proline
CCG (Pro/P)Proline
CAT (His/H)Histidine
CAC (His/H)Histidine
CAA (Gln/Q)Glutamine
CAG (Gln/Q)Glutamine
CGT (Arg/R)Arginine
CGC (Arg/R)Arginine
CGA (Arg/R)Arginine
CGG (Arg/R)Arginine
ATT (Ile/I)Isoleucine, Start2
ATC (Ile/I)Isoleucine
ATA (Ile/I)Isoleucine
ATG (Met/M)Methionine, Start1
ACT (Thr/T)Threonine
ACC (Thr/T)Threonine
ACA (Thr/T)Threonine
ACG (Thr/T)Threonine
AAT (Asn/N)Asparagine
AAC (Asn/N)Asparagine
AAA (Lys/K)Lysine
AAG (Lys/K)Lysine
AGT (Ser/S)Serine
AGC (Ser/S)Serine
AGA (Arg/R)Arginine
AGG (Arg/R)Arginine
GTT (Val/V)Valine
GTC (Val/V)Valine
GTA (Val/V)Valine
GTG (Val/V)Valine, Start2
GCT (Ala/A)Alanine
GCC (Ala/A)Alanine
GCA (Ala/A)Alanine
GCG (Ala/A)Alanine
GAT (Asp/D)Aspartic acid
GAC (Asp/D)Aspartic acid
GAA (Glu/E)Glutamic acid
GAG (Glu/E)Glutamic acid
GGT (Gly/G)Glycine
GGC (Gly/G)Glycine
GGA (Gly/G)Glycine
GGG (Gly/G)Glycine
"""


def translate(codon, genetic_code):
    c1, c2, c3 = codon
    return genetic_code[c1][c2][c3]


""" parse the doc string to hash the genetic code"""
GEN_CODE = {}
for line in GENETIC_CODE.split("\n"):
    if line.strip() == "":
        continue
    f = re.split(r"\s|\(|\)|\/", line)
    codon = f[0]
    c1, c2, c3 = codon
    aminoacid = f[3]
    if c1 not in GEN_CODE:
        GEN_CODE[c1] = {}
    if c2 not in GEN_CODE[c1]:
        GEN_CODE[c1][c2] = {}

    GEN_CODE[c1][c2][c3] = aminoacid


def getnib(nibdir):
    seqs = {}
    for nibf in os.listdir(nibdir):
        if not nibf.endswith(".nib"):
            continue
        chr = nibf.replace(".nib", "")
        file = os.path.join(nibdir, nibf)
        seqs[chr] = nib.NibFile(open(file))

    return seqs


REVMAP = str.maketrans("ACGTacgt", "TGCAtgca")


def revComp(seq):
    return seq[::-1].translate(REVMAP)


def Comp(seq):
    return seq.translate(REVMAP)


def main():
    options, args = doc_optparse.parse(__doc__)
    try:
        if options.outfile:
            out = open(options.outfile, "w")
        else:
            out = sys.stdout
        if options.format:
            format = options.format
        else:
            format = "bed"

        allpositions = bool(options.allpositions)
        include_name = bool(options.include_name)
        nibdir = args[0]
        bedfile = args[1]
    except Exception:
        doc_optparse.exit()

    nibs = getnib(nibdir)

    for chrom, strand, cds_exons, name in CDSReader(open(bedfile), format=format):
        cds_seq = ""

        # genome_seq_index maps the position in CDS to position on the genome
        genome_seq_index = []
        for c_start, c_end in cds_exons:
            cds_seq += nibs[chrom].get(c_start, c_end - c_start)
            for i in range(c_start, c_end):
                genome_seq_index.append(i)

        cds_seq = cds_seq.upper()

        if strand == "+":
            frsts = range(0, len(cds_seq), 3)
            offsign = 1
        else:
            cds_seq = Comp(cds_seq)
            frsts = range(2, len(cds_seq), 3)
            offsign = -1

        offone = 1 * offsign
        offtwo = 2 * offsign

        all = ["A", "C", "G", "T"]

        for first_pos in frsts:
            c1 = first_pos
            c2 = first_pos + offone
            c3 = first_pos + offtwo
            try:
                assert c3 < len(cds_seq)
            except AssertionError:
                print("out of sequence at %d for %s, %d" % (c3, chrom, genome_seq_index[first_pos]), file=sys.stderr)
                continue
            codon = cds_seq[c1], cds_seq[c2], cds_seq[c3]
            aa = translate(codon, GEN_CODE)
            degeneracy3 = str(list(GEN_CODE[codon[0]][codon[1]].values()).count(aa)) + "d"

            if not include_name:
                name_text = ""
            else:
                name_text = name.replace(" ", "_")

            if allpositions:
                try:
                    degeneracy1 = str([GEN_CODE[k][codon[1]][codon[2]] for k in all].count(aa)) + "d"
                    degeneracy2 = str([GEN_CODE[codon[0]][k][codon[2]] for k in all].count(aa)) + "d"
                except TypeError as s:
                    print(list(GEN_CODE.values()), file=sys.stderr)
                    raise TypeError(s)

                if strand == "+":
                    print(
                        chrom,
                        genome_seq_index[c1],
                        genome_seq_index[c1] + 1,
                        cds_seq[c1],
                        degeneracy1,
                        aa,
                        name_text,
                        file=out,
                    )
                    print(
                        chrom,
                        genome_seq_index[c2],
                        genome_seq_index[c2] + 1,
                        cds_seq[c2],
                        degeneracy2,
                        aa,
                        name_text,
                        file=out,
                    )
                    print(
                        chrom,
                        genome_seq_index[c3],
                        genome_seq_index[c3] + 1,
                        cds_seq[c3],
                        degeneracy3,
                        aa,
                        name_text,
                        file=out,
                    )
                else:
                    print(
                        chrom,
                        genome_seq_index[c3],
                        genome_seq_index[c3] + 1,
                        cds_seq[c3],
                        degeneracy3,
                        aa,
                        name_text,
                        file=out,
                    )
                    print(
                        chrom,
                        genome_seq_index[c2],
                        genome_seq_index[c2] + 1,
                        cds_seq[c2],
                        degeneracy2,
                        aa,
                        name_text,
                        file=out,
                    )
                    print(
                        chrom,
                        genome_seq_index[c1],
                        genome_seq_index[c1] + 1,
                        cds_seq[c1],
                        degeneracy1,
                        aa,
                        name_text,
                        file=out,
                    )
            else:
                if strand == "+":
                    for b in c1, c2:
                        print(
                            chrom,
                            genome_seq_index[b],
                            genome_seq_index[b] + 1,
                            cds_seq[b],
                            "1d",
                            aa,
                            name_text,
                            file=out,
                        )
                    print(
                        chrom,
                        genome_seq_index[c3],
                        genome_seq_index[c3] + 1,
                        cds_seq[c3],
                        degeneracy3,
                        aa,
                        name_text,
                        file=out,
                    )
                else:
                    print(
                        chrom,
                        genome_seq_index[c3],
                        genome_seq_index[c3] + 1,
                        cds_seq[c3],
                        degeneracy3,
                        aa,
                        name_text,
                        file=out,
                    )
                    for b in c2, c1:
                        print(
                            chrom,
                            genome_seq_index[b],
                            genome_seq_index[b] + 1,
                            cds_seq[b],
                            "1d",
                            aa,
                            name_text,
                            file=out,
                        )
    out.close()


if __name__ == "__main__":
    main()
