# -*- coding: utf-8 -*-
# Copyright (c) Vispy Development Team. All Rights Reserved.
# Distributed under the (new) BSD License. See LICENSE.txt for more info.

import numpy as np
from numpy.testing import assert_allclose, assert_equal

from vispy.visuals.graphs.layouts import get_layout
from vispy.testing import (run_tests_if_main, assert_raises)


adjacency_mat = np.array([
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 1, 0, 1, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
])


def test_get_layout():
    from vispy.visuals.graphs.layouts.random import random

    # Simple retrieval
    assert_equal(random, get_layout('random'))

    # Pass arguments
    fruchterman_reingold = get_layout('force_directed', iterations=100)
    assert_equal(fruchterman_reingold.iterations, 100)

    # Check if layout exists
    assert_raises(KeyError, get_layout, 'fdgdfgs_non_existent')


def test_random_layout():
    layout = get_layout('random')

    expected_pos = np.array([
        [0.22270932715095826, 0.7728936927702302],
        [0.6298054094517744, 0.21851589821484974],
        [0.75002099889163, 0.5592076821676369],
        [0.1786754307911973, 0.6442165368790972],
        [0.5979199081208609, 0.615159318836822],
        [0.46328431255222746, 0.3582897386994869],
        [0.9595461883180398, 0.2350580044144016],
        [0.094482942129406, 0.20584398882694932],
        [0.5758593091748346, 0.8158957494444451],
        [0.5908647616961652, 0.1584550825482285]
    ])

    expected_vertices = np.array([
        [0.22270932715095826, 0.7728936927702302],
        [0.6298054094517744, 0.21851589821484974],
        [0.6298054094517744, 0.21851589821484974],
        [0.22270932715095826, 0.7728936927702302],
        [0.6298054094517744, 0.21851589821484974],
        [0.5979199081208609, 0.615159318836822],
        [0.6298054094517744, 0.21851589821484974],
        [0.9595461883180398, 0.2350580044144016],
        [0.6298054094517744, 0.21851589821484974],
        [0.094482942129406, 0.20584398882694932],
        [0.1786754307911973, 0.6442165368790972],
        [0.5758593091748346, 0.8158957494444451],
        [0.5979199081208609, 0.615159318836822],
        [0.6298054094517744, 0.21851589821484974],
        [0.9595461883180398, 0.2350580044144016],
        [0.6298054094517744, 0.21851589821484974],
        [0.094482942129406, 0.20584398882694932],
        [0.6298054094517744, 0.21851589821484974],
        [0.5758593091748346, 0.8158957494444451],
        [0.1786754307911973, 0.6442165368790972],
        [0.5758593091748346, 0.8158957494444451],
        [0.5908647616961652, 0.1584550825482285],
        [0.5908647616961652, 0.1584550825482285],
        [0.5758593091748346, 0.8158957494444451]
    ])

    pos, line_vertices, arrows = next(layout(adjacency_mat,
                                             random_state=0xDEADBEEF))

    assert_allclose(pos, expected_pos, atol=1e-7)
    assert_allclose(line_vertices, expected_vertices, atol=1e-7)


def test_circular_layout():
    layout = get_layout('circular')

    expected_pos = np.array([
        [1.0, 0.5],
        [0.9045084714889526, 0.7938926219940186],
        [0.6545084714889526, 0.9755282402038574],
        [0.3454914689064026, 0.9755282402038574],
        [0.09549146890640259, 0.7938926219940186],
        [0.0, 0.4999999701976776],
        [0.09549152851104736, 0.20610731840133667],
        [0.3454914689064026, 0.024471759796142578],
        [0.6545085906982422, 0.024471759796142578],
        [0.9045084714889526, 0.20610734820365906]
    ])

    expected_vertices = np.array([
        [1.0, 0.5],
        [0.9045084714889526, 0.7938926219940186],
        [0.9045084714889526, 0.7938926219940186],
        [1.0, 0.5],
        [0.9045084714889526, 0.7938926219940186],
        [0.09549146890640259, 0.7938926219940186],
        [0.9045084714889526, 0.7938926219940186],
        [0.09549152851104736, 0.20610731840133667],
        [0.9045084714889526, 0.7938926219940186],
        [0.3454914689064026, 0.024471759796142578],
        [0.3454914689064026, 0.9755282402038574],
        [0.6545085906982422, 0.024471759796142578],
        [0.09549146890640259, 0.7938926219940186],
        [0.9045084714889526, 0.7938926219940186],
        [0.09549152851104736, 0.20610731840133667],
        [0.9045084714889526, 0.7938926219940186],
        [0.3454914689064026, 0.024471759796142578],
        [0.9045084714889526, 0.7938926219940186],
        [0.6545085906982422, 0.024471759796142578],
        [0.3454914689064026, 0.9755282402038574],
        [0.6545085906982422, 0.024471759796142578],
        [0.9045084714889526, 0.20610734820365906],
        [0.9045084714889526, 0.20610734820365906],
        [0.6545085906982422, 0.024471759796142578]
    ])

    pos, line_vertices, arrows = next(layout(adjacency_mat))

    assert_allclose(pos, expected_pos, atol=1e-4)
    assert_allclose(line_vertices, expected_vertices, atol=1e-4)


run_tests_if_main()
