# Copyright 2018 by Adhemar Zerlotini. All rights reserved.
#
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.

"""Bio.SearchIO parser for InterProScan XML output formats."""
# for more info: https://github.com/ebi-pf-team/interproscan/wiki/OutputFormats

import sys
import re
from Bio.Alphabet import generic_protein
from Bio.SearchIO._model import QueryResult, Hit, HSP, HSPFragment

# For speed try to use cElementTree rather than ElementTree
try:
    if (3, 0) <= sys.version_info[:2] <= (3, 1):
        # Workaround for bug in python 3.0 and 3.1,
        # see http://bugs.python.org/issue9257
        from xml.etree import ElementTree as ElementTree
    else:
        from xml.etree import cElementTree as ElementTree
except ImportError:
    from xml.etree import ElementTree as ElementTree


# element - hit attribute name mapping
_ELEM_HIT = {
    'name': ('accession', str),
    'ac': ('id', str),
    'desc': ('description', str),
}
# element - hsp attribute name mapping
_ELEM_HSP = {
    'score': ('bitscore', float),
    'evalue': ('evalue', float),
}
# element - fragment attribute name mapping
_ELEM_FRAG = {
    'start': ('query_start', int),
    'end': ('query_end', int),
    'hmm-start': ('hit_start', int),
    'hmm-end': ('hit_end', int),
}


class InterproscanXmlParser(object):
    """Parser for the InterProScan XML format."""

    def __init__(self, handle):
        """Initialize the class."""
        self.xml_iter = iter(ElementTree.iterparse(
            handle, events=('start', 'end')))
        self._meta = self._parse_header()

    def __iter__(self):
        """Iterate qresults."""
        for qresult in self._parse_qresult():
            yield qresult

    def _parse_header(self):
        """Parse the header for the InterProScan version (PRIVATE)."""
        event, elem = next(self.xml_iter)
        meta = dict()
        meta['target'] = 'InterPro'
        meta['program'] = 'InterProScan'
        meta['version'] = elem.attrib['interproscan-version']
        # store the namespace value
        self.NS = re.sub('protein-matches', '', elem.tag)
        return meta

    def _parse_qresult(self):
        """Parse query results (PRIVATE)."""
        for event, elem in self.xml_iter:
            if event == 'end' and elem.tag == self.NS + 'protein':
                # store the query sequence
                seq = elem.find(self.NS + 'sequence')
                query_seq = seq.text

                # store the query id and description
                xref = elem.find(self.NS + 'xref')
                query_id = xref.attrib['id']
                query_desc = xref.attrib['name']

                # parse each hit
                hit_list = []
                for hit_new in self._parse_hit(
                        elem.find(self.NS + 'matches'), query_id, query_seq):
                    # interproscan results contain duplicate hits rather than
                    # a single hit with multiple hsps. In this case the hsps
                    # of a duplicate hit will be appended to the already
                    # existing hit
                    for hit in hit_list:
                        if hit.id == hit_new.id:
                            for hsp in hit_new.hsps:
                                hit.hsps.append(hsp)
                            break
                    else:
                        hit_list.append(hit_new)

                # create qresult and assing attributes
                qresult = QueryResult(hit_list, query_id)
                setattr(qresult, 'description', query_desc)
                for key, value in self._meta.items():
                    setattr(qresult, key, value)
                yield qresult

    def _parse_hit(self, root_hit_elem, query_id, query_seq=None):
        """Parse hit (PRIVATE)."""
        # feed the loop below an empty list so iteration still works
        if root_hit_elem is None:
            root_hit_elem = []

        for hit_elem in root_hit_elem:
            # store the match/location type
            hit_type = re.sub(r"%s(\w+)-match" % self.NS,
                              r"\1",
                              hit_elem.find('.').tag)
            # store the hit id
            signature = hit_elem.find(self.NS + 'signature')
            hit_id = signature.attrib['ac']

            # store xrefs and alt_descs
            xrefs = self._parse_xrefs(
                signature.find(self.NS + 'entry'))

            # parse each hsp
            hsps = [hsp for hsp in self._parse_hsp(
                hit_elem.find(self.NS + 'locations'),
                query_id, hit_id, query_seq)]

            # create hit and assign attributes
            hit = Hit(hsps, hit_id)
            setattr(hit, 'dbxrefs', xrefs)
            for key, (attr, caster) in _ELEM_HIT.items():
                value = signature.attrib.get(key)
                if value is not None:
                    setattr(hit, attr, caster(value))
            # format specific attributes
            hit.attributes['Hit type'] = str(hit_type)
            signature_lib = signature.find(
                    self.NS + 'signature-library-release')
            hit.attributes['Target'] = str(
                signature_lib.attrib.get('library'))
            hit.attributes['Target version'] = str(
                signature_lib.attrib.get('version'))

            yield hit

    def _parse_hsp(self, root_hsp_elem, query_id, hit_id, query_seq=None):
        """Parse hsp (PRIVATE)."""
        # feed the loop below an empty list so iteration still works
        if root_hsp_elem is None:
            root_hsp_elem = []

        for hsp_elem in root_hsp_elem:
            # create frag and assign attributes
            frag = HSPFragment(hit_id, query_id)
            setattr(frag, 'alphabet', generic_protein)
            if query_seq is not None:
                setattr(frag, 'query', query_seq)
            for key, (attr, caster) in _ELEM_FRAG.items():
                value = hsp_elem.attrib.get(key)
                if value is not None:
                    # start should be 0-based
                    if attr.endswith('start'):
                        value = caster(value) - 1
                    # store query start and end to calculate aln_span
                    if attr == 'query_start':
                        start = int(value)
                    if attr == 'query_end':
                        end = int(value)
                    setattr(frag, attr, caster(value))
            # calculate aln_span and store
            setattr(frag, 'aln_span', end - start)

            # create hsp and assign attributes
            hsp = HSP([frag])
            setattr(hsp, 'query_id', query_id)
            setattr(hsp, 'hit_id', hit_id)
            for key, (attr, caster) in _ELEM_HSP.items():
                value = hsp_elem.attrib.get(key)
                if value is not None:
                    setattr(hsp, attr, caster(value))
            yield hsp

    def _parse_xrefs(self, root_entry_elem):
        """Parse xrefs (PRIVATE)."""
        xrefs = []
        # store entry id and description
        if root_entry_elem is not None:
            xrefs.append('IPR:' + root_entry_elem.attrib['ac'])

        # store go-xrefs and pathway-refs id and description
        if root_entry_elem is not None:
            xref_elems = []
            xref_elems = xref_elems + root_entry_elem.findall(
                self.NS + 'go-xref')
            xref_elems = xref_elems + root_entry_elem.findall(
                self.NS + 'pathway-xref')

            for entry in xref_elems:
                xref = entry.attrib['id']
                if ':' not in xref:
                    xref = entry.attrib['db'] + ':' + xref
                xrefs.append(xref)
        return xrefs


# if not used as a module, run the doctest
if __name__ == "__main__":
    from Bio._utils import run_doctest
    run_doctest()
