# Copyright (c) 2020 Chris Richardson
# FEniCS Project
# SPDX-License-Identifier: MIT

import basix
import numpy as np


def test_regge_tri():
    # Simplest element
    regge = basix.Regge("triangle", 1)

    # tabulate at origin
    pts = [[0.0, 0.0]]
    w = regge.tabulate(0, pts)[0]
    w = w.reshape((4, -1)).transpose().reshape(-1, 2, 2)

    ref = np.array([[[-0.,  0.5],
                     [0.5, -0.]],

                    [[0.,  0.5],
                     [0.5, -0.]],

                    [[-0.,  1.],
                     [1.,  2.]],

                    [[-0., -0.5],
                     [-0.5, -1.]],

                    [[2.,  1.],
                     [1., -0.]],

                    [[-1., -0.5],
                     [-0.5,  0.]],

                    [[-0.,  0.],
                     [0.,  0.]],

                    [[0., -0.],
                     [-0., -0.]],

                    [[-0., -1.5],
                     [-1.5,  0.]]])

    assert(np.isclose(ref, w).all())


def test_regge_tri2():
    # Second order
    regge = basix.Regge("triangle", 2)
    # tabulate at origin
    pts = [[0.0, 0.0]]
    w = regge.tabulate(0, pts)[0]
    w = w.reshape((4, -1)).transpose().reshape(-1, 2, 2)

    ref = np.array([[[0., -0.5],
                     [-0.5,  0.]],

                    [[0., -0.5],
                     [-0.5, -0.]],

                    [[-0., -0.5],
                     [-0.5,  0.]],

                    [[-0.,  1.5],
                     [1.5,  3.]],

                    [[0., -1.5],
                     [-1.5, -3.]],

                    [[-0.,  0.5],
                     [0.5,  1.]],

                    [[3.,  1.5],
                     [1.5, -0.]],

                    [[-3., -1.5],
                     [-1.5,  0.]],

                    [[1.,  0.5],
                     [0.5, -0.]],

                    [[-0., -0.],
                     [-0., -0.]],

                    [[0., -0.],
                     [-0., -0.]],

                    [[0., -3.],
                     [-3.,  0.]],

                    [[0., -0.],
                     [-0., -0.]],

                    [[-0., -0.],
                     [-0.,  0.]],

                    [[-0.,  2.],
                     [2., -0.]],

                    [[-0.,  0.],
                     [0.,  0.]],

                    [[0.,  0.],
                     [0., -0.]],

                    [[0.,  2.],
                     [2., -0.]]])
    assert(np.isclose(ref, w).all())


def test_regge_tet():
    # Simplest element
    regge = basix.Regge("tetrahedron", 1)
    # tabulate at origin
    pts = [[0.0, 0.0, 0.0]]
    w = regge.tabulate(0, pts)[0]
    w = w.reshape((9, -1)).transpose().reshape(-1, 3, 3)

    ref = np.array([[[0.,  0.,  0.],
                     [0.,  0.,  0.5],
                     [0.,  0.5, -0.]],

                    [[-0.,  0., -0.],
                     [0., -0.,  0.5],
                     [-0.,  0.5,  0.]],

                    [[0., -0.,  0.5],
                     [-0.,  0.,  0.],
                     [0.5,  0.,  0.]],

                    [[-0.,  0.,  0.5],
                     [0., -0.,  0.],
                     [0.5,  0.,  0.]],

                    [[-0.,  0.5,  0.],
                     [0.5, -0., -0.],
                     [0., -0.,  0.]],

                    [[0.,  0.5,  0.],
                     [0.5, -0.,  0.],
                     [0.,  0.,  0.]],

                    [[0.,  0.,  1.],
                     [0.,  0.,  1.],
                     [1.,  1.,  2.]],

                    [[0., -0., -0.5],
                     [-0.,  0., -0.5],
                     [-0.5, -0.5, -1.]],

                    [[0.,  1., -0.],
                     [1.,  2.,  1.],
                     [-0.,  1., -0.]],

                    [[-0., -0.5, -0.],
                     [-0.5, -1., -0.5],
                     [-0., -0.5, -0.]],

                    [[2.,  1.,  1.],
                     [1., -0.,  0.],
                     [1.,  0.,  0.]],

                    [[-1., -0.5, -0.5],
                     [-0.5, -0.,  0.],
                     [-0.5,  0., -0.]],

                    [[-0.,  0., -0.],
                     [0., -0., -0.],
                     [-0., -0., -0.]],

                    [[-0., -0., -0.],
                     [-0., -0., -0.],
                     [-0., -0., -0.]],

                    [[-0.,  0., -0.],
                     [0.,  0., -0.],
                     [-0., -0., -0.]],

                    [[0., -0.,  0.],
                     [-0., -0., -0.],
                     [0., -0.,  0.]],

                    [[-0., -0., -0.],
                     [-0., -0.,  0.],
                     [-0.,  0.,  0.]],

                    [[-0., -0., -0.],
                     [-0., -0., -1.5],
                     [-0., -1.5, -0.]],

                    [[0., -0.,  0.],
                     [-0.,  0., -0.],
                     [0., -0.,  0.]],

                    [[0., -0., -0.],
                     [-0., -0., -0.],
                     [-0., -0.,  0.]],

                    [[-0.,  0., -1.5],
                     [0., -0., -0.],
                     [-1.5, -0., -0.]],

                    [[-0., -0., -0.],
                     [-0.,  0., -0.],
                     [-0., -0., -0.]],

                    [[0., -0., -0.],
                     [-0., -0., -0.],
                     [-0., -0.,  0.]],

                    [[-0., -1.5, -0.],
                     [-1.5,  0., -0.],
                     [-0., -0., -0.]]])
    assert(np.isclose(ref, w).all())
