# Copyright 2005 by Iddo Friedberg.  All rights reserved.
# Revisions copyright 2006-2013,2017 by Peter Cock. All rights reserved.
# Revisions copyright 2008 by Frank Kauff. All rights reserved.
# Revisions copyright 2009 by Michiel de Hoon. All rights reserved.
# Revisions copyright 2015 by Joe Cora. All rights reserved.
#
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.

"""Tests for Nexus module."""

import os.path
import sys
import tempfile
import unittest
from io import StringIO

from Bio import SeqIO
from Bio.Align import Alignment
from Bio.Align import MultipleSeqAlignment
from Bio.Align.nexus import AlignmentIterator
from Bio.Align.nexus import AlignmentWriter
from Bio.AlignIO.NexusIO import NexusIterator
from Bio.AlignIO.NexusIO import NexusWriter
from Bio.Nexus import Nexus
from Bio.Nexus import Trees
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord


class OldSelfTests(unittest.TestCase):
    """Test cases originally in Nexus.py via __main__."""

    def test_trees_and_taxa_block(self):
        """Basic tree file with TREES and TAXA block."""
        nexus1 = Nexus.Nexus()
        nexus1.read("Nexus/bats.nex")

    def test_data_and_codons_block(self):
        """Simple sequence data file with DATA and CODONS block."""
        nexus2 = Nexus.Nexus()
        nexus2.read("Nexus/codonposset.nex")

    def test_data_sets_trees_unknown_block(self):
        """Sequence data file with DATA, SETS, TREES and an unknown block."""
        nexus3 = Nexus.Nexus()
        nexus3.read("Nexus/test_Nexus_input.nex")

    def test_taxa_and_characters_block(self):
        """Taxa and characters multi-state block."""
        nexus4 = Nexus.Nexus()
        nexus4.read("Nexus/vSysLab_Ganaspidium_multistate.nex")

    def test_taxa_and_characters_with_many_codings_one_without_state(self):
        """Taxa and chr blocks, over 9 codings, 1 character without states."""
        nexus5 = Nexus.Nexus()
        nexus5.read("Nexus/vSysLab_Heptascelio_no-states_10+chars.nex")

    def test_taxa_and_characters_with_many_codings_two_without_state(self):
        """Taxa and chr blocks, over 9 codings, 2 character without states."""
        nexus6 = Nexus.Nexus()
        # TODO: Implement continuous datatype:
        # Bio.Nexus.Nexus.NexusError: Unsupported datatype: continuous
        self.assertRaises(
            Nexus.NexusError,
            nexus6.read,
            "Nexus/vSysLab_Oreiscelio_discrete+continuous.nex",
        )


class NexusTest1(unittest.TestCase):
    def setUp(self):
        self.testfile_dir = "Nexus"
        self.handle = open(os.path.join(self.testfile_dir, "test_Nexus_input.nex"))

    def tearDown(self):
        self.handle.close()

    def test_WriteToFileName(self):
        """Test writing to a given filename."""
        filename = "Nexus/test_temp.nex"
        if os.path.isfile(filename):
            os.remove(filename)
        n = Nexus.Nexus(self.handle)
        n.write_nexus_data(filename)
        self.assertTrue(os.path.isfile(filename))
        os.remove(filename)

    def test_write_with_dups(self):
        # see issue: biopython/Bio/Nexus/Nexus.py _unique_label() eval error #633
        records = [
            SeqRecord(
                Seq("ATGCTGCTGAT"), id="foo", annotations={"molecule_type": "DNA"}
            )
            for _ in range(4)
        ]
        out_file = StringIO()
        self.assertEqual(4, SeqIO.write(records, out_file, "nexus"))

    def test_NexusTest1(self):
        """Test Nexus module."""
        # check data of main nexus file
        n = Nexus.Nexus(self.handle)
        self.assertEqual(
            os.path.normpath(n.filename), os.path.normpath("Nexus/test_Nexus_input.nex")
        )
        self.assertEqual(n.ntax, 9)
        self.assertEqual(n.nchar, 48)
        self.assertEqual(n.datatype, "dna")
        self.assertTrue(n.interleave)
        self.assertEqual(n.missing, "?")
        self.assertEqual(n.gap, "-")
        self.assertEqual(
            n.taxlabels,
            [
                "t1",
                "t2 the name",
                "isn'that [a] strange name?",
                "one should be punished, for (that)!",
                "t5",
                "t6",
                "t7",
                "t8",
                "t9",
            ],
        )
        self.assertEqual(
            n.charlabels,
            {
                0: "a",
                1: "b",
                2: "c",
                4: "f",
                9: "A",
                10: "B",
                22: "x",
                23: "y",
                29: "1,2,3 can't decide for a name?!",
                47: "final",
            },
        )
        self.assertEqual(
            n.charsets,
            {
                "big": [0, 2, 4, 6],
                "bigchunk": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30,
                    31,
                    32,
                    33,
                    34,
                    35,
                    36,
                    37,
                    38,
                    39,
                    40,
                    41,
                    42,
                    43,
                    44,
                    45,
                    46,
                ],
                "byname": [
                    0,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                ],
                "c1": [0, 1, 2, 3, 4, 5, 6, 7],
                "c2": [8, 9, 10, 11, 12, 13, 14, 15],
                "c3": [16, 17, 18, 19, 20, 21, 22, 23],
                "firsthalf": [
                    0,
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                ],
                "mix": [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46],
                "mux": [
                    0,
                    1,
                    4,
                    7,
                    8,
                    10,
                    13,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    25,
                    28,
                    31,
                    34,
                    37,
                    40,
                    43,
                    46,
                ],
                "pos1": [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45],
                "pos2": [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46],
                "pos3": [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47],
                "secondhalf": [
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30,
                    31,
                    32,
                    33,
                    34,
                    35,
                    36,
                    37,
                    38,
                    39,
                    40,
                    41,
                    42,
                    43,
                    44,
                    45,
                    46,
                    47,
                ],
            },
        )
        self.assertEqual(
            n.taxsets,
            {
                "normal": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t5",
                    "t6",
                    "t8",
                ],
                "reference": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname1": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname2": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                    "t5",
                    "t6",
                    "t7",
                ],
                "tbyname3": ["t1", "t2 the name"],
            },
        )
        self.assertEqual(len(n.charpartitions), 2)
        self.assertIn("codons", n.charpartitions)
        self.assertIn("part", n.charpartitions)
        self.assertEqual(
            n.charpartitions["codons"],
            {
                "a": [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45],
                "b": [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46],
                "c": [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47],
            },
        )
        self.assertEqual(
            n.charpartitions["part"],
            {
                "one": [0, 1, 2, 3, 4, 5, 6, 7],
                "three": [16, 17, 18, 19, 20, 21, 22, 23],
                "two": [8, 9, 10, 11, 12, 13, 14, 15],
            },
        )
        self.assertEqual(list(n.taxpartitions), ["taxpart"])
        self.assertEqual(
            n.taxpartitions["taxpart"],
            {
                "badnames": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                ],
                "goodnames": ["t1", "t5", "t6", "t7", "t8", "t9"],
            },
        )

        # now we check excluding characters, deleting taxa,
        # and exporting adjusted sets
        f1 = tempfile.NamedTemporaryFile("w+")
        n.write_nexus_data(f1, delete=["t1", "t7"], exclude=n.invert(n.charsets["big"]))
        f1.seek(0)
        nf1 = Nexus.Nexus(f1)
        self.assertEqual(os.path.normpath(nf1.filename), os.path.normpath(f1.name))
        self.assertEqual(nf1.ntax, 7)
        self.assertEqual(nf1.nchar, 4)
        self.assertEqual(nf1.datatype, "dna")
        self.assertFalse(nf1.interleave)
        self.assertEqual(nf1.missing, "?")
        self.assertEqual(nf1.gap, "-")
        self.assertEqual(
            nf1.taxlabels,
            [
                "t2 the name",
                "isn'that [a] strange name?",
                "one should be punished, for (that)!",
                "t5",
                "t6",
                "t8",
                "t9",
            ],
        )
        self.assertEqual(nf1.charlabels, {0: "a", 1: "c", 2: "f"})
        self.assertEqual(
            nf1.charsets,
            {
                "big": [0, 1, 2, 3],
                "bigchunk": [1, 2, 3],
                "byname": [0, 2, 3],
                "c1": [0, 1, 2, 3],
                "firsthalf": [0, 1, 2, 3],
                "mix": [0, 2],
                "mux": [0, 2],
                "pos1": [0, 3],
                "pos2": [2],
                "pos3": [1],
            },
        )
        self.assertEqual(
            nf1.taxsets,
            {
                "normal": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t5",
                    "t6",
                    "t8",
                ],
                "reference": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname1": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname2": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname3": ["t2 the name"],
            },
        )
        self.assertEqual(len(nf1.charpartitions), 2)
        self.assertIn("codons", nf1.charpartitions)
        self.assertIn("part", nf1.charpartitions)
        self.assertEqual(
            nf1.charpartitions["codons"], {"a": [0, 3], "b": [2], "c": [1]}
        )
        self.assertEqual(nf1.charpartitions["part"], {"one": [0, 1, 2, 3]})

        self.assertEqual(list(nf1.taxpartitions), ["taxpart"])
        self.assertEqual(
            nf1.taxpartitions["taxpart"],
            {
                "badnames": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                ],
                "goodnames": ["t5", "t6", "t8", "t9"],
            },
        )

        f2 = tempfile.NamedTemporaryFile("w+")
        n.write_nexus_data(f2, delete=["t2_the_name"], exclude=list(range(3, 40, 4)))
        f2.seek(0)
        nf2 = Nexus.Nexus(f2)
        self.assertEqual(os.path.normpath(nf2.filename), os.path.normpath(f2.name))
        self.assertEqual(nf2.ntax, 9)
        self.assertEqual(nf2.nchar, 38)
        self.assertEqual(nf2.datatype, "dna")
        self.assertFalse(nf2.interleave)
        self.assertEqual(nf2.missing, "?")
        self.assertEqual(nf2.gap, "-")
        self.assertEqual(
            nf2.taxlabels,
            [
                "t1",
                "t2 the name",
                "isn'that [a] strange name?",
                "one should be punished, for (that)!",
                "t5",
                "t6",
                "t7",
                "t8",
                "t9",
            ],
        )
        self.assertEqual(
            nf2.charlabels,
            {
                0: "a",
                1: "b",
                2: "c",
                3: "f",
                7: "A",
                8: "B",
                17: "x",
                22: "1,2,3 can't decide for a name?!",
                37: "final",
            },
        )
        self.assertEqual(
            nf2.charsets,
            {
                "big": [0, 2, 3, 5],
                "bigchunk": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30,
                    31,
                    32,
                    33,
                    34,
                    35,
                    36,
                ],
                "byname": [
                    0,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                ],
                "c1": [0, 1, 2, 3, 4, 5],
                "c2": [6, 7, 8, 9, 10, 11],
                "c3": [12, 13, 14, 15, 16, 17],
                "firsthalf": [
                    0,
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                ],
                "mix": [0, 1, 3, 8, 10, 12, 17, 19, 21, 26, 28, 30, 33, 36],
                "mux": [
                    0,
                    1,
                    3,
                    6,
                    8,
                    10,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    19,
                    21,
                    26,
                    28,
                    30,
                    33,
                    36,
                ],
                "pos1": [0, 5, 7, 9, 14, 16, 18, 23, 25, 27, 32, 35],
                "pos2": [1, 3, 8, 10, 12, 17, 19, 21, 26, 28, 30, 33, 36],
                "pos3": [2, 4, 6, 11, 13, 15, 20, 22, 24, 29, 31, 34, 37],
                "secondhalf": [
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30,
                    31,
                    32,
                    33,
                    34,
                    35,
                    36,
                    37,
                ],
            },
        )

        self.assertEqual(
            nf2.taxsets,
            {
                "normal": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t5",
                    "t6",
                    "t8",
                ],
                "reference": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname1": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t1",
                    "t2 the name",
                    "t5",
                    "t6",
                ],
                "tbyname2": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                    "t5",
                    "t6",
                    "t7",
                ],
                "tbyname3": ["t1", "t2 the name"],
            },
        )
        self.assertEqual(len(nf2.charpartitions), 2)
        self.assertIn("codons", nf2.charpartitions)
        self.assertIn("part", nf2.charpartitions)
        self.assertEqual(
            nf2.charpartitions["codons"],
            {
                "a": [0, 5, 7, 9, 14, 16, 18, 23, 25, 27, 32, 35],
                "b": [1, 3, 8, 10, 12, 17, 19, 21, 26, 28, 30, 33, 36],
                "c": [2, 4, 6, 11, 13, 15, 20, 22, 24, 29, 31, 34, 37],
            },
        )
        self.assertEqual(
            nf2.charpartitions["part"],
            {
                "one": [0, 1, 2, 3, 4, 5],
                "three": [12, 13, 14, 15, 16, 17],
                "two": [6, 7, 8, 9, 10, 11],
            },
        )
        self.assertEqual(list(nf2.taxpartitions), ["taxpart"])
        self.assertEqual(
            nf2.taxpartitions["taxpart"],
            {
                "badnames": [
                    "isn'that [a] strange name?",
                    "one should be punished, for (that)!",
                    "t2 the name",
                ],
                "goodnames": ["t1", "t5", "t6", "t7", "t8", "t9"],
            },
        )
        # check the stepmatrix
        self.assertEqual(
            n.weighted_stepmatrix(name="matrix_test"),
            """\
usertype matrix_test stepmatrix=5
        A        C        G        T        -
[A]     .       2.40     2.57     2.43     2.43     
[C]    2.40      .       2.28     2.12     2.14     
[G]    2.57     2.28      .       2.31     2.31     
[T]    2.43     2.12     2.31      .       2.14     
[-]    2.43     2.14     2.31     2.14      .       
;
""",  # noqa : W291
        )

    def test_write_alignment(self):
        # Default causes no interleave (columns <= 1000)
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 90), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = Alignment(records)

        handle = StringIO()
        AlignmentWriter(handle).write([a])
        handle.seek(0)
        data = handle.read()
        self.assertEqual(
            data,
            """\
#NEXUS
begin data;
dimensions ntax=3 nchar=900;
format datatype=dna missing=? gap=-;
matrix
foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
;
end;
""",
        )
        # Default causes interleave (columns > 1000)
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 110), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = Alignment(records)
        handle = StringIO()
        AlignmentWriter(handle).write([a])
        handle.seek(0)
        data = handle.read()
        self.assertEqual(
            data,
            """\
#NEXUS
begin data;
dimensions ntax=3 nchar=1100;
format datatype=dna missing=? gap=- interleave;
matrix
foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

;
end;
""",
        )

        # Override interleave: True
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 9), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = Alignment(records)
        handle = StringIO()
        AlignmentWriter(handle, interleave=True).write([a])
        handle.seek(0)
        data = handle.read()
        self.assertEqual(
            data,
            """\
#NEXUS
begin data;
dimensions ntax=3 nchar=90;
format datatype=dna missing=? gap=- interleave;
matrix
foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA

foo ATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGA

;
end;
""",
        )

        # Override interleave: False
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 110), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = Alignment(records)
        handle = StringIO()
        AlignmentWriter(handle, interleave=False).write([a])
        handle.seek(0)
        data = handle.read()
        self.assertEqual(
            data,
            """\
#NEXUS
begin data;
dimensions ntax=3 nchar=1100;
format datatype=dna missing=? gap=-;
matrix
foo ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
bar ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
baz ATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGAATGCTGCTGA
;
end;
""",
        )

    def test_write_alignment_msa(self):
        # Default causes no interleave (columns <= 1000)
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 90), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = MultipleSeqAlignment(records)

        handle = StringIO()
        NexusWriter(handle).write_alignment(a)
        handle.seek(0)
        data = handle.read()
        self.assertIn("ATGCTGCTGA" * 90, data)

        # Default causes interleave (columns > 1000)
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 110), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = MultipleSeqAlignment(records)
        handle = StringIO()
        NexusWriter(handle).write_alignment(a)
        handle.seek(0)
        data = handle.read()
        self.assertNotIn("ATGCTGCTGA" * 90, data)
        self.assertIn("ATGCTGCTGA" * 7, data)

        # Override interleave: True
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 9), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = MultipleSeqAlignment(records)
        handle = StringIO()
        NexusWriter(handle).write_alignment(a, interleave=True)
        handle.seek(0)
        data = handle.read()
        self.assertNotIn("ATGCTGCTGA" * 9, data)
        self.assertIn("ATGCTGCTGA" * 7, data)

        # Override interleave: False
        records = [
            SeqRecord(
                Seq("ATGCTGCTGA" * 110), id=_id, annotations={"molecule_type": "DNA"}
            )
            for _id in ["foo", "bar", "baz"]
        ]
        a = MultipleSeqAlignment(records)
        handle = StringIO()
        NexusWriter(handle).write_alignment(a, interleave=False)
        handle.seek(0)
        data = handle.read()
        self.assertIn("ATGCTGCTGA" * 110, data)

    def test_TreeTest1(self):
        """Test Tree module."""
        n = Nexus.Nexus(self.handle)
        t3 = n.trees[2]
        t2 = n.trees[2]
        t3.root_with_outgroup(["t1", "t5"])
        self.assertEqual(
            str(t3),
            "tree tree1 = (((((('one should be punished, for (that)!','isn''that [a] strange name?'),'t2 the name'),t8,t9),t6),t7),(t5,t1));",
        )
        self.assertEqual(t3.is_monophyletic(["t8", "t9", "t6", "t7"]), -1)
        self.assertEqual(t3.is_monophyletic(["t1", "t5"]), 13)
        t3.split(parent_id=t3.search_taxon("t9"))
        stdout = sys.stdout
        try:
            sys.stdout = StringIO()
            t3.display()
            output = sys.stdout.getvalue()
        finally:
            sys.stdout = stdout
        expected = """\
  #                            taxon            prev            succ    brlen blen (sum)  support              comment
  1    'isn''that [a] strange name?'               2              []   100.00     119.84    10.00                    -
  2                                -               4          [3, 1]     0.40      19.84     0.30                    -
  3 'one should be punished, for (that)!'               2              []     0.50      20.34        -                    -
  4                                -               6          [2, 5]     4.00      19.44     3.00                    -
  5                    't2 the name'               4              []     0.30      19.74        -                    -
  6                                -               9       [4, 7, 8]     2.00      15.44     1.00                    -
  7                               t8               6              []     1.20      16.64        -                    -
  8                               t9               6        [17, 18]     3.40      18.84        -                    -
  9                                -              11         [6, 10]     0.44      13.44    33.00                    -
 10                               t6               9              []     1.00      14.44        -                    -
 11                                -              16         [9, 12]    13.00      13.00    12.00                    -
 12                               t7              11              []    99.90     112.90        -                    -
 13                                -              16        [14, 15]     0.00       0.00     0.00                    -
 14                               t5              13              []    99.00      99.00        -                    -
 15                               t1              13              []     0.98       0.98        -                    -
 16                                -            None        [11, 13]     0.00       0.00        -                    -
 17                              t90               8              []     1.00      19.84        -                    -
 18                              t91               8              []     1.00      19.84        -                    -

Root:  16
"""
        self.assertEqual(len(output.split("\n")), len(expected.split("\n")))
        for l1, l2 in zip(output.split("\n"), expected.split("\n")):
            self.assertEqual(l1, l2)
        self.assertEqual(output, expected)
        self.assertEqual(t3.is_compatible(t2, threshold=0.3), [])

    def test_TreeTest2(self):
        """Handle text labels on internal nodes."""
        ts1b = (
            "(Cephalotaxus:125.000000,(Taxus:100.000000,Torreya:100.000000)"
            "TT1:25.000000)Taxaceae:90.000000;"
        )
        tree = Trees.Tree(ts1b)
        self.assertEqual(
            self._get_flat_nodes(tree),
            [
                ("Taxaceae", 90.0, None, None),
                ("Cephalotaxus", 125.0, None, None),
                ("TT1", 25.0, None, None),
                ("Taxus", 100.0, None, None),
                ("Torreya", 100.0, None, None),
            ],
        )
        tree.prune("Torreya")
        self.assertEqual(tree.all_ids(), [0, 1, 3])
        ts1c = (
            "(Cephalotaxus:125.000000,(Taxus:100.000000,Torreya:100.000000)"
            "25.000000)90.000000;"
        )
        tree = Trees.Tree(ts1c)
        self.assertEqual(
            self._get_flat_nodes(tree),
            [
                (None, 90.0, None, None),
                ("Cephalotaxus", 125.0, None, None),
                (None, 25.0, None, None),
                ("Taxus", 100.0, None, None),
                ("Torreya", 100.0, None, None),
            ],
        )
        self.assertFalse(tree.has_support())
        with self.assertRaises(Exception) as context:
            tree.randomize()
        self.assertIn(
            "Either number of taxa or list of taxa must be specified.",
            str(context.exception),
        )
        tree_rand = Trees.Tree(ts1c)
        tree_rand.randomize(ntax=4)
        self.assertEqual(
            sorted(tree_rand.get_taxa()), ["taxon1", "taxon2", "taxon3", "taxon4"]
        )
        tree.branchlength2support()
        tree.convert_absolute_support(2)
        self.assertEqual(
            self._get_flat_nodes(tree),
            [
                (None, 0.0, 90.0, None),
                ("Cephalotaxus", 0.0, 62.5, None),
                (None, 0.0, 12.5, None),
                ("Taxus", 0.0, 50.0, None),
                ("Torreya", 0.0, 50.0, None),
            ],
        )

        ts2 = (
            "(((t9:0.385832, (t8:0.445135,t4:0.41401)C:0.024032)B:0.041436,"
            "t6:0.392496)A:0.0291131, t2:0.497673, ((t0:0.301171,"
            "t7:0.482152)E:0.0268148, ((t5:0.0984167,t3:0.488578)G:0.0349662,"
            "t1:0.130208)F:0.0318288)D:0.0273876);"
        )
        tree = Trees.Tree(ts2)
        tree.branchlength2support()
        supports = []
        for i in tree.all_ids():
            node = tree.node(i)
            data = node.get_data()
            supports.append(data.support)
        self.assertEqual(
            supports,
            [
                0.0,
                0.0291131,
                0.041436,
                0.385832,
                0.024032,
                0.445135,
                0.41401,
                0.392496,
                0.497673,
                0.0273876,
                0.0268148,
                0.301171,
                0.482152,
                0.0318288,
                0.0349662,
                0.0984167,
                0.488578,
                0.130208,
            ],
        )
        ts3 = (
            "(((B 9:0.385832, (C 8:0.445135, C4:0.41401)C:0.024032)B:0.041436,"
            "A 6:0.392496)A:0.0291131, t2:0.497673, ((E 0:0.301171,"
            "E 7:0.482152)E:0.0268148, ((G 5:0.0984167,G 3:0.488578)G:0.0349662,"
            "F 1:0.130208)F:0.0318288)D:0.0273876);"
        )
        self.assertFalse(tree.is_identical(Trees.Tree(ts3)))
        tree = Trees.Tree(ts3)
        self.assertTrue(tree.is_bifurcating())
        self.assertTrue(tree.is_bifurcating(1))
        self.assertEqual(
            [tree.distance(0, n) for n in tree.all_ids()],
            [
                0.0,
                0.0291131,
                0.0705491,
                0.4563811,
                0.0945811,
                0.5397161,
                0.5085911,
                0.4216091,
                0.497673,
                0.0273876,
                0.0542024,
                0.3553734,
                0.5363544,
                0.0592164,
                0.09418259999999999,
                0.1925993,
                0.5827606,
                0.1894244,
            ],
        )

        subtree = tree.set_subtree(10)
        self.assertEqual(sorted(subtree), ["E 0", "E 7"])
        tree.collapse_genera()
        self.assertEqual(tree.all_ids(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 13, 14, 17])

    def test_merge_with_support(self):
        """Test merge_with_support and consensus method."""
        ts1 = (
            "(((B 9:0.385832, (C 8:0.445135, C 4:0.41401)C:0.024032)B:0.041436,"
            "A 6:0.392496)A:0.0291131, t2:0.497673, ((E 0:0.301171,"
            "E 7:0.482152)E:0.0268148, ((G 5:0.0984167,G 3:0.488578)G:0.0349662,"
            "F 1:0.130208)F:0.0318288)D:0.0273876);"
        )

        tbs1 = (
            "(((B 9:0.385832, (C 8:0.445135, C 4:0.41401)C:0.024032)B:0.041436,"
            "A 6:0.392496)A:0.0291131, t2:0.497673, ((G 5:0.0984167,"
            "G 3:0.488578)E:0.0268148, ((E 0:0.301171, E 7:0.482152)G:0.0349662,"
            "F 1:0.130208)F:0.0318288)D:0.0273876);"
        )

        tbs2 = (
            "(((B 9:0.385832,A 6:0.392496 C:0.024032)B:0.041436, (C 8:0.445135,"
            "C 4:0.41401))A:0.0291131, t2:0.497673, ((E 0:0.301171, E 7:0.482152)"
            "E:0.0268148, ((G 5:0.0984167,G 3:0.488578)G:0.0349662,F 1:0.130208)"
            "F:0.0318288)D:0.0273876);"
        )

        t1 = Trees.Tree(ts1)
        tb1 = Trees.Tree(tbs1)
        tb2 = Trees.Tree(tbs2)

        t1.branchlength2support()
        tb1.branchlength2support()
        tb2.branchlength2support()

        t1.merge_with_support(bstrees=[tb1, tb2], threshold=0.2)

        supports = []
        for i in t1.all_ids():
            node = t1.node(i)
            data = node.get_data()
            supports.append(data.support)
        self.assertTrue(
            supports,
            [
                0.0,
                1.0,
                0.04,
                1.0,
                0.5,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                0.5,
                1.0,
                1.0,
                0.5,
                1.0,
                1.0,
                1.0,
                1.0,
            ],
        )

    def test_to_string(self):
        """Test to_string method."""
        t = Trees.Tree(
            "(((B 9:0.385832, (C 8:0.445135, C 4:0.41401)C:0.024032)B:0.041436,"
            "A 6:0.392496)A:0.0291131, t2:0.497673, ((E 0:0.301171,"
            "E 7:0.482152)E:0.0268148, ((G 5:0.0984167,G 3:0.488578)G:0.0349662,"
            "F 1:0.130208)F:0.0318288)D:0.0273876);"
        )
        tree_a_val = "((A 6,(B 9,(C 8,C 4))),t2,((E 0,E 7),(F 1,(G 5,G 3))));"
        self.assertEqual(t.to_string(ladderize="LEFT"), "tree a_tree = " + tree_a_val)

    def test_large_newick(self):
        with open(
            os.path.join(self.testfile_dir, "int_node_labels.nwk")
        ) as large_ex_handle:
            tree = Trees.Tree(large_ex_handle.read())

    def _get_flat_nodes(self, tree):
        cur_nodes = [tree.node(tree.root)]
        nodedata = []
        while len(cur_nodes) > 0:
            new_nodes = []
            for cur_node in cur_nodes:
                nodedata.append(
                    (
                        cur_node.data.taxon,
                        cur_node.data.branchlength,
                        cur_node.data.support,
                        cur_node.data.comment,
                    )
                )
                new_nodes.extend([tree.node(nid) for nid in cur_node.get_succ()])
            cur_nodes = new_nodes
        return nodedata

    def test_NexusComments(self):
        """Test the ability to parse nexus comments at internal and leaf nodes."""
        # A tree with simple comments throughout the tree.
        ts1b = "((12:0.13,19[&comment1]:0.13)[&comment2]:0.1,(20:0.171,11:0.171):0.13)[&comment3];"
        tree = Trees.Tree(ts1b)
        self.assertEqual(
            self._get_flat_nodes(tree),
            [
                (None, 0.0, None, "[&comment3]"),
                (None, 0.1, None, "[&comment2]"),
                (None, 0.13, None, None),
                ("12", 0.13, None, None),
                ("19", 0.13, None, "[&comment1]"),
                ("20", 0.171, None, None),
                ("11", 0.171, None, None),
            ],
        )

        # A tree with more complex comments throughout the tree.
        # This is typical of the MCC trees produced by `treeannotator` in the beast-mcmc suite of phylogenetic tools
        # The key difference being tested here is the ability to parse internal node comments that include ','.
        ts1b = "(((9[&rate_range={1.3E-5,0.10958320752991428},height_95%_HPD={0.309132419999969,0.3091324199999691},length_range={3.513906814545109E-4,0.4381986285528381},height_median=0.309132419999969,length_95%_HPD={0.003011577063374571,0.08041621647998398}]:0.055354097721950546,5[&rate_range={1.3E-5,0.10958320752991428},height_95%_HPD={0.309132419999969,0.3091324199999691},length_range={3.865051168833178E-5,0.4391594442572986},height_median=0.309132419999969,length_95%_HPD={0.003011577063374571,0.08041621647998398}]:0.055354097721950546)[&height_95%_HPD={0.3110921040545068,0.38690865205576275},length_range={0.09675588357303178,0.4332959544380489},length_95%_HPD={0.16680375169879613,0.36500804261814374}]:0.20039426358269385)[&height_95%_HPD={0.5289500597932948,0.6973881165460601},length_range={0.02586430194846201,0.29509451958008265},length_95%_HPD={0.0840287249314221,0.2411078625957056}]:0.23042678598484334)[&height_95%_HPD={0.7527502510685965,0.821862094763501},height_median=0.8014438411766163,height=0.795965080422763,posterior=1.0,height_range={0.49863013698599995,0.821862094763501},length=0.0];"
        tree = Trees.Tree(ts1b)
        self.assertEqual(
            self._get_flat_nodes(tree),
            [
                (
                    None,
                    0.0,
                    None,
                    "[&height_95%_HPD={0.7527502510685965,0.821862094763501},height_median=0.8014438411766163,height=0.795965080422763,posterior=1.0,height_range={0.49863013698599995,0.821862094763501},length=0.0]",
                ),
                (
                    None,
                    0.23042678598484334,
                    None,
                    "[&height_95%_HPD={0.5289500597932948,0.6973881165460601},length_range={0.02586430194846201,0.29509451958008265},length_95%_HPD={0.0840287249314221,0.2411078625957056}]",
                ),
                (
                    None,
                    0.20039426358269385,
                    None,
                    "[&height_95%_HPD={0.3110921040545068,0.38690865205576275},length_range={0.09675588357303178,0.4332959544380489},length_95%_HPD={0.16680375169879613,0.36500804261814374}]",
                ),
                (
                    "9",
                    0.055354097721950546,
                    None,
                    "[&rate_range={1.3E-5,0.10958320752991428},height_95%_HPD={0.309132419999969,0.3091324199999691},length_range={3.513906814545109E-4,0.4381986285528381},height_median=0.309132419999969,length_95%_HPD={0.003011577063374571,0.08041621647998398}]",
                ),
                (
                    "5",
                    0.055354097721950546,
                    None,
                    "[&rate_range={1.3E-5,0.10958320752991428},height_95%_HPD={0.309132419999969,0.3091324199999691},length_range={3.865051168833178E-5,0.4391594442572986},height_median=0.309132419999969,length_95%_HPD={0.003011577063374571,0.08041621647998398}]",
                ),
            ],
        )


class TestSelf(unittest.TestCase):
    def test_repeated_names_no_taxa(self):
        handle = StringIO(
            """#NEXUS
        [TITLE: NoName]
        begin data;
        dimensions ntax=4 nchar=50;
        format interleave datatype=protein   gap=- symbols="FSTNKEYVQMCLAWPHDRIG";
        matrix
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ----
        ALEU_HORVU          MAHARVLLLA LAVLATAAVA VASSSSFADS NPIRPVTDRA ASTLESAVLG
        CATH_HUMAN          ------MWAT LPLLCAGAWL LGV------- -PVCGAAELS VNSLEK----
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ---X
        ;
        end;
        """
        )
        alignments = AlignmentIterator(handle)
        alignment = next(alignments)
        self.assertEqual(alignment.shape, (4, 50))
        self.assertEqual(
            str(alignment),
            """\
CYS1_DICD         0 -----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---- 26
ALEU_HORV         0 MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG 50
CATH_HUMA         0 ------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK---- 32
CYS1_DICD         0 -----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X 27
""",
        )
        self.assertEqual(
            alignment[0], "-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ----"
        )
        self.assertEqual(alignment.sequences[0].id, "CYS1_DICDI")
        self.assertEqual(
            alignment[1], "MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG"
        )
        self.assertEqual(alignment.sequences[1].id, "ALEU_HORVU")
        self.assertEqual(
            alignment[2], "------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK----"
        )
        self.assertEqual(alignment.sequences[2].id, "CATH_HUMAN")
        self.assertEqual(
            alignment[3], "-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X"
        )
        self.assertEqual(alignment.sequences[3].id, "CYS1_DICDI")
        self.assertRaises(StopIteration, next, alignments)

    def test_repeated_names_no_taxa_msa(self):
        handle = StringIO(
            """#NEXUS
        [TITLE: NoName]
        begin data;
        dimensions ntax=4 nchar=50;
        format interleave datatype=protein   gap=- symbols="FSTNKEYVQMCLAWPHDRIG";
        matrix
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ----
        ALEU_HORVU          MAHARVLLLA LAVLATAAVA VASSSSFADS NPIRPVTDRA ASTLESAVLG
        CATH_HUMAN          ------MWAT LPLLCAGAWL LGV------- -PVCGAAELS VNSLEK----
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ---X
        ;
        end;
        """
        )
        alignments = NexusIterator(handle)
        alignment = next(alignments)
        self.assertEqual(len(alignment), 4)
        self.assertEqual(
            str(alignment),
            """\
Alignment with 4 rows and 50 columns
-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---- CYS1_DICDI
MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG ALEU_HORVU
------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK---- CATH_HUMAN
-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X CYS1_DICDI.copy""",
        )
        records = iter(alignment)
        record = next(records)
        self.assertEqual(
            record.seq, Seq("-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ----")
        )
        self.assertEqual(record.name, "CYS1_DICDI")
        self.assertEqual(record.id, "CYS1_DICDI")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG")
        )
        self.assertEqual(record.name, "ALEU_HORVU")
        self.assertEqual(record.id, "ALEU_HORVU")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK----")
        )
        self.assertEqual(record.name, "CATH_HUMAN")
        self.assertEqual(record.id, "CATH_HUMAN")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X")
        )
        self.assertEqual(record.name, "CYS1_DICDI")
        self.assertEqual(record.id, "CYS1_DICDI.copy")
        self.assertRaises(StopIteration, next, records)
        self.assertRaises(StopIteration, next, alignments)

    def test_repeated_names_with_taxa(self):
        handle = StringIO(
            """#NEXUS
        [TITLE: NoName]
        begin taxa
        CYS1_DICDI
        ALEU_HORVU
        CATH_HUMAN
        CYS1_DICDI;
        end;
        begin data;
        dimensions ntax=4 nchar=50;
        format interleave datatype=protein   gap=- symbols="FSTNKEYVQMCLAWPHDRIG";
        matrix
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ----
        ALEU_HORVU          MAHARVLLLA LAVLATAAVA VASSSSFADS NPIRPVTDRA ASTLESAVLG
        CATH_HUMAN          ------MWAT LPLLCAGAWL LGV------- -PVCGAAELS VNSLEK----
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ---X
        ;
        end;
        """
        )
        alignments = AlignmentIterator(handle)
        alignment = next(alignments)
        self.assertEqual(len(alignment), 4)
        self.assertEqual(
            str(alignment),
            """\
CYS1_DICD         0 -----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---- 26
ALEU_HORV         0 MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG 50
CATH_HUMA         0 ------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK---- 32
CYS1_DICD         0 -----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X 27
""",
        )
        self.assertEqual(
            alignment[0], "-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ----"
        )
        self.assertEqual(alignment.sequences[0].id, "CYS1_DICDI")
        self.assertEqual(
            alignment[1], "MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG"
        )
        self.assertEqual(alignment.sequences[1].id, "ALEU_HORVU")
        self.assertEqual(
            alignment[2], "------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK----"
        )
        self.assertEqual(alignment.sequences[2].id, "CATH_HUMAN")
        self.assertEqual(
            alignment[3], "-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X"
        )
        self.assertEqual(alignment.sequences[3].id, "CYS1_DICDI")
        self.assertRaises(StopIteration, next, alignments)

    def test_repeated_names_with_taxa_msa(self):
        handle = StringIO(
            """#NEXUS
        [TITLE: NoName]
        begin taxa
        CYS1_DICDI
        ALEU_HORVU
        CATH_HUMAN
        CYS1_DICDI;
        end;
        begin data;
        dimensions ntax=4 nchar=50;
        format interleave datatype=protein   gap=- symbols="FSTNKEYVQMCLAWPHDRIG";
        matrix
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ----
        ALEU_HORVU          MAHARVLLLA LAVLATAAVA VASSSSFADS NPIRPVTDRA ASTLESAVLG
        CATH_HUMAN          ------MWAT LPLLCAGAWL LGV------- -PVCGAAELS VNSLEK----
        CYS1_DICDI          -----MKVIL LFVLAVFTVF VSS------- --------RG IPPEEQ---X
        ;
        end;
        """
        )
        alignments = NexusIterator(handle)
        alignment = next(alignments)
        self.assertEqual(len(alignment), 4)
        self.assertEqual(
            str(alignment),
            """\
Alignment with 4 rows and 50 columns
-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---- CYS1_DICDI
MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG ALEU_HORVU
------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK---- CATH_HUMAN
-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X CYS1_DICDI.copy""",
        )
        records = iter(alignment)
        record = next(records)
        self.assertEqual(
            record.seq, Seq("-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ----")
        )
        self.assertEqual(record.name, "CYS1_DICDI")
        self.assertEqual(record.id, "CYS1_DICDI")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("MAHARVLLLALAVLATAAVAVASSSSFADSNPIRPVTDRAASTLESAVLG")
        )
        self.assertEqual(record.name, "ALEU_HORVU")
        self.assertEqual(record.id, "ALEU_HORVU")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("------MWATLPLLCAGAWLLGV--------PVCGAAELSVNSLEK----")
        )
        self.assertEqual(record.name, "CATH_HUMAN")
        self.assertEqual(record.id, "CATH_HUMAN")
        record = next(records)
        self.assertEqual(
            record.seq, Seq("-----MKVILLFVLAVFTVFVSS---------------RGIPPEEQ---X")
        )
        self.assertEqual(record.name, "CYS1_DICDI")
        self.assertEqual(record.id, "CYS1_DICDI.copy")
        self.assertRaises(StopIteration, next, records)
        self.assertRaises(StopIteration, next, alignments)

    def test_empty_file_read(self):
        with self.assertRaises(ValueError) as cm:
            AlignmentIterator(StringIO())
        self.assertEqual(str(cm.exception), "Empty file.")

    def test_empty_file_read_msa(self):
        self.assertEqual([], list(NexusIterator(StringIO())))

    def test_multiple_output(self):
        records = [
            SeqRecord(
                Seq("ATGCTGCTGAT"), id="foo", annotations={"molecule_type": "DNA"}
            ),
            SeqRecord(
                Seq("ATGCTGCAGAT"), id="bar", annotations={"molecule_type": "DNA"}
            ),
            SeqRecord(
                Seq("ATGCTGCGGAT"), id="baz", annotations={"molecule_type": "DNA"}
            ),
        ]
        a = Alignment(records)

        handle = StringIO()
        AlignmentWriter(handle).write([a])
        handle.seek(0)
        data = handle.read()
        self.assertEqual(
            data,
            """\
#NEXUS
begin data;
dimensions ntax=3 nchar=11;
format datatype=dna missing=? gap=-;
matrix
foo ATGCTGCTGAT
bar ATGCTGCAGAT
baz ATGCTGCGGAT
;
end;
""",
        )
        handle = StringIO()
        with self.assertRaises(ValueError):
            AlignmentWriter(handle).write([a, a])

    def test_multiple_output_msa(self):
        records = [
            SeqRecord(
                Seq("ATGCTGCTGAT"), id="foo", annotations={"molecule_type": "DNA"}
            ),
            SeqRecord(
                Seq("ATGCTGCAGAT"), id="bar", annotations={"molecule_type": "DNA"}
            ),
            SeqRecord(
                Seq("ATGCTGCGGAT"), id="baz", annotations={"molecule_type": "DNA"}
            ),
        ]
        a = MultipleSeqAlignment(records)

        handle = StringIO()
        NexusWriter(handle).write_file([a])
        handle.seek(0)
        data = handle.read()
        self.assertTrue(data.startswith("#NEXUS\nbegin data;\n"), data)
        self.assertTrue(data.endswith("end;\n"), data)

        handle = StringIO()
        with self.assertRaises(ValueError):
            NexusWriter(handle).write_file([a, a])


if __name__ == "__main__":
    runner = unittest.TextTestRunner(verbosity=2)
    unittest.main(testRunner=runner)
