# Copyright 2002 by Jeffrey Chang.  All rights reserved.
# Revisions copyright 2008 by Brad Chapman. All rights reserved.
# Revisions copyright 2008 by Michiel de Hoon. All rights reserved.
# Revisions copyright 2008-2010,2013-2014 by Peter Cock. All rights reserved.
# Revisions copyright 2012 by Christian Brueffer. All rights reserved.
# Revisions copyright 2017 by Maximilian Greil. All rights reserved.
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.

"""Tests for MarkovModel module."""

import warnings
import unittest

from io import StringIO

try:
    from numpy import array
    from numpy import random  # missing in PyPy's micronumpy
    from numpy import array_equal
    from numpy import around
    from numpy import log
except ImportError:
    from Bio import MissingPythonDependencyError

    raise MissingPythonDependencyError(
        "Install NumPy if you want to use Bio.MarkovModel."
    ) from None

with warnings.catch_warnings():
    # Silence this warning:
    # For optimal speed, please update to Numpy version 1.3 or later
    warnings.simplefilter("ignore", UserWarning)
    from Bio import MarkovModel


class TestMarkovModel(unittest.TestCase):
    def test_train_visible(self):
        states = ["0", "1", "2", "3"]
        alphabet = ["A", "C", "G", "T"]
        training_data = [
            ("AACCCGGGTTTTTTT", "001112223333333"),
            ("ACCGTTTTTTT", "01123333333"),
            ("ACGGGTTTTTT", "01222333333"),
            ("ACCGTTTTTTTT", "011233333333"),
        ]
        markov_model = MarkovModel.train_visible(states, alphabet, training_data)
        states = MarkovModel.find_states(markov_model, "AACGTT")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(state_list, ["0", "0", "1", "2", "3", "3"])
        self.assertAlmostEqual(state_float, 0.0082128906)
        self.assertEqual(markov_model.states, ["0", "1", "2", "3"])
        self.assertEqual(markov_model.alphabet, ["A", "C", "G", "T"])
        self.assertEqual(len(markov_model.p_initial), 4)
        self.assertAlmostEqual(markov_model.p_initial[0], 1.0)
        self.assertAlmostEqual(markov_model.p_initial[1], 0.0)
        self.assertAlmostEqual(markov_model.p_initial[2], 0.0)
        self.assertAlmostEqual(markov_model.p_initial[3], 0.0)
        self.assertEqual(len(markov_model.p_transition), 4)
        self.assertEqual(len(markov_model.p_transition[0]), 4)
        self.assertEqual(len(markov_model.p_transition[1]), 4)
        self.assertEqual(len(markov_model.p_transition[2]), 4)
        self.assertEqual(len(markov_model.p_transition[3]), 4)
        self.assertAlmostEqual(markov_model.p_transition[0][0], 0.2)
        self.assertAlmostEqual(markov_model.p_transition[0][1], 0.8)
        self.assertAlmostEqual(markov_model.p_transition[0][2], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[0][3], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[1][0], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[1][1], 0.5)
        self.assertAlmostEqual(markov_model.p_transition[1][2], 0.5)
        self.assertAlmostEqual(markov_model.p_transition[1][3], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[2][0], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[2][1], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[2][2], 0.5)
        self.assertAlmostEqual(markov_model.p_transition[2][3], 0.5)
        self.assertAlmostEqual(markov_model.p_transition[3][0], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[3][1], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[3][2], 0.0)
        self.assertAlmostEqual(markov_model.p_transition[3][3], 1.0)
        self.assertEqual(len(markov_model.p_emission), 4)
        self.assertEqual(len(markov_model.p_emission[0]), 4)
        self.assertEqual(len(markov_model.p_emission[1]), 4)
        self.assertEqual(len(markov_model.p_emission[2]), 4)
        self.assertEqual(len(markov_model.p_emission[3]), 4)
        self.assertAlmostEqual(markov_model.p_emission[0][0], 0.666667, places=4)
        self.assertAlmostEqual(markov_model.p_emission[0][1], 0.111111, places=4)
        self.assertAlmostEqual(markov_model.p_emission[0][2], 0.111111, places=4)
        self.assertAlmostEqual(markov_model.p_emission[0][3], 0.111111, places=4)
        self.assertAlmostEqual(markov_model.p_emission[1][0], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[1][1], 0.750000, places=4)
        self.assertAlmostEqual(markov_model.p_emission[1][2], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[1][3], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[2][0], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[2][1], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[2][2], 0.750000, places=4)
        self.assertAlmostEqual(markov_model.p_emission[2][3], 0.083333, places=4)
        self.assertAlmostEqual(markov_model.p_emission[3][0], 0.031250, places=4)
        self.assertAlmostEqual(markov_model.p_emission[3][1], 0.031250, places=4)
        self.assertAlmostEqual(markov_model.p_emission[3][2], 0.031250, places=4)
        self.assertAlmostEqual(markov_model.p_emission[3][3], 0.906250, places=4)

    def test_baum_welch(self):
        states = ["CP", "IP"]
        alphabet = ["cola", "ice_t", "lem"]
        outputs = [(2, 1, 0)]
        p_initial = [1.0, 0.0000001]
        p_transition = [[0.7, 0.3], [0.5, 0.5]]
        p_emission = [[0.6, 0.1, 0.3], [0.1, 0.7, 0.2]]
        N, M = len(states), len(alphabet)
        x = MarkovModel._baum_welch(
            N,
            M,
            outputs,
            p_initial=p_initial,
            p_transition=p_transition,
            p_emission=p_emission,
        )
        p_initial, p_transition, p_emission = x
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        self.assertEqual(markov_model.states, ["CP", "IP"])
        self.assertEqual(markov_model.alphabet, ["cola", "ice_t", "lem"])
        self.assertEqual(len(markov_model.p_initial), 2)
        self.assertAlmostEqual(markov_model.p_initial[0], 1.0, places=4)
        self.assertAlmostEqual(markov_model.p_initial[1], 0.0, places=4)
        self.assertEqual(len(markov_model.p_transition), 2)
        self.assertEqual(len(markov_model.p_transition[0]), 2)
        self.assertEqual(len(markov_model.p_transition[1]), 2)
        self.assertAlmostEqual(markov_model.p_transition[0][0], 0.02460365, places=4)
        self.assertAlmostEqual(markov_model.p_transition[0][1], 0.97539634, places=4)
        self.assertAlmostEqual(markov_model.p_transition[1][0], 1.0, places=4)
        self.assertAlmostEqual(markov_model.p_transition[1][1], 0.0, places=4)
        self.assertEqual(len(markov_model.p_emission), 2)
        self.assertEqual(len(markov_model.p_emission[0]), 3)
        self.assertEqual(len(markov_model.p_emission[1]), 3)
        self.assertAlmostEqual(markov_model.p_emission[0][0], 0.5)
        self.assertAlmostEqual(markov_model.p_emission[0][1], 0.0)
        self.assertAlmostEqual(markov_model.p_emission[0][2], 0.5)
        self.assertAlmostEqual(markov_model.p_emission[1][0], 0.0)
        self.assertAlmostEqual(markov_model.p_emission[1][1], 1.0)
        self.assertAlmostEqual(markov_model.p_emission[1][2], 0.0)

    # Do some tests from the topcoder competition.

    def test_topcoder1(self):
        # NNNN
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.90, 0.10], [0.20, 0.80]])
        p_emission = array([[0.30, 0.20, 0.30, 0.20], [0.10, 0.40, 0.10, 0.40]])
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        states = MarkovModel.find_states(markov_model, "TGCC")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(state_list, ["N", "N", "N", "N"])

    def test_topcoder2(self):
        # NNNRRRNNRRNRRN
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.56, 0.44], [0.25, 0.75]])
        p_emission = array([[0.04, 0.14, 0.62, 0.20], [0.39, 0.15, 0.04, 0.42]])
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        states = MarkovModel.find_states(markov_model, "CCTGAGTTAGTCGT")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(
            state_list,
            ["N", "N", "N", "R", "R", "R", "N", "N", "R", "R", "N", "R", "R", "N"],
        )

    def test_topcoder3(self):
        # NRRRRRRRRRRRNNNNRRRRRRRRR
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.75, 0.25], [0.25, 0.75]])
        p_emission = array([[0.45, 0.36, 0.06, 0.13], [0.24, 0.18, 0.12, 0.46]])
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        states = MarkovModel.find_states(markov_model, "CCGTACTTACCCAGGACCGCAGTCC")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(
            state_list,
            [
                "N",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "N",
                "N",
                "N",
                "N",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
                "R",
            ],
        )

    def test_topcoder4(self):
        # NRRRRRRRRRR
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.55, 0.45], [0.15, 0.85]])
        p_emission = array([[0.75, 0.03, 0.01, 0.21], [0.34, 0.11, 0.39, 0.16]])
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        states = MarkovModel.find_states(markov_model, "TTAGCAGTGCG")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(
            state_list, ["N", "R", "R", "R", "R", "R", "R", "R", "R", "R", "R"]
        )

    def test_topcoder5(self):
        # N
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.84, 0.16], [0.25, 0.75]])
        p_emission = array([[0.26, 0.37, 0.08, 0.29], [0.31, 0.13, 0.33, 0.23]])
        markov_model = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )
        states = MarkovModel.find_states(markov_model, "T")
        self.assertEqual(len(states), 1)
        state_list, state_float = states[0]
        self.assertEqual(state_list, ["N"])

    def test_readline_and_check_start(self):
        states = "NR"
        alphabet = "AGTC"
        markov_model = MarkovModel.MarkovModel(states, alphabet)

        line = "This is a \n string with two lines \n"
        handle = StringIO(line)
        start = "This is a \n"
        self.assertEqual(start, MarkovModel._readline_and_check_start(handle, start))

    def test_save_and_load(self):
        states = "NR"
        alphabet = "AGTC"
        p_initial = array([1.0, 0.0])
        p_transition = array([[0.75, 0.25], [0.25, 0.75]])
        p_emission = array([[0.45, 0.36, 0.06, 0.13], [0.24, 0.18, 0.12, 0.46]])
        markov_model_save = MarkovModel.MarkovModel(
            states, alphabet, p_initial, p_transition, p_emission
        )

        handle = StringIO()
        MarkovModel.save(markov_model_save, handle)
        handle.seek(0)
        markov_model_load = MarkovModel.load(handle)

        self.assertEqual("".join(markov_model_load.states), states)
        self.assertEqual("".join(markov_model_load.alphabet), alphabet)
        self.assertTrue(array_equal(markov_model_load.p_initial, p_initial))
        self.assertTrue(array_equal(markov_model_load.p_transition, p_transition))
        self.assertTrue(array_equal(markov_model_load.p_emission, p_emission))

    def test_train_bw(self):
        random.seed(0)
        states = ["0", "1", "2", "3"]
        alphabet = ["A", "C", "G", "T"]
        training_data = [
            "AACCCGGGTTTTTTT",
            "ACCGTTTTTTT",
            "ACGGGTTTTTT",
            "ACCGTTTTTTTT",
        ]

        output_p_initial = array([0.2275677, 0.29655611, 0.24993822, 0.22593797])
        output_p_transition = array(
            [
                [5.16919807e-001, 3.65825814e-033, 4.83080193e-001, 9.23220689e-042],
                [3.65130247e-001, 1.00000000e-300, 6.34869753e-001, 1.00000000e-300],
                [8.68776164e-001, 1.02254304e-034, 1.31223836e-001, 6.21835051e-047],
                [3.33333333e-301, 3.33333333e-001, 3.33333333e-301, 6.66666667e-001],
            ]
        )
        output_p_emission = array(
            [
                [2.02593570e-301, 2.02593570e-301, 2.02593570e-301, 1.00000000e000],
                [1.00000000e-300, 1.00000000e-300, 1.00000000e000, 1.09629016e-259],
                [3.26369779e-301, 3.26369779e-301, 3.26369779e-301, 1.00000000e000],
                [3.33333333e-001, 6.66666667e-001, 3.33333333e-301, 3.33333333e-301],
            ]
        )

        markov_model = MarkovModel.train_bw(states, alphabet, training_data)
        self.assertEqual("".join(markov_model.states), "".join(states))
        self.assertEqual("".join(markov_model.alphabet), "".join(alphabet))
        self.assertTrue(
            array_equal(
                around(markov_model.p_initial, decimals=3),
                around(output_p_initial, decimals=3),
            )
        )
        self.assertTrue(
            array_equal(
                around(markov_model.p_transition, decimals=3),
                around(output_p_transition, decimals=3),
            )
        )
        self.assertTrue(
            array_equal(
                around(markov_model.p_emission, decimals=3),
                around(output_p_emission, decimals=3),
            )
        )

    def test_forward(self):
        states = ["CP", "IP"]
        outputs = [2, 1, 0]
        lp_initial = log([1.0, 0.0000001])
        lp_transition = log([[0.7, 0.3], [0.5, 0.5]])
        lp_emission = log([[0.6, 0.1, 0.3], [0.1, 0.7, 0.2]])

        matrix = array(
            [
                [0.0, -1.5606477, -3.07477539, -3.84932984],
                [-16.11809565, -2.4079455, -3.27544608, -4.5847794],
            ]
        )
        self.assertTrue(
            array_equal(
                around(
                    MarkovModel._forward(
                        len(states),
                        len(outputs),
                        lp_initial,
                        lp_transition,
                        lp_emission,
                        outputs,
                    ),
                    decimals=3,
                ),
                around(matrix, decimals=3),
            )
        )

    def test_backward(self):
        states = ["CP", "IP"]
        outputs = [2, 1, 0]
        lp_transition = log([[0.7, 0.3], [0.5, 0.5]])
        lp_emission = log([[0.6, 0.1, 0.3], [0.1, 0.7, 0.2]])

        matrix = array(
            [
                [-3.45776773, -3.10109279, -0.51082562, 0.0],
                [-3.54045945, -1.40649707, -2.30258509, 0.0],
            ]
        )
        self.assertTrue(
            array_equal(
                around(
                    MarkovModel._backward(
                        len(states), len(outputs), lp_transition, lp_emission, outputs
                    ),
                    decimals=3,
                ),
                around(matrix, decimals=3),
            )
        )

    def test_mle(self):
        states = ["0", "1", "2", "3"]
        alphabet = ["A", "C", "G", "T"]
        training_data = [
            ("AACCCGGGTTTTTTT", "001112223333333"),
            ("ACCGTTTTTTT", "01123333333"),
            ("ACGGGTTTTTT", "01222333333"),
            ("ACCGTTTTTTTT", "011233333333"),
        ]
        training_outputs = array(
            [
                [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3],
                [0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3],
                [0, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3],
                [0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3],
            ]
        )
        training_states = array(
            [
                [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3],
                [0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3],
                [0, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3],
                [0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3],
            ]
        )

        p_initial = array([1.0, 0.0, 0.0, 0.0])
        p_transition = array(
            [
                [0.2, 0.8, 0.0, 0.0],
                [0.0, 0.5, 0.5, 0.0],
                [0.0, 0.0, 0.5, 0.5],
                [0.0, 0.0, 0.0, 1.0],
            ]
        )
        p_emission = array(
            [
                [0.66666667, 0.11111111, 0.11111111, 0.11111111],
                [0.08333333, 0.75, 0.08333333, 0.08333333],
                [0.08333333, 0.08333333, 0.75, 0.08333333],
                [0.03125, 0.03125, 0.03125, 0.90625],
            ]
        )
        p_initial_out, p_transition_out, p_emission_out = MarkovModel._mle(
            len(states),
            len(alphabet),
            training_outputs,
            training_states,
            None,
            None,
            None,
        )
        self.assertTrue(
            array_equal(
                around(p_initial_out, decimals=3), around(p_initial, decimals=3)
            )
        )
        self.assertTrue(
            array_equal(
                around(p_transition_out, decimals=3), around(p_transition, decimals=3)
            )
        )
        self.assertTrue(
            array_equal(
                around(p_emission_out, decimals=3), around(p_emission, decimals=3)
            )
        )

    def test_argmaxes(self):
        matrix = array([[4, 5, 6], [9, 7, 8], [1, 2, 3]])
        output = [3]
        self.assertEqual(len(MarkovModel._argmaxes(matrix)), len(output))
        self.assertEqual(MarkovModel._argmaxes(matrix)[0], output[0])

    def test_viterbi(self):
        states = ["CP", "IP"]
        outputs = [2, 1, 0]
        lp_initial = log([1.0, 0.0000001])
        lp_transition = log([[0.7, 0.3], [0.5, 0.5]])
        lp_emission = log([[0.6, 0.1, 0.3], [0.1, 0.7, 0.2]])

        output1 = [0, 1, 0]
        output2 = -3.968593356916541

        viterbi_output = MarkovModel._viterbi(
            len(states), lp_initial, lp_transition, lp_emission, outputs
        )
        self.assertEqual(len(viterbi_output[0][0]), 3)
        self.assertEqual(viterbi_output[0][0][0], output1[0])
        self.assertEqual(viterbi_output[0][0][1], output1[1])
        self.assertEqual(viterbi_output[0][0][2], output1[2])
        self.assertEqual(float("%.3f" % viterbi_output[0][1]), float("%.3f" % output2))

    def test_normalize_and_copy_and_check(self):
        matrix_in1 = array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]])
        matrix_in2 = array([1, 2, 3])

        matrix_out1 = array(
            [
                [0.16666667, 0.33333333, 0.5],
                [0.26666667, 0.33333333, 0.4],
                [0.29166667, 0.33333333, 0.375],
            ]
        )
        matrix_out2 = array([0.16666667, 0.33333333, 0.5])
        self.assertTrue(
            array_equal(
                around(MarkovModel._normalize(matrix_in1), decimals=3),
                around(matrix_out1, decimals=3),
            )
        )
        self.assertTrue(
            array_equal(
                around(MarkovModel._normalize(matrix_in2), decimals=3),
                around(matrix_out2, decimals=3),
            )
        )

        shape1 = (3, 3)
        shape2 = (3,)
        self.assertTrue(
            array_equal(
                around(MarkovModel._copy_and_check(matrix_out1, shape1), decimals=3),
                around(matrix_out1, decimals=3),
            )
        )
        self.assertTrue(
            array_equal(
                around(MarkovModel._copy_and_check(matrix_out2, shape2), decimals=3),
                around(matrix_out2, decimals=3),
            )
        )

    def test_uniform_norm(self):
        shape = (4, 3)
        matrix = array(
            [
                [0.33333333, 0.33333333, 0.33333333],
                [0.33333333, 0.33333333, 0.33333333],
                [0.33333333, 0.33333333, 0.33333333],
                [0.33333333, 0.33333333, 0.33333333],
            ]
        )
        self.assertTrue(
            array_equal(
                around(MarkovModel._uniform_norm(shape), decimals=3),
                around(matrix, decimals=3),
            )
        )

    def test_random_norm(self):
        random.seed(0)
        shape = (4, 3)
        matrix = array(
            [
                [0.29399155, 0.38311672, 0.32289173],
                [0.33750765, 0.26241723, 0.40007512],
                [0.1908342, 0.38890714, 0.42025866],
                [0.22501625, 0.46461061, 0.31037314],
            ]
        )
        self.assertTrue(
            array_equal(
                around(MarkovModel._random_norm(shape), decimals=3),
                around(matrix, decimals=3),
            )
        )

    def test_logsum_and_exp_logsum(self):
        matrix = array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]])
        matrix1 = array([1, 2, 3])

        output = 10.304721798
        output1 = 3.40760596444
        self.assertEqual(
            float("%.3f" % MarkovModel._logsum(matrix)), float("%.3f" % output)
        )
        self.assertEqual(
            float("%.3f" % MarkovModel._logsum(matrix1)), float("%.3f" % output1)
        )

        output2 = 29873.342245
        output3 = 30.1928748506
        self.assertEqual(
            float("%.3f" % MarkovModel._exp_logsum(matrix)), float("%.3f" % output2)
        )
        self.assertEqual(
            float("%.3f" % MarkovModel._exp_logsum(matrix1)), float("%.3f" % output3)
        )

    def test_logvecadd(self):
        vec1 = log(array([1, 2, 3, 4]))
        vec2 = log(array([5, 6, 7, 8]))

        sumvec = array([1.79175947, 2.07944154, 2.30258509, 2.48490665])
        self.assertTrue(
            array_equal(
                around(MarkovModel._logvecadd(vec1, vec2), decimals=3),
                around(sumvec, decimals=3),
            )
        )


if __name__ == "__main__":
    runner = unittest.TextTestRunner(verbosity=2)
    unittest.main(testRunner=runner)
