# Copyright 2013 by Anthony Mathelier and David Arenillas. All rights reserved.
# This code is part of the Biopython distribution and governed by its
# license. Please see the LICENSE file that should have been included
# as part of this package.

"""JASPAR2014 module."""

from Bio.Seq import Seq
from Bio.Alphabet.IUPAC import unambiguous_dna as dna
import re
import math

from Bio._py3k import range

from Bio import motifs


class Motif(motifs.Motif):
    """A subclass of Bio.motifs.Motif used to represent a JASPAR profile.

    Additional metadata information are stored if available. The metadata
    availability depends on the source of the JASPAR motif (a 'pfm' format
    file, a 'jaspar' format file or a JASPAR database).
    """

    def __init__(self, matrix_id, name, alphabet=dna, instances=None,
                 counts=None, collection=None, tf_class=None, tf_family=None,
                 species=None, tax_group=None, acc=None, data_type=None,
                 medline=None, pazar_id=None, comment=None):
        """Construct a JASPAR Motif instance."""

        motifs.Motif.__init__(self, alphabet, instances, counts)
        self.name = name
        self.matrix_id = matrix_id
        self.collection = collection
        self.tf_class = tf_class
        self.tf_family = tf_family
        # May have multiple so species is a list.
        # The species are actually specified as
        # taxonomy IDs.
        self.species = species
        self.tax_group = tax_group
        self.acc = acc              # May have multiple so acc is a list.
        self.data_type = data_type
        self.medline = medline
        self.pazar_id = pazar_id
        self.comment = comment

    @property
    def base_id(self):
        """Return the JASPAR base matrix ID."""
        (base_id, __) = split_jaspar_id(self.matrix_id)
        return base_id

    @property
    def version(self):
        """Return the JASPAR matrix version."""
        (__, version) = split_jaspar_id(self.matrix_id)
        return version

    def __str__(self):
        """Return a string represention of the JASPAR profile.

        We choose to provide only the filled metadata information.
        """
        tf_name_str = "TF name\t{0}\n".format(self.name)
        matrix_id_str = "Matrix ID\t{0}\n".format(self.matrix_id)
        the_string = "".join([tf_name_str, matrix_id_str])
        if self.collection:
            collection_str = "Collection\t{0}\n".format(self.collection)
            the_string = "".join([the_string, collection_str])
        if self.tf_class:
            tf_class_str = "TF class\t{0}\n".format(self.tf_class)
            the_string = "".join([the_string, tf_class_str])
        if self.tf_family:
            tf_family_str = "TF family\t{0}\n".format(self.tf_family)
            the_string = "".join([the_string, tf_family_str])
        if self.species:
            species_str = "Species\t{0}\n".format(",".join(self.species))
            the_string = "".join([the_string, species_str])
        if self.tax_group:
            tax_group_str = "Taxonomic group\t{0}\n".format(self.tax_group)
            the_string = "".join([the_string, tax_group_str])
        if self.acc:
            acc_str = "Accession\t{0}\n".format(self.acc)
            the_string = "".join([the_string, acc_str])
        if self.data_type:
            data_type_str = "Data type used\t{0}\n".format(self.data_type)
            the_string = "".join([the_string, data_type_str])
        if self.medline:
            medline_str = "Medline\t{0}\n".format(self.medline)
            the_string = "".join([the_string, medline_str])
        if self.pazar_id:
            pazar_id_str = "PAZAR ID\t{0}\n".format(self.pazar_id)
            the_string = "".join([the_string, pazar_id_str])
        if self.comment:
            comment_str = "Comments\t{0}\n".format(self.comment)
            the_string = "".join([the_string, comment_str])
        matrix_str = "Matrix:\n{0}\n\n".format(self.counts)
        the_string = "".join([the_string, matrix_str])
        return the_string

    def __hash__(self):
        """Return the hash key corresponding to the JASPAR profile.

        :note: We assume the unicity of matrix IDs
        """
        return self.matrix_id.__hash__()

    def __eq__(self, other):
        return self.matrix_id == other.matrix_id


class Record(list):
    """Represent a list of jaspar motifs.

    Attributes:

     - version: The JASPAR version used

    """

    def __init__(self):
        self.version = None

    def __str__(self):
        return "\n".join(str(the_motif) for the_motif in self)

    def to_dict(self):
        """Return the list of matrices as a dictionary of matrices."""
        dic = {}
        for motif in self:
            dic[motif.matrix_id] = motif
        return dic


def read(handle, format):
    """Read motif(s) from a file in one of several different JASPAR formats.

    Return the record of PFM(s).
    Call the appropriate routine based on the format passed.
    """
    format = format.lower()
    if format == "pfm":
        record = _read_pfm(handle)
        return record
    elif format == "sites":
        record = _read_sites(handle)
        return record
    elif format == "jaspar":
        record = _read_jaspar(handle)
        return record
    else:
        raise ValueError("Unknown JASPAR format %s" % format)


def write(motifs, format):
    """Return the representation of motifs in "pfm" or "jaspar" format."""
    letters = "ACGT"
    lines = []
    if format == 'pfm':
        motif = motifs[0]
        counts = motif.counts
        for letter in letters:
            terms = ["{0:6.2f}".format(value) for value in counts[letter]]
            line = "{0}\n".format(" ".join(terms))
            lines.append(line)
    elif format == 'jaspar':
        for m in motifs:
            counts = m.counts
            try:
                matrix_id = m.matrix_id
            except AttributeError:
                matrix_id = None
            line = ">{0} {1}\n".format(matrix_id, m.name)
            lines.append(line)
            for letter in letters:
                terms = ["{0:6.2f}".format(value) for value in counts[letter]]
                line = "{0} [{1}]\n".format(letter, " ".join(terms))
                lines.append(line)
    else:
        raise ValueError("Unknown JASPAR format %s" % format)

    # Finished; glue the lines together
    text = "".join(lines)

    return text


def _read_pfm(handle):
    """Read the motif from a JASPAR .pfm file (PRIVATE)."""
    alphabet = dna
    counts = {}

    letters = "ACGT"
    for letter, line in zip(letters, handle):
        words = line.split()
        # if there is a letter in the beginning, ignore it
        if words[0] == letter:
            words = words[1:]
        counts[letter] = [float(x) for x in words]

    motif = Motif(matrix_id=None, name=None, alphabet=alphabet, counts=counts)
    motif.mask = "*" * motif.length
    record = Record()
    record.append(motif)

    return record


def _read_sites(handle):
    """Read the motif from JASPAR .sites file (PRIVATE)."""
    alphabet = dna
    instances = []

    for line in handle:
        if not line.startswith(">"):
            break
        # line contains the header ">...."
        # now read the actual sequence
        line = next(handle)
        instance = ""
        for c in line.strip():
            if c == c.upper():
                instance += c
        instance = Seq(instance, alphabet)
        instances.append(instance)

    instances = motifs.Instances(instances, alphabet)
    motif = Motif(
        matrix_id=None, name=None, alphabet=alphabet, instances=instances
    )
    motif.mask = "*" * motif.length
    record = Record()
    record.append(motif)

    return record


def _read_jaspar(handle):
    """Read motifs from a JASPAR formatted file (PRIVATE).

    Format is one or more records of the form, e.g.::

      - JASPAR 2010 matrix_only format::

                >MA0001.1 AGL3
                A  [ 0  3 79 40 66 48 65 11 65  0 ]
                C  [94 75  4  3  1  2  5  2  3  3 ]
                G  [ 1  0  3  4  1  0  5  3 28 88 ]
                T  [ 2 19 11 50 29 47 22 81  1  6 ]

      - JASPAR 2010-2014 PFMs format::

                >MA0001.1 AGL3
                0	3	79	40	66	48	65	11	65	0
                94	75	4	3	1	2	5	2	3	3
                1	0	3	4	1	0	5	3	28	88
                2	19	11	50	29	47	22	81	1	6

    """

    alphabet = dna
    counts = {}

    record = Record()

    head_pat = re.compile(r"^>\s*(\S+)(\s+(\S+))?")
    row_pat_long = re.compile(r"\s*([ACGT])\s*\[\s*(.*)\s*\]")
    row_pat_short = re.compile(r"\s*(.+)\s*")

    identifier = None
    name = None
    row_count = 0
    nucleotides = ['A', 'C', 'G', 'T']
    for line in handle:
        line = line.strip()

        head_match = head_pat.match(line)
        row_match_long = row_pat_long.match(line)
        row_match_short = row_pat_short.match(line)

        if head_match:
            identifier = head_match.group(1)
            if head_match.group(3):
                name = head_match.group(3)
            else:
                name = identifier
        elif row_match_long:
            (letter, counts_str) = row_match_long.group(1, 2)
            words = counts_str.split()
            counts[letter] = [float(x) for x in words]
            row_count += 1
            if row_count == 4:
                record.append(Motif(identifier, name, alphabet=alphabet,
                                    counts=counts))
                identifier = None
                name = None
                counts = {}
                row_count = 0
        elif row_match_short:
            words = row_match_short.group(1).split()
            counts[nucleotides[row_count]] = [float(x) for x in words]
            row_count += 1
            if row_count == 4:
                record.append(Motif(identifier, name, alphabet=alphabet,
                                    counts=counts))
                identifier = None
                name = None
                counts = {}
                row_count = 0

    return record


def calculate_pseudocounts(motif):
    alphabet = motif.alphabet
    background = motif.background

    # It is possible to have unequal column sums so use the average
    # number of instances.
    total = 0
    for i in range(motif.length):
        total += sum(float(motif.counts[letter][i])
                     for letter in alphabet.letters)

    avg_nb_instances = total / motif.length
    sq_nb_instances = math.sqrt(avg_nb_instances)

    if background:
        background = dict(background)
    else:
        background = dict.fromkeys(sorted(alphabet.letters), 1.0)

    total = sum(background.values())
    pseudocounts = {}

    for letter in alphabet.letters:
        background[letter] /= total
        pseudocounts[letter] = sq_nb_instances * background[letter]

    return pseudocounts


def split_jaspar_id(id):
    """Utility function to split a JASPAR matrix ID into its component.

    Components are base ID and version number, e.g. 'MA0047.2' is returned as
    ('MA0047', 2).
    """

    id_split = id.split('.')

    base_id = None
    version = None
    if len(id_split) == 2:
        base_id = id_split[0]
        version = id_split[1]
    else:
        base_id = id

    return (base_id, version)
