# Copyright 2012 by Wibowo Arindrarto.  All rights reserved.
# This code is part of the Biopython distribution and governed by its
# license.  Please see the LICENSE file that should have been included
# as part of this package.

"""Tests for SearchIO writing."""

import os
import unittest

from Bio import BiopythonExperimentalWarning

import warnings


with warnings.catch_warnings():
    warnings.simplefilter('ignore', BiopythonExperimentalWarning)
    from Bio import SearchIO

from search_tests_common import compare_search_obj


class WriteCases(unittest.TestCase):

    def tearDown(self):
        if os.path.exists(self.out):
            os.remove(self.out)

    def parse_write_and_compare(self, source_file, source_format, out_file,
            out_format, **kwargs):
        """Compares parsed QueryResults after they have been written to a file."""
        source_qresults = list(SearchIO.parse(source_file, source_format,
                **kwargs))
        SearchIO.write(source_qresults, out_file, out_format, **kwargs)
        out_qresults = list(SearchIO.parse(out_file, out_format, **kwargs))
        for source, out in zip(source_qresults, out_qresults):
            self.assertTrue(compare_search_obj(source, out))

    def read_write_and_compare(self, source_file, source_format, out_file,
            out_format, **kwargs):
        """Compares read QueryResults after it has been written to a file."""
        source_qresult = SearchIO.read(source_file, source_format, **kwargs)
        SearchIO.write(source_qresult, out_file, out_format, **kwargs)
        out_qresult = SearchIO.read(out_file, out_format, **kwargs)
        self.assertTrue(compare_search_obj(source_qresult, out_qresult))


class BlastXmlWriteCases(WriteCases):

    fmt = 'blast-xml'
    out = os.path.join('Blast', 'test_write.xml')

    def test_write_single_from_blastxml(self):
        """Test blast-xml writing from blast-xml, BLAST 2.2.26+, single query (xml_2226_blastp_004.xml)"""
        source = os.path.join('Blast', 'xml_2226_blastp_004.xml')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_multiple_from_blastxml(self):
        """Test blast-xml writing from blast-xml, BLAST 2.2.26+, multiple queries (xml_2226_blastp_001.xml)"""
        source = os.path.join('Blast', 'xml_2226_blastp_001.xml')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)


class BlastTabWriteCases(WriteCases):

    fmt = 'blast-tab'
    out = os.path.join('Blast', 'test_write.txt')

    def test_write_single_from_blasttab(self):
        """Test blast-tab writing from blast-tab, BLAST 2.2.26+, single query (tab_2226_tblastn_004.txt)"""
        source = os.path.join('Blast', 'tab_2226_tblastn_004.txt')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_multiple_from_blasttab(self):
        """Test blast-tab writing from blast-tab, BLAST 2.2.26+, multiple queries (tab_2226_tblastn_001.txt)"""
        source = os.path.join('Blast', 'tab_2226_tblastn_001.txt')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_single_from_blasttabc(self):
        """Test blast-tabc writing from blast-tabc, BLAST 2.2.26+, single query (tab_2226_tblastn_008.txt)"""
        source = os.path.join('Blast', 'tab_2226_tblastn_008.txt')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt, comments=True)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt, comments=True)

    def test_write_multiple_from_blasttabc(self):
        """Test blast-tabc writing from blast-tabc, BLAST 2.2.26+, multiple queries (tab_2226_tblastn_005.txt)"""
        source = os.path.join('Blast', 'tab_2226_tblastn_005.txt')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt, comments=True)

    def test_write_multiple_from_blasttabc_allfields(self):
        """Test blast-tabc writing from blast-tabc, BLAST 2.2.28+, multiple queries (tab_2228_tblastx_001.txt)"""
        source = os.path.join('Blast', 'tab_2228_tblastx_001.txt')
        fields = ['qseqid', 'qgi', 'qacc', 'qaccver', 'qlen', 'sseqid',
                'sallseqid', 'sgi', 'sallgi', 'sacc', 'saccver', 'sallacc',
                'slen', 'qstart', 'qend', 'sstart', 'send', 'qseq', 'sseq',
                'evalue', 'bitscore', 'score', 'length', 'pident', 'nident',
                'mismatch', 'positive', 'gapopen', 'gaps', 'ppos', 'frames',
                'qframe', 'sframe', 'btop', 'staxids', 'sscinames', 'scomnames',
                'sblastnames', 'sskingdoms', 'stitle', 'salltitles', 'sstrand',
                'qcovs', 'qcovhsp']
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt, comments=True, fields=fields)


class HmmerTabWriteCases(WriteCases):

    fmt = 'hmmer3-tab'
    out = os.path.join('Hmmer', 'test_write.txt')

    def test_write_single_from_hmmertab(self):
        """Test hmmer3-tab writing from hmmer3-tab, HMMER 3.0, single query (tab_30_hmmscan_004.out)"""
        source = os.path.join('Hmmer', 'tab_30_hmmscan_004.out')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_multiple_from_hmmertab(self):
        """Test hmmer3-tab writing from hmmer3-tab, HMMER 3.0, multiple queries (tab_30_hmmscan_001.out)"""
        source = os.path.join('Hmmer', 'tab_30_hmmscan_001.out')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)


class HmmerDomtabWriteCases(WriteCases):

    out = os.path.join('Hmmer', 'test_write.txt')

    def test_write_single_from_hmmscandomtab(self):
        """Test hmmscan-domtab writing from hmmscan-domtab, HMMER 3.0, single query (tab_30_hmmscan_004.out)"""
        source = os.path.join('Hmmer', 'domtab_30_hmmscan_004.out')
        fmt = 'hmmscan3-domtab'
        self.parse_write_and_compare(source, fmt, self.out, fmt)
        self.read_write_and_compare(source, fmt, self.out, fmt)

    def test_write_multiple_from_hmmscandomtab(self):
        """Test hmmscan-domtab writing from hmmscan-domtab, HMMER 3.0, multiple queries (tab_30_hmmscan_001.out)"""
        source = os.path.join('Hmmer', 'domtab_30_hmmscan_001.out')
        fmt = 'hmmscan3-domtab'
        self.parse_write_and_compare(source, fmt, self.out, fmt)

    def test_write_single_from_hmmsearchdomtab(self):
        """Test hmmsearch-domtab writing from hmmsearch-domtab, HMMER 3.0, single query (tab_30_hmmscan_004.out)"""
        source = os.path.join('Hmmer', 'domtab_30_hmmsearch_001.out')
        fmt = 'hmmsearch3-domtab'
        self.parse_write_and_compare(source, fmt, self.out, fmt)
        self.read_write_and_compare(source, fmt, self.out, fmt)


class BlatPslWriteCases(WriteCases):

    fmt = 'blat-psl'
    out = os.path.join('Blat', 'test_write.txt')

    def test_write_single_from_blatpsl(self):
        """Test blat-psl writing from blat-psl, single query (psl_34_004.psl)"""
        source = os.path.join('Blat', 'psl_34_004.psl')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_multiple_from_blatpsl(self):
        """Test blat-psl writing from blat-psl, multiple queries (psl_34_001.psl)"""
        source = os.path.join('Blat', 'psl_34_001.psl')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt)

    def test_write_single_from_blatpslx(self):
        """Test blat-pslx writing from blat-pslx, single query (pslx_34_004.pslx)"""
        source = os.path.join('Blat', 'pslx_34_004.pslx')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt, pslx=True)
        self.read_write_and_compare(source, self.fmt, self.out, self.fmt, pslx=True)

    def test_write_multiple_from_blatpslx(self):
        """Test blat-pslx writing from blat-pslx, multiple queries (pslx_34_001.pslx)"""
        source = os.path.join('Blat', 'pslx_34_001.pslx')
        self.parse_write_and_compare(source, self.fmt, self.out, self.fmt, pslx=True)


if __name__ == "__main__":
    runner = unittest.TextTestRunner(verbosity=2)
    unittest.main(testRunner=runner)
