# Copyright (c) DataLab Platform Developers, BSD 3-Clause license, see LICENSE file.

"""
.. Coordinates Algorithms (see parent package :mod:`sigima.tools`)
"""

# pylint: disable=invalid-name  # Allows short reference names like x, y, ...

from __future__ import annotations

from typing import Literal

import numpy as np

from sigima.tools.checks import check_1d_arrays


def circle_to_diameter(
    xc: float, yc: float, r: float
) -> tuple[float, float, float, float]:
    """Convert circle center and radius to X diameter coordinates

    Args:
        xc: Circle center X coordinate
        yc: Circle center Y coordinate
        r: Circle radius

    Returns:
        tuple: Circle X diameter coordinates
    """
    return xc - r, yc, xc + r, yc


def array_circle_to_diameter(data: np.ndarray) -> np.ndarray:
    """Convert circle center and radius to X diameter coordinates (array version)

    Args:
        data: Circle center and radius, in the form of a 2D array (N, 3)

    Returns:
        Circle X diameter coordinates, in the form of a 2D array (N, 4)
    """
    xc, yc, r = data[:, 0], data[:, 1], data[:, 2]
    x_start = xc - r
    x_end = xc + r
    result = np.column_stack((x_start, yc, x_end, yc)).astype(float)
    return result


def circle_to_center_radius(
    x0: float, y0: float, x1: float, y1: float
) -> tuple[float, float, float]:
    """Convert circle X diameter coordinates to center and radius

    Args:
        x0: Diameter start X coordinate
        y0: Diameter start Y coordinate
        x1: Diameter end X coordinate
        y1: Diameter end Y coordinate

    Returns:
        tuple: Circle center and radius
    """
    xc, yc = (x0 + x1) / 2, (y0 + y1) / 2
    r = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) / 2
    return xc, yc, r


def array_circle_to_center_radius(data: np.ndarray) -> np.ndarray:
    """Convert circle X diameter coordinates to center and radius (array version)

    Args:
        data: Circle X diameter coordinates, in the form of a 2D array (N, 4)

    Returns:
        Circle center and radius, in the form of a 2D array (N, 3)
    """
    x0, y0, x1, y1 = data[:, 0], data[:, 1], data[:, 2], data[:, 3]
    xc, yc = (x0 + x1) / 2, (y0 + y1) / 2
    r = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) / 2
    result = np.column_stack((xc, yc, r)).astype(float)
    return result


def ellipse_to_diameters(
    xc: float, yc: float, a: float, b: float, theta: float
) -> tuple[float, float, float, float, float, float, float, float]:
    """Convert ellipse center, axes and angle to X/Y diameters coordinates

    Args:
        xc: Ellipse center X coordinate
        yc: Ellipse center Y coordinate
        a: Ellipse half larger axis
        b: Ellipse half smaller axis
        theta: Ellipse angle

    Returns:
        Ellipse X/Y diameters (major/minor axes) coordinates
    """
    dxa, dya = a * np.cos(theta), a * np.sin(theta)
    dxb, dyb = b * np.sin(theta), b * np.cos(theta)
    x0, y0, x1, y1 = xc - dxa, yc - dya, xc + dxa, yc + dya
    x2, y2, x3, y3 = xc - dxb, yc - dyb, xc + dxb, yc + dyb
    return x0, y0, x1, y1, x2, y2, x3, y3


def array_ellipse_to_diameters(data: np.ndarray) -> np.ndarray:
    """Convert ellipse center, axes and angle to X/Y diameters coordinates
    (array version)

    Args:
        data: Ellipse center, axes and angle, in the form of a 2D array (N, 5)

    Returns:
        Ellipse X/Y diameters (major/minor axes) coordinates,
         in the form of a 2D array (N, 8)
    """
    xc, yc, a, b, theta = data[:, 0], data[:, 1], data[:, 2], data[:, 3], data[:, 4]
    dxa, dya = a * np.cos(theta), a * np.sin(theta)
    dxb, dyb = b * np.sin(theta), b * np.cos(theta)
    x0, y0, x1, y1 = xc - dxa, yc - dya, xc + dxa, yc + dya
    x2, y2, x3, y3 = xc - dxb, yc - dyb, xc + dxb, yc + dyb
    result = np.column_stack((x0, y0, x1, y1, x2, y2, x3, y3)).astype(float)
    return result


def ellipse_to_center_axes_angle(
    x0: float,
    y0: float,
    x1: float,
    y1: float,
    x2: float,
    y2: float,
    x3: float,
    y3: float,
) -> tuple[float, float, float, float, float]:
    """Convert ellipse X/Y diameters coordinates to center, axes and angle

    Args:
        x0: major axis start X coordinate
        y0: major axis start Y coordinate
        x1: major axis end X coordinate
        y1: major axis end Y coordinate
        x2: minor axis start X coordinate
        y2: minor axis start Y coordinate
        x3: minor axis end X coordinate
        y3: minor axis end Y coordinate

    Returns:
        Ellipse center, axes and angle
    """
    xc, yc = (x0 + x1) / 2, (y0 + y1) / 2
    a = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) / 2
    b = np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) / 2
    theta = np.arctan2(y1 - y0, x1 - x0)
    return xc, yc, a, b, theta


def array_ellipse_to_center_axes_angle(data: np.ndarray) -> np.ndarray:
    """Convert ellipse X/Y diameters coordinates to center, axes and angle
    (array version)

    Args:
        data: Ellipse X/Y diameters coordinates, in the form of a 2D array (N, 8)

    Returns:
        Ellipse center, axes and angle, in the form of a 2D array (N, 5)
    """
    x0, y0, x1, y1, x2, y2, x3, y3 = (
        data[:, 0],
        data[:, 1],
        data[:, 2],
        data[:, 3],
        data[:, 4],
        data[:, 5],
        data[:, 6],
        data[:, 7],
    )
    xc, yc = (x0 + x1) / 2, (y0 + y1) / 2
    a = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) / 2
    b = np.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) / 2
    theta = np.arctan2(y1 - y0, x1 - x0)
    result = np.column_stack((xc, yc, a, b, theta)).astype(float)
    return result


@check_1d_arrays
def to_polar(
    x: np.ndarray, y: np.ndarray, unit: Literal["°", "rad"] = "rad"
) -> tuple[np.ndarray, np.ndarray]:
    """Convert Cartesian coordinates to polar coordinates.

    Args:
        x: Cartesian x-coordinate.
        y: Cartesian y-coordinate.
        unit: Unit of the angle ("°" or "rad").

    Returns:
        Polar coordinates (r, theta) where r is the radius and theta is the angle.

    Raises:
        ValueError: If the unit is not "°" or "rad".
    """
    if unit not in ["rad", "°"]:
        raise ValueError(f"Unit must be radian ('rad') or degree ('°'), got {unit}.")
    r = np.sqrt(x**2 + y**2)
    theta = np.arctan2(y, x)
    if unit == "°":
        theta = np.rad2deg(theta)
    return r, theta


@check_1d_arrays
def to_cartesian(
    r: np.ndarray, theta: np.ndarray, unit: Literal["°", "rad"] = "rad"
) -> tuple[np.ndarray, np.ndarray]:
    """Convert polar coordinates to Cartesian coordinates.

    Args:
        r: Polar radius.
        theta: Polar angle.
        unit: Unit of the angle ("°" or "rad").

    Returns:
        Cartesian coordinates (x, y) where x is the x-coordinate and y is the
        y-coordinate.

    Raises:
        ValueError: If the unit is not "°" or "rad".
        ValueError: If any value of the radius is negative.
    """
    if unit not in ["rad", "°"]:
        raise ValueError(f"Unit must be radian ('rad') or degree ('°'), got {unit}.")
    if np.any(r < 0.0):
        raise ValueError("Negative radius values are not allowed.")
    if unit == "°":
        theta = np.deg2rad(theta)
    x = r * np.cos(theta)
    y = r * np.sin(theta)
    return x, y


def rotate(angle: float) -> np.ndarray:
    """Return rotation matrix

    Args:
        angle: Rotation angle (in radians)

    Returns:
        Rotation matrix
    """
    cos_a = np.cos(angle)
    sin_a = np.sin(angle)
    return np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]], dtype=float)


def colvector(x: float, y: float) -> np.ndarray:
    """Return vector from coordinates

    Args:
        x: x-coordinate
        y: y-coordinate

    Returns:
        Vector
    """
    return np.array([x, y, 1]).T


def vector_rotation(theta: float, dx: float, dy: float) -> tuple[float, float]:
    """Compute theta-rotation on vector

    Args:
        theta: Rotation angle
        dx: x-coordinate of vector
        dy: y-coordinate of vector

    Returns:
        Tuple of (x, y) coordinates of rotated vector
    """
    return (rotate(theta) @ colvector(dx, dy)).ravel()[:2]


@check_1d_arrays(x_require_1d=False, y_require_1d=False)
def polar_to_complex(
    r: np.ndarray, theta: np.ndarray, unit: Literal["°", "rad"] = "rad"
) -> np.ndarray:
    """Convert polar coordinates to complex number.

    Args:
        r: Polar radius.
        theta: Polar angle.
        unit: Unit of the angle ("°" or "rad").

    Returns:
        Complex numbers corresponding to the polar coordinates.

    Raises:
        ValueError: If the unit is not "°" or "rad".
        ValueError: If any value of the radius is negative.
    """
    if unit not in ["rad", "°"]:
        raise ValueError(f"Unit must be radian ('rad') or degree ('°'), got {unit}.")
    if np.any(r < 0.0):
        raise ValueError("Negative radius values are not allowed.")
    if unit == "°":
        theta = np.deg2rad(theta)
    return r * np.exp(1j * theta)
