"""Taken from https://github.com/btubbs/sseclient"""

import codecs
import logging
import re
import time
import warnings

import requests
from oauthlib.oauth2 import TokenExpiredError
from requests.exceptions import HTTPError

# Technically, we should support streams that mix line endings.  This regex,
# however, assumes that a system will provide consistent line endings.
end_of_field = re.compile(r"\r\n\r\n|\r\r|\n\n")

LOGGER = logging.getLogger("homeconnect.sseclient")


class SSEClient(object):
    def __init__(
        self, url, last_id=None, retry=3000, session=None, chunk_size=1024, **kwargs
    ):
        self.url = url
        self.last_id = last_id
        self.retry = retry
        self.chunk_size = chunk_size

        # Optional support for passing in a requests.Session()
        self.session = session

        # Any extra kwargs will be fed into the requests.get call later.
        self.requests_kwargs = kwargs

        # The SSE spec requires making requests with Cache-Control: nocache
        if "headers" not in self.requests_kwargs:
            self.requests_kwargs["headers"] = {}
        self.requests_kwargs["headers"]["Cache-Control"] = "no-cache"

        # The 'Accept' header is not required, but explicit > implicit
        self.requests_kwargs["headers"]["Accept"] = "text/event-stream"

        # Keep data here as it streams in
        self.buf = ""

        self._connect()

    def _connect(self):
        LOGGER.info("Connecting ...")
        if self.last_id:
            self.requests_kwargs["headers"]["Last-Event-ID"] = self.last_id

        # Use session if set.  Otherwise fall back to requests module.
        requester = self.session or requests
        self.resp = requester.get(self.url, stream=True, **self.requests_kwargs)
        self.resp_iterator = self.resp.iter_content(chunk_size=self.chunk_size)

        self.resp.encoding = "UTF-8"

        # TODO: Ensure we're handling redirects.  Might also stick the 'origin'
        # attribute on Events like the Javascript spec requires.
        try:
            self.resp.raise_for_status()
        except HTTPError:
            LOGGER.error("Failed connecting.")
            # Wait 10 times longer if connection failed due to rate limits
            time.sleep(10 * self.retry / 1000.0)
            self._connect()

    def _event_complete(self):
        return re.search(end_of_field, self.buf) is not None

    def __iter__(self):
        return self

    def __next__(self):
        while not self._event_complete():
            try:
                decoder = codecs.getincrementaldecoder(self.resp.encoding)(
                    errors="replace"
                )
                next_chunk = next(self.resp_iterator)
                if not next_chunk:
                    LOGGER.error("EOFError")
                    raise EOFError()
                self.buf += decoder.decode(next_chunk)

            # except (StopIteration, requests.RequestException, EOFError, http.client.IncompleteRead, ValueError) as e:
            except Exception as e:
                LOGGER.warning("Exception while reading event: ", exc_info=True)
                time.sleep(self.retry / 1000.0)
                self._connect()

                # The SSE spec only supports resuming from a whole message, so
                # if we have half a message we should throw it out.
                head, sep, tail = self.buf.rpartition("\n")
                self.buf = head + sep
                continue

        # Split the complete event (up to the end_of_field) into event_string,
        # and retain anything after the current complete event in self.buf
        # for next time.
        (event_string, self.buf) = re.split(end_of_field, self.buf, maxsplit=1)
        msg = Event.parse(event_string)

        # If the server requests a specific retry delay, we need to honor it.
        if msg.retry:
            self.retry = msg.retry

        # last_id should only be set if included in the message.  It's not
        # forgotten if a message omits it.
        if msg.id:
            self.last_id = msg.id

        return msg


class Event(object):

    sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")

    def __init__(self, data="", event="message", id=None, retry=None):
        self.data = data
        self.event = event
        self.id = id
        self.retry = retry

    def dump(self):
        lines = []
        if self.id:
            lines.append("id: %s" % self.id)

        # Only include an event line if it's not the default already.
        if self.event != "message":
            lines.append("event: %s" % self.event)

        if self.retry:
            lines.append("retry: %s" % self.retry)

        lines.extend("data: %s" % d for d in self.data.split("\n"))
        return "\n".join(lines) + "\n\n"

    @classmethod
    def parse(cls, raw):
        """
        Given a possibly-multiline string representing an SSE message, parse it
        and return a Event object.
        """
        msg = cls()
        for line in raw.splitlines():
            m = cls.sse_line_pattern.match(line)
            if m is None:
                # Malformed line.  Discard but warn.
                warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning)
                LOGGER.warn('Invalid SSE line: "%s"', line)
                continue

            name = m.group("name")
            if name == "":
                # line began with a ":", so is a comment.  Ignore
                continue
            value = m.group("value")

            if name == "data":
                # If we already have some data, then join to it with a newline.
                # Else this is it.
                if msg.data:
                    msg.data = "%s\n%s" % (msg.data, value)
                else:
                    msg.data = value
            elif name == "event":
                msg.event = value
            elif name == "id":
                msg.id = value
            elif name == "retry":
                msg.retry = int(value)

        return msg

    def __str__(self):
        return self.data
