import collections
import typing
import uuid

from qtpy.QtCore import QObject, QPointF, QSizeF

from .base import NodeBase, Serializable
from .enums import ReactToConnectionState
from .node_data import NodeData, NodeDataModel, NodeDataType
from .node_geometry import NodeGeometry
from .node_graphics_object import NodeGraphicsObject
from .node_state import NodeState
from .port import Port, PortType
from .style import NodeStyle


class Node(QObject, Serializable, NodeBase):
    def __init__(self, data_model: NodeDataModel):
        '''
        A single Node in the scene

        Parameters
        ----------
        data_model : NodeDataModel
        '''
        super().__init__()
        self._model = data_model
        self._uid = str(uuid.uuid4())
        self._style = data_model.node_style
        self._state = NodeState(self)
        self._geometry = NodeGeometry(self)
        self._graphics_obj = None
        self._geometry.recalculate_size()

        # propagate data: model => node
        self._model.data_updated.connect(self._on_port_index_data_updated)
        self._model.embedded_widget_size_updated.connect(self.on_node_size_updated)

    def __hash__(self):
        return id(self._uid)

    def __eq__(self, node):
        try:
            return node.id == self.id and self.model is node.model
        except AttributeError:
            return False

    def has_any_connection(self, node: 'Node') -> bool:
        """
        Is this node connected to `node` through any port?

        Parameters
        ----------
        node : Node
            The node to check connectivity

        Returns
        -------
        connected : bool
        """
        return any(self.has_connection_by_port_type(node, port_type)
                   for port_type in PortType)

    def has_connection_by_port_type(self, target: 'Node',
                                    port_type: PortType) -> bool:
        """
        Is this node connected to `target` through an input/output port?

        Parameters
        ----------
        target : Node
            The target node to check connectivity
        port_type : PortType
            The port type (``PortType.input``, ``PortType.output``) to check

        Returns
        -------
        connected : bool
        """
        return any(
            path[-1] == target
            for path in self.walk_paths_by_port_type(port_type)
        )

    def walk_paths_by_port_type(
            self, port_type: PortType) -> typing.Iterable['Node']:
        """
        Yields paths to connected nodes by port type

        Yields
        ------
        node_path : tuple
            The path to the node
        """
        seen = set([None])
        pending = collections.deque([([], self)])

        if port_type == PortType.output:
            def get_connection_nodes(state):
                for con in state.output_connections:
                    yield con.input_node
        elif port_type == PortType.input:
            def get_connection_nodes(state):
                for con in state.input_connections:
                    yield con.output_node
        else:
            raise ValueError(f'Unexpected port_type {port_type}')

        while pending:
            node_path, node = pending.popleft()
            seen.add(node)
            if node is not self:
                yield tuple(node_path) + (node, )

            node_path = list(node_path) + [node]
            for node in get_connection_nodes(node.state):
                if node not in seen:
                    pending.append((node_path, node))

    def __getitem__(self, key):
        return self._state[key]

    def _cleanup(self):
        if self._graphics_obj is not None:
            self._graphics_obj._cleanup()
            self._graphics_obj = None
            self._geometry = None

    def __getstate__(self) -> dict:
        """
        Save

        Returns
        -------
        value : dict
        """
        return {
            "id": self._uid,
            "model": self._model.__getstate__(),
            "position": {"x": self._graphics_obj.pos().x(),
                         "y": self._graphics_obj.pos().y()}
        }

    def __setstate__(self, state: dict):
        """
        Restore

        Parameters
        ----------
        state : dict
        """
        self._uid = state["id"]
        if self._graphics_obj:
            pos = state["position"]
            self.position = (pos["x"], pos["y"])
        self._model.__setstate__(state["model"])

    @property
    def id(self) -> str:
        """
        Node unique identifier (uuid)

        Returns
        -------
        value : str
        """
        return self._uid

    def react_to_possible_connection(self, reacting_port_type: PortType,
                                     reacting_data_type: NodeDataType,
                                     scene_point: QPointF
                                     ):
        """
        React to possible connection

        Parameters
        ----------
        port_type : PortType
        node_data_type : NodeDataType
        scene_point : QPointF
        """
        transform = self._graphics_obj.sceneTransform()
        inverted, invertible = transform.inverted()
        if invertible:
            pos = inverted.map(scene_point)
            self._geometry.dragging_position = pos
        self._graphics_obj.update()
        self._state.set_reaction(ReactToConnectionState.reacting,
                                 reacting_port_type, reacting_data_type)

    def reset_reaction_to_connection(self):
        self._state.set_reaction(ReactToConnectionState.not_reacting)
        self._graphics_obj.update()

    @property
    def graphics_object(self) -> NodeGraphicsObject:
        """
        Node graphics object

        Returns
        -------
        value : NodeGraphicsObject
        """
        return self._graphics_obj

    @graphics_object.setter
    def graphics_object(self, graphics: NodeGraphicsObject):
        """
        Set graphics object

        Parameters
        ----------
        graphics : NodeGraphicsObject
        """
        self._graphics_obj = graphics
        self._geometry.recalculate_size()

    @property
    def geometry(self) -> NodeGeometry:
        """
        Node geometry

        Returns
        -------
        value : NodeGeometry
        """
        return self._geometry

    @property
    def model(self) -> NodeDataModel:
        """
        Node data model

        Returns
        -------
        value : NodeDataModel
        """
        return self._model

    def propagate_data(self, node_data: NodeData, input_port: Port):
        """
        Propagates incoming data to the underlying model.

        Parameters
        ----------
        node_data : NodeData
        input_port : int
        """
        if input_port.node is not self:
            raise ValueError('Port does not belong to this Node')
        elif input_port.port_type != PortType.input:
            raise ValueError('Port is not an input port')

        self._model.set_in_data(node_data, input_port)

        # Recalculate the nodes visuals. A data change can result in the node
        # taking more space than before, so self forces a recalculate+repaint
        # on the affected node
        self._graphics_obj.set_geometry_changed()
        self._geometry.recalculate_size()
        self._graphics_obj.update()
        self._graphics_obj.move_connections()

    def _on_port_index_data_updated(self, port_index: int):
        """
        Data has been updated on this Node's output port port_index;
        propagate it to any connections.

        Parameters
        ----------
        index : int
        """
        port = self[PortType.output][port_index]
        self.on_data_updated(port)

    def on_data_updated(self, port: Port):
        """
        Fetches data from model's output port and propagates it along the
        connection

        Parameters
        ----------
        port : Port
        """
        node_data = port.data
        for conn in port.connections:
            conn.propagate_data(node_data)

    def on_node_size_updated(self):
        """
        update the graphic part if the size of the embeddedwidget changes
        """
        widget = self.model.embedded_widget()
        if widget:
            widget.adjustSize()

        self.geometry.recalculate_size()
        for conn in self.state.all_connections:
            conn.graphics_object.move()

    @property
    def size(self) -> QSizeF:
        """
        Get the node size

        Parameters
        ----------
        node : Node

        Returns
        -------
        value : QSizeF
        """
        return self._geometry.size

    @property
    def position(self) -> QPointF:
        """
        Get the node position

        Parameters
        ----------
        node : Node

        Returns
        -------
        value : QPointF
        """
        return self._graphics_obj.pos()

    @position.setter
    def position(self, pos):
        if not isinstance(pos, QPointF):
            px, py = pos
            pos = QPointF(px, py)

        self._graphics_obj.setPos(pos)
        self._graphics_obj.move_connections()

    @property
    def style(self) -> NodeStyle:
        'Node style'
        return self._style

    @property
    def state(self) -> NodeState:
        """
        Node state

        Returns
        -------
        value : NodeState
        """
        return self._state

    def __repr__(self):
        return (f'<{self.__class__.__name__} model={self._model} '
                f'uid={self._uid!r}>')
