# -*- coding: utf-8 -*-

# (c) Jérôme Laheurte 2015-2026
# See LICENSE.txt

import functools
import collections
import logging
import re
import warnings

from ptk.lexer import EOF, token, LexerPosition, ReLexer
from ptk.grammar import Grammar, Production, GrammarError
# production is only imported so that client code doesn't have to import it from grammar
from ptk.grammar import production # pylint: disable=W0611
from ptk.utils import Singleton, callback_by_name


class ParseError(Exception):
    """
    Syntax error when parsing.

    :ivar token: The unexpected token.
    """
    def __init__(self, grammar, tok, state, tokens):
        self.token = tok
        super().__init__(f'Unexpected token "{tok.value}" ({tok.type}) in state "{sorted(state)}"')

        self._state = state
        self._expecting = set()
        for terminal in grammar.token_types():
            if grammar.__actions__.get((state, terminal), None) is not None:
                self._expecting.add(terminal)
        self._tokens = tokens

    @property
    def position(self):
        return self.token.position

    def expecting(self):
        """
        Returns a set of tokens types that would have been valid in input.
        """
        return self._expecting

    def state(self):
        """
        Returns the parser state when the error was encountered.
        """
        return self._state

    def last_token(self):
        """
        Returns the last valid token seen before this error
        """
        return self._tokens[-1]

    def lastToken(self): # pylint: disable=invalid-name
        warnings.warn('lastToken is deprecated in favor of last_token', DeprecationWarning)
        return self.last_token()

    def tokens(self):
        """
        Returns all tokens seen
        """
        return self._tokens


def left_assoc(*operators):
    """
    Class decorator for left associative operators. Use this to
    decorate your :py:class:`Parser` class. Operators passed as
    argument are assumed to have the same priority. The later you
    declare associativity, the higher the priority; so the following
    code

    .. code-block:: python

       @left_assoc('+', '-')
       @left_assoc('*', '/')
       class MyParser(LRParser):
           # ...

    declares '+' and '-' to be left associative, with the same
    priority. '*' and '/' are also left associative, with a higher
    priority than '+' and '-'.

    See also the *priority* argument to :py:func:`production`.
    """
    def _wrapper(cls):
        cls.__precedence__.insert(0, ('left', set(operators)))
        return cls
    return _wrapper


def leftAssoc(*operators): # pylint: disable=invalid-name
    warnings.warn('leftAssoc is deprecated in favor of left_assoc', DeprecationWarning)
    return left_assoc(*operators)


def right_assoc(*operators):
    """
    Class decorator for right associative operators. Same remarks as :py:func:`left_assoc`.
    """
    def _wrapper(cls):
        cls.__precedence__.insert(0, ('right', set(operators)))
        return cls
    return _wrapper


def rightAssoc(*operators): # pylint: disable=invalid-name
    warnings.warn('rightAssoc is deprecated in favor of right_assoc', DeprecationWarning)
    return right_assoc(*operators)


def non_assoc(*operators):
    """
    Class decorator for non associative operators. Same remarks as :py:func:`left_assoc`.
    """
    def _wrapper(cls):
        cls.__precedence__.insert(0, ('non', set(operators)))
        return cls
    return _wrapper


def nonAssoc(*operators): # pylint: disable=invalid-name
    warnings.warn('nonAssoc is deprecated in favor of non_assoc', DeprecationWarning)
    return non_assoc(*operators)


class _StartState(metaclass=Singleton): # pylint: disable=too-few-public-methods
    __reprval__ = '\u03A3'


class _ResolveError(Exception):
    pass


@functools.total_ordering
class _Item:
    __slots__ = ('production', 'dot', 'terminal', 'index', 'should_reduce', 'expecting')

    def __init__(self, prod, dot, terminal):
        self.production = prod
        self.dot = dot
        self.terminal = terminal
        self.index = None
        self.should_reduce = self.dot == len(self.production.right)
        self.expecting = None if self.should_reduce else self.production.right[self.dot]

    def next(self):
        """
        Returns an item with the dot advanced one position
        """
        return _Item(self.production, self.dot + 1, self.terminal)

    def __repr__(self):
        symbols = list(self.production.right)
        symbols.insert(self.dot, '\u2022')
        return '%s -> %s (%s)' % (self.production.name, ' '.join([repr(sym) for sym in symbols]), self.terminal) # pylint: disable=consider-using-f-string

    def __eq__(self, other):
        return (self.production, self.dot, self.terminal) == (other.production, other.dot, other.terminal)

    def __lt__(self, other):
        return (self.production, self.dot, self.terminal) < (other.production, other.dot, other.terminal)

    def __hash__(self):
        return hash((self.production, self.dot, self.terminal))


class _Accept(BaseException):
    def __init__(self, result):
        self.result = result
        super().__init__()


_StackItem = collections.namedtuple('_StackItem', ['state', 'value', 'position'])


class _Shift: # pylint: disable=too-few-public-methods
    def __init__(self, new_state):
        self.new_state = new_state

    def do_action(self, grammar, stack, tok): # pylint: disable=W0613
        stack.append(_StackItem(self.new_state, tok.value, tok.position))
        return True


class _Reduce: # pylint: disable=too-few-public-methods
    def __init__(self, item):
        self.item = item
        self.nargs = len(item.production.right)

    def do_action(self, grammar, stack, tok): # pylint: disable=W0613
        pos, (callback, kwargs) = self._get_callback(stack)
        self._applied(grammar, stack, callback(grammar, **kwargs), pos)
        return False

    def _applied(self, grammar, stack, prod_val, position):
        stack.append(_StackItem(grammar.goto(stack[-1].state, self.item.production.name), prod_val, position))

    def _get_callback(self, stack):
        if self.nargs:
            args = [stackItem.value for stackItem in stack[-self.nargs:]]
            pos = stack[-self.nargs].position
            stack[-self.nargs:] = []
        else:
            args = []
            pos = stack[-1].position # Hum.
        return pos, self.item.production.apply(args, pos)


class LRParser(Grammar):
    """
    LR(1) parser. This class is intended to be used with a lexer class
    derived from :py:class:`LexerBase`, using inheritance; it
    overrides :py:func:`LexerBase.new_token` so you must inherit from
    the parser first, then the lexer:

    .. code-block:: python

       class MyParser(LRParser, ReLexer):
           # ...

    """
    def __init__(self): # pylint: disable=R0914,R0912
        super().__init__()
        self._restart_parser()

    def rstack(self):
        return reversed(self.__stack)

    def new_token(self, tok):
        try:
            for action, stack in self._process_token(tok):
                if action.do_action(self, stack, tok):
                    break
            self.__tokens.append(tok)
            return None
        except _Accept as exc:
            self._restart_parser()
            self.on_start_symbol(exc.result)
            return exc.result

    def _process_token(self, tok):
        while True:
            action = self.__actions__.get((self.__stack[-1].state, tok.type), None)
            if action is None:
                raise ParseError(self, tok, self.__stack[-1].state, self.__tokens)
            yield action, self.__stack

    def on_start_symbol(self, start_symbol): # pragma: no cover
        """
        This is called when the start symbol has been reduced.

        :param start_symbol: The value associated with the start symbol.
        """
        return self.newSentence(start_symbol)

    def newSentence(self, start_symbol): # pylint: disable=invalid-name,unused-argument
        warnings.warn('newSentence has been deprecated in favor of on_start_symbol', DeprecationWarning)
        raise NotImplementedError

    @classmethod
    def _create_production_parser(cls, name, priority, attrs):
        return ProductionParser(callback_by_name(name), priority, cls, attrs)

    @classmethod
    def _create_reduce_action(cls, item):
        return _Reduce(item)

    @classmethod
    def _create_shift_action(cls, state):
        return _Shift(state)

    @classmethod
    def prepare(cls):
        for prod in cls.productions():
            if prod.name is _StartState:
                break
        else:
            def acceptor(_, result):
                raise _Accept(result)
            prod = Production(_StartState, acceptor)
            prod.add_symbol(cls._default_start_symbol() if cls.startSymbol is None else cls.startSymbol, name='result')
            cls.__productions__.insert(0, prod)

        cls.startSymbol = _StartState
        super().prepare()

        states, goto = cls.__compute_states(prod)
        reachable = cls.__compute_actions(states, goto)

        logger = logging.getLogger('LRParser')
        cls.__resolve_conflicts(logger)

        used_tokens = {symbol for _, symbol in cls.__actions__ if symbol is not EOF}
        if used_tokens != cls.token_types(): # pragma: no cover pylint: disable=no-member
            logger.warning('The following tokens are not used: %s', ','.join([repr(sym) for sym in sorted(cls.token_types() - used_tokens)])) # pylint: disable=no-member

        if reachable != cls.nonterminals(): # pragma: no cover
            logger.warning('The following nonterminals are not reachable: %s', ','.join([repr(sym) for sym in sorted(cls.nonterminals() - reachable)]))

        # Reductions only need goto entries for nonterminals
        cls._goto = {(state, symbol): new_state for (state, symbol), new_state in goto.items() if symbol not in cls.token_types()} # pylint: disable=no-member

        parts = []
        if cls.nSR:
            parts.append(f'{cls.nSR} shift/reduce conflicts')
        if cls.nRR:
            parts.append(f'{cls.nRR} reduce/reduce conflicts')
        if parts:
            logger.warning(', '.join(parts))

        # Cast to tuple because sets are not totally ordered
        for index, state in enumerate([tuple(cls._startState)] + sorted([tuple(state) for state in states if state != cls._startState])):
            logger.debug('State %d', index)
            for item in sorted(state):
                logger.debug('    %s', item)
                item.index = index
            cls.__lrstates__.append(sorted(state)) # pylint: disable=no-member
        logger.info('%d states.', len(states))

    @classmethod
    def __compute_states(cls, start):
        all_syms = list(cls.token_types() | cls.nonterminals()) # pylint: disable=no-member
        goto = []
        cls._startState = frozenset([_Item(start, 0, EOF)])
        states = set([cls._startState])
        stack = [cls._startState]
        while stack:
            state = stack.pop()
            state_closure = cls.__item_set_closure(state)
            for symbol in all_syms:
                # Compute goto(symbol, state)
                next_state = frozenset([item.next() for item in state_closure if item.expecting == symbol])
                if next_state:
                    goto.append(((state, symbol), next_state))
                    if next_state not in states:
                        states.add(next_state)
                        stack.append(next_state)
        return states, dict(goto)

    @classmethod
    def __compute_actions(cls, states, goto):
        cls.__actions__ = {}
        reachable = set()
        token_types = cls.token_types() # pylint: disable=no-member
        for state in states:
            for item in cls.__item_set_closure(state):
                if item.should_reduce:
                    action = cls._create_reduce_action(item)
                    reachable.add(item.production.name)
                    cls.__add_reduce_action(state, item.terminal, action)
                else:
                    symbol = item.production.right[item.dot]
                    if symbol in token_types:
                        cls.__add_shift_action(state, symbol, cls._create_shift_action(goto[(state, symbol)]))
        return reachable

    @classmethod
    def __should_prefer_shift(cls, logger, reduce_action, symbol):
        logger.info('Shift/reduce conflict for "%s" on "%s"', reduce_action.item, symbol)

        prod_precedence = reduce_action.item.production.precedence(cls)
        token_precedence = cls.terminal_precedence(symbol)

        # If both precedences are specified, use priority/associativity
        if prod_precedence is not None and token_precedence is not None:
            prod_assoc, prod_prio = prod_precedence
            token_assoc, token_prio = token_precedence
            if prod_prio > token_prio:
                logger.info('Resolving in favor of reduction because of priority')
                return False
            if prod_prio < token_prio:
                logger.info('Resolving in favor of shift because of priority')
                return True
            if prod_assoc == token_assoc:
                if prod_assoc == 'non':
                    logger.info('Resolving in favor of error because of associativity')
                    raise _ResolveError()
                if prod_assoc == 'left':
                    logger.info('Resolving in favor of reduction because of associativity')
                    return False
                logger.info('Resolving in favor of shift because of associativity')
                return True

        # At least one of those is not specified; use shift
        logger.warning('Could not disambiguate shift/reduce conflict for "%s" on "%s"; using shift', reduce_action.item, symbol)
        cls.nSR += 1
        return True

    @classmethod
    def __resolve_conflicts(cls, logger):
        cls.nSR = 0
        cls.nRR = 0

        for (state, symbol), actions in sorted(cls.__actions__.items()):
            action = actions.pop()
            while actions:
                conflicting = actions.pop()
                try:
                    action = cls.__resolve_conflict(logger, action, conflicting, symbol)
                except _ResolveError:
                    del cls.__actions__[(state, symbol)]
                    break
            else:
                cls.__actions__[(state, symbol)] = action

    @classmethod
    def __resolve_conflict(cls, logger, action1, action2, symbol):
        if isinstance(action2, _Shift):
            action1, action2 = action2, action1

        if isinstance(action1, _Shift):
            # Shift/reduce
            return action1 if cls.__should_prefer_shift(logger, action2, symbol) else action2

        # Reduce/reduce
        logger.warning('Reduce/reduce conflict between "%s" and "%s"', action1.item, action2.item)
        cls.nRR += 1

        # Use the first one to be declared
        for prod in cls.productions():
            if prod == action1.item.production:
                logger.warning('Using "%s', prod)
                return action1
            if prod == action2.item.production:
                logger.warning('Using "%s', prod)
                return action2

        return None

    @classmethod
    def __add_reduce_action(cls, state, symbol, action):
        cls.__actions__.setdefault((state, symbol), []).append(action)

    @classmethod
    def __add_shift_action(cls, state, symbol, action):
        for existing in cls.__actions__.get((state, symbol), []):
            if isinstance(existing, _Shift):
                return
        cls.__actions__.setdefault((state, symbol), []).append(action)

    @classmethod
    def goto(cls, state, symbol):
        return cls._goto[(state, symbol)]

    def _restart_parser(self):
        self.__stack = [_StackItem(self._startState, None, LexerPosition(1, 1))]
        self.__tokens = []
        self.restart_lexer() # pylint: disable=no-member

    @classmethod
    @functools.cache
    def __item_set_closure(cls, items):
        result = set(items)
        while True:
            prev = set(result)
            for item in [item for item in result if not item.should_reduce]:
                symbol = item.production.right[item.dot]
                if symbol not in cls.token_types(): # pylint: disable=no-member
                    terminals = cls.first(*tuple(item.production.right[item.dot + 1:] + [item.terminal]))
                    for prod in (prod for prod in cls.productions() if prod.name == symbol):
                        for terminal in terminals:
                            result.add(_Item(prod, 0, terminal))
            if prev == result:
                break
        return result


class ProductionParser(LRParser, ReLexer): # pylint: disable=R0904
    """
    Parser for productions in other parsers
    """
    def __init__(self, callback, priority, grammar_class, attributes): # pylint: disable=R0915
        self.callback = callback
        self.priority = priority
        self.grammar_class = grammar_class
        self.attributes = attributes

        super().__init__()

    @classmethod
    def prepare(cls, **kwargs): # pylint: disable=R0915
        # Obviously cannot use @production here

        # When mixing async and sync parsers in the same program this may be called twice,
        # because AsyncProductionParser inherits from ProductionParser
        if cls.productions():
            # FIXME remove this, we don't do the async stuff any more
            return

        # DECL -> identifier "->" PRODS
        prod = Production('DECL', cls.DECL)
        prod.add_symbol('LEFT', 'left')
        prod.add_symbol('arrow')
        prod.add_symbol('PRODS', 'prods')
        cls.__productions__.append(prod)

        # LEFT -> identifier
        prod = Production('LEFT', cls.LEFT)
        prod.add_symbol('identifier', 'name')
        cls.__productions__.append(prod)

        # LEFT -> identifier "<" posarg ">"
        prod = Production('LEFT', cls.LEFT)
        prod.add_symbol('identifier', 'name')
        prod.add_symbol('lchev')
        prod.add_symbol('identifier', 'posarg')
        prod.add_symbol('rchev')
        cls.__productions__.append(prod)

        # PRODS -> P
        prod = Production('PRODS', cls.PRODS1)
        prod.add_symbol('P', 'prodlist')
        cls.__productions__.append(prod)

        # PRODS -> PRODS "|" P
        prod = Production('PRODS', cls.PRODS2)
        prod.add_symbol('PRODS', 'prods')
        prod.add_symbol('union')
        prod.add_symbol('P', 'prodlist')
        cls.__productions__.append(prod)

        # P -> P SYM
        prod = Production('P', cls.P1)
        prod.add_symbol('P', 'prodlist')
        prod.add_symbol('SYM', 'sym')
        cls.__productions__.append(prod)

        # P -> ɛ
        prod = Production('P', cls.P2)
        cls.__productions__.append(prod)

        # SYM -> SYMNAME PROPERTIES
        prod = Production('SYM', cls.SYM)
        prod.add_symbol('SYMNAME', 'symname')
        prod.add_symbol('PROPERTIES', 'properties')
        cls.__productions__.append(prod)

        # SYM -> SYMNAME repeat PROPERTIES
        prod = Production('SYM', cls.SYMREP)
        prod.add_symbol('SYMNAME', 'symname')
        prod.add_symbol('repeat', 'repeat')
        prod.add_symbol('PROPERTIES', 'properties')
        cls.__productions__.append(prod)

        # SYM -> SYMNAME repeat lparen identifier rparen PROPERTIES
        prod = Production('SYM', cls.SYMREP)
        prod.add_symbol('SYMNAME', 'symname')
        prod.add_symbol('repeat', 'repeat')
        prod.add_symbol('lparen')
        prod.add_symbol('identifier', 'separator')
        prod.add_symbol('rparen')
        prod.add_symbol('PROPERTIES', 'properties')
        cls.__productions__.append(prod)

        # SYM -> SYMNAME repeat lparen litteral rparen PROPERTIES
        prod = Production('SYM', cls.SYMREP_LIT)
        prod.add_symbol('SYMNAME', 'symname')
        prod.add_symbol('repeat', 'repeat')
        prod.add_symbol('lparen')
        prod.add_symbol('litteral', 'separator')
        prod.add_symbol('rparen')
        prod.add_symbol('PROPERTIES', 'properties')
        cls.__productions__.append(prod)

        # SYMNAME -> identifier
        prod = Production('SYMNAME', cls.SYMNAME1)
        prod.add_symbol('identifier', 'identifier')
        cls.__productions__.append(prod)

        # SYMNAME -> litteral
        prod = Production('SYMNAME', cls.SYMNAME2)
        prod.add_symbol('litteral', 'litteral')
        cls.__productions__.append(prod)

        # PROPERTIES -> ɛ
        prod = Production('PROPERTIES', cls.PROPERTIES1)
        cls.__productions__.append(prod)

        # PROPERTIES -> lchev identifier rchev
        prod = Production('PROPERTIES', cls.PROPERTIES2)
        prod.add_symbol('lchev')
        prod.add_symbol('identifier', 'name')
        prod.add_symbol('rchev')
        cls.__productions__.append(prod)

        super().prepare(**kwargs)

    def on_start_symbol(self, start_symbol):
        (name, posarg), prods = start_symbol
        for prod in prods:
            if prod.name is None:
                prod.name = name
                prod.posarg = posarg
        self.grammar_class.__productions__.extend(prods)

    # Lexer

    @staticmethod
    def ignore(char):
        return char in ' \t\n'

    @token('->')
    def arrow(self, tok):
        pass

    @token('<')
    def lchev(self, tok):
        pass

    @token('>')
    def rchev(self, tok):
        pass

    @token(r'\|')
    def union(self, tok):
        pass

    @token(r'\*|\+|\?')
    def repeat(self, tok):
        pass

    @token(r'\(')
    def lparen(self, tok):
        pass

    @token(r'\)')
    def rparen(self, tok):
        pass

    @token('[a-zA-Z_][a-zA-Z0-9_]*')
    def identifier(self, tok):
        pass

    @token(r'"|\'')
    def litteral(self, tok):
        class StringBuilder: # pylint: disable=missing-class-docstring,too-few-public-methods
            def __init__(self, quotetype):
                self.quotetype = quotetype
                self.chars = []
                self.state = 0
            def feed(self, char):
                if self.state == 0:
                    if char == '\\':
                        self.state = 1
                    elif char == self.quotetype:
                        return 'litteral', ''.join(self.chars)
                    else:
                        self.chars.append(char)
                elif self.state == 1:
                    self.chars.append(char)
                    self.state = 0
                return None
        self.set_consumer(StringBuilder(tok.value))

    # Parser

    def DECL(self, left, prods): # pylint: disable=invalid-name
        name, _ = left
        if name in self.grammar_class.token_types():
            raise GrammarError(f'"{name}" is a token name and cannot be used as non-terminal')
        return (left, prods)

    def LEFT(self, name, posarg=None): # pylint: disable=invalid-name
        return (name, posarg)

    def PRODS1(self, prodlist): # pylint: disable=invalid-name
        return prodlist

    def PRODS2(self, prods, prodlist): # pylint: disable=invalid-name
        prods.extend(prodlist)
        return prods

    def P1(self, sym, prodlist): # pylint: disable=invalid-name
        result = []
        symbol, properties, repeat, sep = sym

        for prod in prodlist:
            if prod.name is None:
                if repeat is None:
                    prod.add_symbol(symbol, name=properties.get('name', None))
                    result.append(prod)
                elif repeat == '?':
                    if sep is not None:
                        raise GrammarError('A separator makes no sense for "?"')
                    self.__add_at_most_one(result, prod, symbol, properties.get('name', None))
                elif repeat in ['*', '+']:
                    self.__add_list(result, prod, symbol, properties.get('name', None), repeat == '*', sep)
            else:
                result.append(prod)

        return result

    def __add_at_most_one(self, productions, prod, symbol, name):
        clone = prod.cloned()
        if name is not None:
            self._wrap_callback_none(name, clone)
        productions.append(clone)

        prod.add_symbol(symbol, name=name)
        productions.append(prod)

    def _wrap_callback_none(self, name, prod):
        previous = prod.callback
        def callback(*args, **kwargs):
            kwargs[name] = None
            return previous(*args, **kwargs)
        prod.callback = callback

    def __add_list(self, productions, prod, symbol, name, allow_empty, sep): # pylint: disable=too-many-arguments,too-many-positional-arguments
        class ListSymbol(metaclass=Singleton): # pylint: disable=missing-class-docstring,too-few-public-methods
            __reprval__ = 'List(%s, "%s")' % (symbol, '*' if allow_empty else '+') # pylint: disable=consider-using-f-string

        if allow_empty:
            clone = prod.cloned()
            self._wrap_callback_empty(name, clone)
            productions.append(clone)

        prod.add_symbol(ListSymbol, name=name)
        productions.append(prod)

        list_prod = Production(ListSymbol, self._wrap_callback_one())
        list_prod.add_symbol(symbol, name='item')
        productions.append(list_prod)

        list_prod = Production(ListSymbol, self._wrap_callback_next())
        list_prod.add_symbol(ListSymbol, name='items')
        if sep is not None:
            list_prod.add_symbol(sep)
        list_prod.add_symbol(symbol, name='item')
        productions.append(list_prod)

    def _wrap_callback_empty(self, name, prod):
        previous = prod.callback
        def callback_empty(*args, **kwargs):
            if name is not None:
                kwargs[name] = []
            return previous(*args, **kwargs)
        prod.callback = callback_empty

    def _wrap_callback_one(self):
        def callback_one(_, item):
            return [item]
        return callback_one

    def _wrap_callback_next(self):
        def callback_next(_, items, item):
            items.append(item)
            return items
        return callback_next

    def P2(self): # pylint: disable=invalid-name
        # 'name' is replaced in on_start_symbol()
        return [Production(None, self.callback, priority=self.priority, attributes=self.attributes)]

    def SYMNAME1(self, identifier): # pylint: disable=invalid-name
        return identifier

    def SYMNAME2(self, litteral): # pylint: disable=invalid-name
        name = litteral
        if name not in self.grammar_class.token_types():
            self.grammar_class.add_token_type(name, lambda s, tok: None, re.escape(name), None)
        return name

    def SYM(self, symname, properties): # pylint: disable=invalid-name
        return (symname, properties, None, None)

    def SYMREP(self, symname, repeat, properties, separator=None): # pylint: disable=invalid-name
        return (symname, properties, repeat, separator)

    def SYMREP_LIT(self, symname, repeat, properties, separator): # pylint: disable=invalid-name
        if separator not in self.grammar_class.token_types():
            self.grammar_class.add_token_type(separator, lambda s, tok: None, re.escape(separator), None)
        return self.SYMREP(symname, repeat, properties, separator)

    def PROPERTIES1(self): # pylint: disable=invalid-name
        return {}

    def PROPERTIES2(self, name): # pylint: disable=invalid-name
        return {'name': name}
