# Copyright (C) 2002, Thomas Hamelryck (thamelry@binf.ku.dk)
# This code is part of the Biopython distribution and governed by its
# license.  Please see the LICENSE file that should have been included
# as part of this package.

"""Polypeptide-related classes (construction and representation).

Simple example with multiple chains,

    >>> from Bio.PDB.PDBParser import PDBParser
    >>> from Bio.PDB.Polypeptide import PPBuilder
    >>> structure = PDBParser().get_structure('2BEG', 'PDB/2BEG.pdb')
    >>> ppb=PPBuilder()
    >>> for pp in ppb.build_peptides(structure):
    ...     print(pp.get_sequence())
    LVFFAEDVGSNKGAIIGLMVGGVVIA
    LVFFAEDVGSNKGAIIGLMVGGVVIA
    LVFFAEDVGSNKGAIIGLMVGGVVIA
    LVFFAEDVGSNKGAIIGLMVGGVVIA
    LVFFAEDVGSNKGAIIGLMVGGVVIA

Example with non-standard amino acids using HETATM lines in the PDB file,
in this case selenomethionine (MSE):

    >>> from Bio.PDB.PDBParser import PDBParser
    >>> from Bio.PDB.Polypeptide import PPBuilder
    >>> structure = PDBParser().get_structure('1A8O', 'PDB/1A8O.pdb')
    >>> ppb=PPBuilder()
    >>> for pp in ppb.build_peptides(structure):
    ...     print(pp.get_sequence())
    DIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNW
    TETLLVQNANPDCKTILKALGPGATLEE
    TACQG

If you want to, you can include non-standard amino acids in the peptides:

    >>> for pp in ppb.build_peptides(structure, aa_only=False):
    ...     print(pp.get_sequence())
    ...     print("%s %s" % (pp.get_sequence()[0], pp[0].get_resname()))
    ...     print("%s %s" % (pp.get_sequence()[-7], pp[-7].get_resname()))
    ...     print("%s %s" % (pp.get_sequence()[-6], pp[-6].get_resname()))
    MDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQNANPDCKTILKALGPGATLEEMMTACQG
    M MSE
    M MSE
    M MSE

In this case the selenomethionines (the first and also seventh and sixth from
last residues) have been shown as M (methionine) by the get_sequence method.
"""

from __future__ import print_function
from Bio._py3k import basestring

import warnings

from Bio.Alphabet import generic_protein
from Bio.Data import SCOPData
from Bio.Seq import Seq
from Bio.PDB.PDBExceptions import PDBException
from Bio.PDB.Vector import calc_dihedral, calc_angle


standard_aa_names = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS",
                     "LEU", "MET", "ASN", "PRO", "GLN", "ARG", "SER", "THR", "VAL",
                     "TRP", "TYR"]


aa1 = "ACDEFGHIKLMNPQRSTVWY"
aa3 = standard_aa_names

d1_to_index = {}
dindex_to_1 = {}
d3_to_index = {}
dindex_to_3 = {}

# Create some lookup tables
for i in range(0, 20):
    n1 = aa1[i]
    n3 = aa3[i]
    d1_to_index[n1] = i
    dindex_to_1[i] = n1
    d3_to_index[n3] = i
    dindex_to_3[i] = n3


def index_to_one(index):
    """Index to corresponding one letter amino acid name.

    >>> index_to_one(0)
    'A'
    >>> index_to_one(19)
    'Y'
    """
    return dindex_to_1[index]


def one_to_index(s):
    """One letter code to index.

    >>> one_to_index('A')
    0
    >>> one_to_index('Y')
    19
    """
    return d1_to_index[s]


def index_to_three(i):
    """Index to corresponding three letter amino acid name.

    >>> index_to_three(0)
    'ALA'
    >>> index_to_three(19)
    'TYR'
    """
    return dindex_to_3[i]


def three_to_index(s):
    """Three letter code to index.

    >>> three_to_index('ALA')
    0
    >>> three_to_index('TYR')
    19
    """
    return d3_to_index[s]


def three_to_one(s):
    """Three letter code to one letter code.

    >>> three_to_one('ALA')
    'A'
    >>> three_to_one('TYR')
    'Y'

    For non-standard amino acids, you get a KeyError:

    >>> three_to_one('MSE')
    Traceback (most recent call last):
       ...
    KeyError: 'MSE'
    """
    i = d3_to_index[s]
    return dindex_to_1[i]


def one_to_three(s):
    """One letter code to three letter code.

    >>> one_to_three('A')
    'ALA'
    >>> one_to_three('Y')
    'TYR'
    """
    i = d1_to_index[s]
    return dindex_to_3[i]


def is_aa(residue, standard=False):
    """Return True if residue object/string is an amino acid.

    @param residue: a L{Residue} object OR a three letter amino acid code
    @type residue: L{Residue} or string

    @param standard: flag to check for the 20 AA (default false)
    @type standard: boolean

    >>> is_aa('ALA')
    True

    Known three letter codes for modified amino acids are supported,

    >>> is_aa('FME')
    True
    >>> is_aa('FME', standard=True)
    False
    """
    # TODO - What about special cases like XXX, can they appear in PDB files?
    if not isinstance(residue, basestring):
        residue = residue.get_resname()
    residue = residue.upper()
    if standard:
        return residue in d3_to_index
    else:
        return residue in SCOPData.protein_letters_3to1


class Polypeptide(list):
    """A polypeptide is simply a list of L{Residue} objects."""
    def get_ca_list(self):
        """Get list of C-alpha atoms in the polypeptide.

        @return: the list of C-alpha atoms
        @rtype: [L{Atom}, L{Atom}, ...]
        """
        ca_list = []
        for res in self:
            ca = res["CA"]
            ca_list.append(ca)
        return ca_list

    def get_phi_psi_list(self):
        """Return the list of phi/psi dihedral angles."""
        ppl = []
        lng = len(self)
        for i in range(0, lng):
            res = self[i]
            try:
                n = res['N'].get_vector()
                ca = res['CA'].get_vector()
                c = res['C'].get_vector()
            except Exception:
                # Some atoms are missing
                # Phi/Psi cannot be calculated for this residue
                ppl.append((None, None))
                res.xtra["PHI"] = None
                res.xtra["PSI"] = None
                continue
            # Phi
            if i > 0:
                rp = self[i - 1]
                try:
                    cp = rp['C'].get_vector()
                    phi = calc_dihedral(cp, n, ca, c)
                except Exception:
                    phi = None
            else:
                # No phi for residue 0!
                phi = None
            # Psi
            if i < (lng - 1):
                rn = self[i + 1]
                try:
                    nn = rn['N'].get_vector()
                    psi = calc_dihedral(n, ca, c, nn)
                except Exception:
                    psi = None
            else:
                # No psi for last residue!
                psi = None
            ppl.append((phi, psi))
            # Add Phi/Psi to xtra dict of residue
            res.xtra["PHI"] = phi
            res.xtra["PSI"] = psi
        return ppl

    def get_tau_list(self):
        """List of tau torsions angles for all 4 consecutive Calpha atoms."""
        ca_list = self.get_ca_list()
        tau_list = []
        for i in range(0, len(ca_list) - 3):
            atom_list = (ca_list[i], ca_list[i + 1], ca_list[i + 2], ca_list[i + 3])
            v1, v2, v3, v4 = [a.get_vector() for a in atom_list]
            tau = calc_dihedral(v1, v2, v3, v4)
            tau_list.append(tau)
            # Put tau in xtra dict of residue
            res = ca_list[i + 2].get_parent()
            res.xtra["TAU"] = tau
        return tau_list

    def get_theta_list(self):
        """List of theta angles for all 3 consecutive Calpha atoms."""
        theta_list = []
        ca_list = self.get_ca_list()
        for i in range(0, len(ca_list) - 2):
            atom_list = (ca_list[i], ca_list[i + 1], ca_list[i + 2])
            v1, v2, v3 = [a.get_vector() for a in atom_list]
            theta = calc_angle(v1, v2, v3)
            theta_list.append(theta)
            # Put tau in xtra dict of residue
            res = ca_list[i + 1].get_parent()
            res.xtra["THETA"] = theta
        return theta_list

    def get_sequence(self):
        """Return the AA sequence as a Seq object.

        @return: polypeptide sequence
        @rtype: L{Seq}
        """
        s = ""
        for res in self:
            s += SCOPData.protein_letters_3to1.get(res.get_resname(), 'X')
        seq = Seq(s, generic_protein)
        return seq

    def __repr__(self):
        """Return string representation of the polypeptide.

        Return <Polypeptide start=START end=END>, where START
        and END are sequence identifiers of the outer residues.
        """
        start = self[0].get_id()[1]
        end = self[-1].get_id()[1]
        s = "<Polypeptide start=%s end=%s>" % (start, end)
        return s


class _PPBuilder(object):
    """Base class to extract polypeptides.

    It checks if two consecutive residues in a chain are connected.
    The connectivity test is implemented by a subclass.

    This assumes you want both standard and non-standard amino acids.
    """
    def __init__(self, radius):
        """
        @param radius: distance
        @type radius: float
        """
        self.radius = radius

    def _accept(self, residue, standard_aa_only):
        """Check if the residue is an amino acid (PRIVATE)."""
        if is_aa(residue, standard=standard_aa_only):
            return True
        elif not standard_aa_only and "CA" in residue.child_dict:
            # It has an alpha carbon...
            # We probably need to update the hard coded list of
            # non-standard residues, see function is_aa for details.
            warnings.warn("Assuming residue %s is an unknown modified "
                          "amino acid" % residue.get_resname())
            return True
        else:
            # not a standard AA so skip
            return False

    def build_peptides(self, entity, aa_only=1):
        """Build and return a list of Polypeptide objects.

        @param entity: polypeptides are searched for in this object
        @type entity: L{Structure}, L{Model} or L{Chain}

        @param aa_only: if 1, the residue needs to be a standard AA
        @type aa_only: int
        """
        is_connected = self._is_connected
        accept = self._accept
        level = entity.get_level()
        # Decide which entity we are dealing with
        if level == "S":
            model = entity[0]
            chain_list = model.get_list()
        elif level == "M":
            chain_list = entity.get_list()
        elif level == "C":
            chain_list = [entity]
        else:
            raise PDBException("Entity should be Structure, Model or Chain.")
        pp_list = []
        for chain in chain_list:
            chain_it = iter(chain)
            try:
                prev_res = next(chain_it)
                while not accept(prev_res, aa_only):
                    prev_res = next(chain_it)
            except StopIteration:
                # No interesting residues at all in this chain
                continue
            pp = None
            for next_res in chain_it:
                if accept(prev_res, aa_only) \
                and accept(next_res, aa_only) \
                and is_connected(prev_res, next_res):
                    if pp is None:
                        pp = Polypeptide()
                        pp.append(prev_res)
                        pp_list.append(pp)
                    pp.append(next_res)
                else:
                    # Either too far apart, or one of the residues is unwanted.
                    # End the current peptide
                    pp = None
                prev_res = next_res
        return pp_list


class CaPPBuilder(_PPBuilder):
    """Use CA--CA distance to find polypeptides."""
    def __init__(self, radius=4.3):
        _PPBuilder.__init__(self, radius)

    def _is_connected(self, prev_res, next_res):
        for r in [prev_res, next_res]:
            if not r.has_id("CA"):
                return False
        n = next_res["CA"]
        p = prev_res["CA"]
        # Unpack disordered
        if n.is_disordered():
            nlist = n.disordered_get_list()
        else:
            nlist = [n]
        if p.is_disordered():
            plist = p.disordered_get_list()
        else:
            plist = [p]
        for nn in nlist:
            for pp in plist:
                if (nn - pp) < self.radius:
                    return True
        return False


class PPBuilder(_PPBuilder):
    """Use C--N distance to find polypeptides."""
    def __init__(self, radius=1.8):
        _PPBuilder.__init__(self, radius)

    def _is_connected(self, prev_res, next_res):
        if not prev_res.has_id("C"):
            return False
        if not next_res.has_id("N"):
            return False
        test_dist = self._test_dist
        c = prev_res["C"]
        n = next_res["N"]
        # Test all disordered atom positions!
        if c.is_disordered():
            clist = c.disordered_get_list()
        else:
            clist = [c]
        if n.is_disordered():
            nlist = n.disordered_get_list()
        else:
            nlist = [n]
        for nn in nlist:
            for cc in clist:
                # To form a peptide bond, N and C must be
                # within radius and have the same altloc
                # identifier or one altloc blank
                n_altloc = nn.get_altloc()
                c_altloc = cc.get_altloc()
                if n_altloc == c_altloc or n_altloc == " " or c_altloc == " ":
                    if test_dist(nn, cc):
                        # Select the disordered atoms that
                        # are indeed bonded
                        if c.is_disordered():
                            c.disordered_select(c_altloc)
                        if n.is_disordered():
                            n.disordered_select(n_altloc)
                        return True
        return False

    def _test_dist(self, c, n):
        """Return 1 if distance between atoms<radius (PRIVATE)."""
        if (c - n) < self.radius:
            return 1
        else:
            return 0


if __name__ == "__main__":
    import sys
    from Bio.PDB.PDBParser import PDBParser

    p = PDBParser(PERMISSIVE=True)

    s = p.get_structure("scr", sys.argv[1])

    ppb = PPBuilder()

    print("C-N")
    for pp in ppb.build_peptides(s):
        print(pp.get_sequence())
    for pp in ppb.build_peptides(s[0]):
        print(pp.get_sequence())
    for pp in ppb.build_peptides(s[0]["A"]):
        print(pp.get_sequence())

    for pp in ppb.build_peptides(s):
        for phi, psi in pp.get_phi_psi_list():
            print("%f %f" % (phi, psi))

    ppb = CaPPBuilder()

    print("CA-CA")
    for pp in ppb.build_peptides(s):
        print(pp.get_sequence())
    for pp in ppb.build_peptides(s[0]):
        print(pp.get_sequence())
    for pp in ppb.build_peptides(s[0]["A"]):
        print(pp.get_sequence())
