# Copyright (C) 2004, Thomas Hamelryck (thamelry@binf.ku.dk)
# This code is part of the Biopython distribution and governed by its
# license.  Please see the LICENSE file that should have been included
# as part of this package.

"""Vector class, including rotation-related functions."""

from __future__ import print_function

import numpy


def m2rotaxis(m):
    """
    Return angles, axis pair that corresponds to rotation matrix m.
    """
    # Angle always between 0 and pi
    # Sense of rotation is defined by axis orientation
    t=0.5*(numpy.trace(m)-1)
    t=max(-1, t)
    t=min(1, t)
    angle=numpy.arccos(t)
    if angle<1e-15:
        # Angle is 0
        return 0.0, Vector(1, 0, 0)
    elif angle<numpy.pi:
        # Angle is smaller than pi
        x=m[2, 1]-m[1, 2]
        y=m[0, 2]-m[2, 0]
        z=m[1, 0]-m[0, 1]
        axis=Vector(x, y, z)
        axis.normalize()
        return angle, axis
    else:
        # Angle is pi - special case!
        m00=m[0, 0]
        m11=m[1, 1]
        m22=m[2, 2]
        if m00>m11 and m00>m22:
            x=numpy.sqrt(m00-m11-m22+0.5)
            y=m[0, 1]/(2*x)
            z=m[0, 2]/(2*x)
        elif m11>m00 and m11>m22:
            y=numpy.sqrt(m11-m00-m22+0.5)
            x=m[0, 1]/(2*y)
            z=m[1, 2]/(2*y)
        else:
            z=numpy.sqrt(m22-m00-m11+0.5)
            x=m[0, 2]/(2*z)
            y=m[1, 2]/(2*z)
        axis=Vector(x, y, z)
        axis.normalize()
        return numpy.pi, axis


def vector_to_axis(line, point):
    """
    Returns the vector between a point and
    the closest point on a line (ie. the perpendicular
    projection of the point on the line).

    @type line: L{Vector}
    @param line: vector defining a line

    @type point: L{Vector}
    @param point: vector defining the point
    """
    line=line.normalized()
    np=point.norm()
    angle=line.angle(point)
    return point-line**(np*numpy.cos(angle))


def rotaxis2m(theta, vector):
    """
    Calculate a left multiplying rotation matrix that rotates
    theta rad around vector.

    Example:

        >>> m=rotaxis(pi, Vector(1, 0, 0))
        >>> rotated_vector=any_vector.left_multiply(m)

    @type theta: float
    @param theta: the rotation angle


    @type vector: L{Vector}
    @param vector: the rotation axis

    @return: The rotation matrix, a 3x3 Numeric array.
    """
    vector=vector.copy()
    vector.normalize()
    c=numpy.cos(theta)
    s=numpy.sin(theta)
    t=1-c
    x, y, z=vector.get_array()
    rot=numpy.zeros((3, 3))
    # 1st row
    rot[0, 0]=t*x*x+c
    rot[0, 1]=t*x*y-s*z
    rot[0, 2]=t*x*z+s*y
    # 2nd row
    rot[1, 0]=t*x*y+s*z
    rot[1, 1]=t*y*y+c
    rot[1, 2]=t*y*z-s*x
    # 3rd row
    rot[2, 0]=t*x*z-s*y
    rot[2, 1]=t*y*z+s*x
    rot[2, 2]=t*z*z+c
    return rot

rotaxis=rotaxis2m


def refmat(p, q):
    """
    Return a (left multiplying) matrix that mirrors p onto q.

    Example:
        >>> mirror=refmat(p, q)
        >>> qq=p.left_multiply(mirror)
        >>> print(q)
        >>> print(qq) # q and qq should be the same

    @type p,q: L{Vector}
    @return: The mirror operation, a 3x3 Numeric array.
    """
    p.normalize()
    q.normalize()
    if (p-q).norm()<1e-5:
        return numpy.identity(3)
    pq=p-q
    pq.normalize()
    b=pq.get_array()
    b.shape=(3, 1)
    i=numpy.identity(3)
    ref=i-2*numpy.dot(b, numpy.transpose(b))
    return ref


def rotmat(p, q):
    """
    Return a (left multiplying) matrix that rotates p onto q.

    Example:
        >>> r=rotmat(p, q)
        >>> print(q)
        >>> print(p.left_multiply(r))

    @param p: moving vector
    @type p: L{Vector}

    @param q: fixed vector
    @type q: L{Vector}

    @return: rotation matrix that rotates p onto q
    @rtype: 3x3 Numeric array
    """
    rot=numpy.dot(refmat(q, -p), refmat(p, -p))
    return rot


def calc_angle(v1, v2, v3):
    """
    Calculate the angle between 3 vectors
    representing 3 connected points.

    @param v1, v2, v3: the tree points that define the angle
    @type v1, v2, v3: L{Vector}

    @return: angle
    @rtype: float
    """
    v1=v1-v2
    v3=v3-v2
    return v1.angle(v3)


def calc_dihedral(v1, v2, v3, v4):
    """
    Calculate the dihedral angle between 4 vectors
    representing 4 connected points. The angle is in
    ]-pi, pi].

    @param v1, v2, v3, v4: the four points that define the dihedral angle
    @type v1, v2, v3, v4: L{Vector}
    """
    ab=v1-v2
    cb=v3-v2
    db=v4-v3
    u=ab**cb
    v=db**cb
    w=u**v
    angle=u.angle(v)
    # Determine sign of angle
    try:
        if cb.angle(w)>0.001:
            angle=-angle
    except ZeroDivisionError:
        # dihedral=pi
        pass
    return angle


class Vector(object):
    "3D vector"

    def __init__(self, x, y=None, z=None):
        if y is None and z is None:
            # Array, list, tuple...
            if len(x)!=3:
                raise ValueError("Vector: x is not a "
                                 "list/tuple/array of 3 numbers")
            self._ar=numpy.array(x, 'd')
        else:
            # Three numbers
            self._ar=numpy.array((x, y, z), 'd')

    def __repr__(self):
        x, y, z=self._ar
        return "<Vector %.2f, %.2f, %.2f>" % (x, y, z)

    def __neg__(self):
        "Return Vector(-x, -y, -z)"
        a=-self._ar
        return Vector(a)

    def __add__(self, other):
        "Return Vector+other Vector or scalar"
        if isinstance(other, Vector):
            a=self._ar+other._ar
        else:
            a=self._ar+numpy.array(other)
        return Vector(a)

    def __sub__(self, other):
        "Return Vector-other Vector or scalar"
        if isinstance(other, Vector):
            a=self._ar-other._ar
        else:
            a=self._ar-numpy.array(other)
        return Vector(a)

    def __mul__(self, other):
        "Return Vector.Vector (dot product)"
        return sum(self._ar*other._ar)

    def __div__(self, x):
        "Return Vector(coords/a)"
        a=self._ar/numpy.array(x)
        return Vector(a)

    def __pow__(self, other):
        "Return VectorxVector (cross product) or Vectorxscalar"
        if isinstance(other, Vector):
            a, b, c=self._ar
            d, e, f=other._ar
            c1=numpy.linalg.det(numpy.array(((b, c), (e, f))))
            c2=-numpy.linalg.det(numpy.array(((a, c), (d, f))))
            c3=numpy.linalg.det(numpy.array(((a, b), (d, e))))
            return Vector(c1, c2, c3)
        else:
            a=self._ar*numpy.array(other)
            return Vector(a)

    def __getitem__(self, i):
        return self._ar[i]

    def __setitem__(self, i, value):
        self._ar[i]=value

    def __contains__(self, i):
        return (i in self._ar)

    def norm(self):
        "Return vector norm"
        return numpy.sqrt(sum(self._ar*self._ar))

    def normsq(self):
        "Return square of vector norm"
        return abs(sum(self._ar*self._ar))

    def normalize(self):
        "Normalize the Vector"
        self._ar=self._ar/self.norm()

    def normalized(self):
        "Return a normalized copy of the Vector"
        v=self.copy()
        v.normalize()
        return v

    def angle(self, other):
        "Return angle between two vectors"
        n1=self.norm()
        n2=other.norm()
        c=(self*other)/(n1*n2)
        # Take care of roundoff errors
        c=min(c, 1)
        c=max(-1, c)
        return numpy.arccos(c)

    def get_array(self):
        "Return (a copy of) the array of coordinates"
        return numpy.array(self._ar)

    def left_multiply(self, matrix):
        "Return Vector=Matrix x Vector"
        a=numpy.dot(matrix, self._ar)
        return Vector(a)

    def right_multiply(self, matrix):
        "Return Vector=Vector x Matrix"
        a=numpy.dot(self._ar, matrix)
        return Vector(a)

    def copy(self):
        "Return a deep copy of the Vector"
        return Vector(self._ar)

if __name__=="__main__":

    from numpy.random import random

    v1=Vector(0, 0, 1)
    v2=Vector(0, 0, 0)
    v3=Vector(0, 1, 0)
    v4=Vector(1, 1, 0)

    v4.normalize()

    print(v4)

    print(calc_angle(v1, v2, v3))
    dih=calc_dihedral(v1, v2, v3, v4)
    # Test dihedral sign
    assert(dih>0)
    print("DIHEDRAL %f" % dih)

    ref=refmat(v1, v3)
    rot=rotmat(v1, v3)

    print(v3)
    print(v1.left_multiply(ref))
    print(v1.left_multiply(rot))
    print(v1.right_multiply(numpy.transpose(rot)))

    # -
    print(v1-v2)
    print(v1-1)
    print(v1+(1, 2, 3))
    # +
    print(v1+v2)
    print(v1+3)
    print(v1-(1, 2, 3))
    # *
    print(v1*v2)
    # /
    print(v1/2)
    print(v1/(1, 2, 3))
    # **
    print(v1**v2)
    print(v1**2)
    print(v1**(1, 2, 3))
    # norm
    print(v1.norm())
    # norm squared
    print(v1.normsq())
    # setitem
    v1[2]=10
    print(v1)
    # getitem
    print(v1[2])

    print(numpy.array(v1))

    print("ROT")

    angle=random()*numpy.pi
    axis=Vector(random(3)-random(3))
    axis.normalize()

    m=rotaxis(angle, axis)

    cangle, caxis=m2rotaxis(m)

    print(angle-cangle)
    print(axis-caxis)
    print("")

