# -*- coding: utf-8 -*-

# (c) Jérôme Laheurte 2015-2026
# See LICENSE.txt

"""
Context-free grammars objects. To define a grammar, inherit the
Grammar class and define a method decorated with 'production' for each
production.
"""

import copy
import functools
import inspect
import logging
import warnings

from ptk.lexer import EOF, _LexerMeta
from ptk.utils import Singleton


class Epsilon(metaclass=Singleton): # pylint: disable=too-few-public-methods
    """
    Empty production
    """
    __reprval__ = '\u03B5'


class GrammarError(Exception):
    """
    Generic grammar error, like duplicate production.
    """


class GrammarParseError(GrammarError):
    """
    Syntax error in a production specification.
    """


@functools.total_ordering
class Production:
    """
    Production object
    """
    def __init__(self, name, callback, priority=None, attributes=None):
        self.name = name
        self.posarg = None
        self.callback = callback
        self.right = []
        self.attributes = attributes or {}
        self.__priority = priority
        # position => id
        self.__ids = {} # pylint: disable=unused-private-member

    def add_symbol(self, identifier, name=None):
        """
        Append a symbol to the production's right side.
        """
        if name is not None:
            if name in self.__ids.values():
                raise GrammarParseError(f'Duplicate identifier name "{name}"')
            self.__ids[len(self.right)] = name
        self.right.append(identifier)

    def addSymbol(self, identifier, name=None): # pylint: disable=invalid-name
        warnings.warn('addSymbol is deprecated in favor of add_symbol', DeprecationWarning)
        self.add_symbol(identifier, name=name)

    def cloned(self):
        prod = Production(self.name, self.callback, self.__priority)
        prod.right = list(self.right)
        prod.__ids = dict(self.__ids) # pylint: disable=protected-access,unused-private-member
        return prod

    def apply(self, args, position):
        kwargs = {name: args[index] for index, name in self.__ids.items()}
        if self.posarg is not None:
            kwargs[self.posarg] = position
        return self.callback, kwargs

    def rightmost_terminal(self, grammar):
        """
        Returns the rightmost terminal, or None if there is none
        """
        for symbol in reversed(self.right):
            if symbol in grammar.token_types():
                return symbol
        return None

    def rightmostTerminal(self, grammar): # pylint: disable=invalid-name
        warnings.warn('rightmostTerminal is deprecated in favor of rightmost_terminal', DeprecationWarning)
        return self.rightmost_terminal(grammar)

    def precedence(self, grammar):
        """
        Returns the production's priority (specified through the
        'priority' keyword argument to the 'production' decorator), or
        if there is none, the priority for the rightmost terminal.
        """
        if self.__priority is not None:
            return grammar.terminal_precedence(self.__priority)
        symbol = self.rightmost_terminal(grammar)
        if symbol is not None:
            return grammar.terminal_precedence(symbol)
        return None

    def __eq__(self, other):
        return (self.name, self.right) == (other.name, other.right)

    def __lt__(self, other):
        return (self.name, self.right) < (other.name, other.right)

    def __repr__(self): # pragma: no cover
        return '%s -> %s' % (self.name, ' '.join([repr(p) for p in self.right]) if self.right else repr(Epsilon)) # pylint: disable=consider-using-f-string

    def __hash__(self):
        return hash((self.name, tuple(self.right)))


# Same remark as in lexer.py.
_PRODREGISTER = []


class _GrammarMeta(_LexerMeta):
    def __new__(mcs, name, bases, attrs):
        global _PRODREGISTER # pylint: disable=W0603
        try:
            attrs['__productions__'] = []
            attrs['__precedence__'] = []
            attrs['__prepared__'] = False
            attrs['__lrstates__'] = []
            klass = super().__new__(mcs, name, bases, attrs)
            for func, string, priority, attributes in _PRODREGISTER:
                parser = klass._create_production_parser(func.__name__, priority, attributes) # pylint: disable=W0212
                parser.parse(string)
            return klass
        finally:
            _PRODREGISTER = []


def production(prod, priority=None, **kwargs):
    def _wrap(func):
        if any(func.__name__ == aFunc.__name__ and func != aFunc for aFunc, _, _, _ in _PRODREGISTER):
            raise TypeError(f'Duplicate production method name "{func.__name__}"')
        _PRODREGISTER.append((func, prod, priority, kwargs))
        return func
    return _wrap


class Grammar(metaclass=_GrammarMeta):
    """
    Base class for a context-free grammar
    """

    __productions__ = [] # Make pylint happy
    __precedence__ = []
    __prepared__ = False

    startSymbol = None

    def __init__(self):
        # pylint: disable=R0912
        super().__init__()
        if not self.__prepared__:
            self.prepare()

    @classmethod
    def prepare(cls):
        cls.startSymbol = cls._default_start_symbol() if cls.startSymbol is None else cls.startSymbol

        productions = set()
        for prod in cls.productions():
            if prod in productions:
                raise GrammarError(f'Duplicate production "{prod}"')
            productions.add(prod)

        cls.__all_firsts__ = cls.__compute_firsts()

        logger = logging.getLogger('Grammar')
        productions = cls.productions()
        max_width = max(len(prod.name) for prod in productions)
        for prod in productions:
            logger.debug('%%- %ds -> %%s' % max_width, prod.name, ' '.join([repr(name) for name in prod.right]) if prod.right else Epsilon) # pylint: disable=consider-using-f-string,logging-not-lazy

        cls.__prepared__ = True

    @classmethod
    def __compute_firsts(cls):
        # pylint: disable=too-many-nested-blocks
        all_firsts = {symbol: set([symbol]) for symbol in cls.token_types()|set([EOF])} # pylint: disable=no-member
        while True:
            prev = copy.deepcopy(all_firsts)
            for nonterminal in cls.nonterminals():
                for prod in cls.productions():
                    if prod.name == nonterminal:
                        if prod.right:
                            for symbol in prod.right:
                                first = all_firsts.get(symbol, set())
                                all_firsts.setdefault(nonterminal, set()).update(first)
                                if Epsilon not in first:
                                    break
                            else:
                                all_firsts.setdefault(nonterminal, set()).add(Epsilon)
                        else:
                            all_firsts.setdefault(nonterminal, set()).add(Epsilon)
            if prev == all_firsts:
                break
        return all_firsts

    @classmethod
    def _default_start_symbol(cls):
        return cls.productions()[0].name

    @classmethod
    def productions(cls):
        """
        Returns all productions
        """
        productions = []
        for base in inspect.getmro(cls):
            if issubclass(base, Grammar):
                productions.extend(base.__productions__)
        return productions

    @classmethod
    def nonterminals(cls):
        """
        Return all non-terminal symbols
        """
        result = set()
        for prod in cls.productions():
            result.add(prod.name)
            for symbol in prod.right:
                if symbol not in cls.token_types(): # pylint: disable=no-member
                    result.add(symbol)
        return result

    @classmethod
    def precedences(cls):
        precedences = []
        for base in inspect.getmro(cls):
            if issubclass(base, Grammar):
                precedences.extend(base.__precedence__)
        return precedences

    @classmethod
    def terminal_precedence(cls, symbol):
        for index, (associativity, terminals) in enumerate(cls.precedences()):
            if symbol in terminals:
                return associativity, index
        return None

    @classmethod
    def terminalPrecedence(cls, symbol): # pylint: disable=invalid-name
        warnings.warn('terminalPrecedence is deprecated in favor of terminal_precedence', DeprecationWarning)
        return cls.terminal_precedence(symbol)

    @classmethod
    @functools.cache
    def first(cls, *symbols):
        """
        Returns the first set for a group of symbols
        """
        first = set()
        for symbol in symbols:
            rfirst = cls.__all_firsts__[symbol]
            first |= {a for a in rfirst if a is not Epsilon}
            if Epsilon not in rfirst:
                break
        else:
            first.add(Epsilon)
        return first
