import unicodedata
import os
from itertools import product
from collections import deque
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet

###{standalone
import sys, re
import logging

logger: logging.Logger = logging.getLogger("lark")
logger.addHandler(logging.StreamHandler())
# Set to highest level, since we have some warnings amongst the code
# By default, we should not output any log messages
logger.setLevel(logging.CRITICAL)


NO_VALUE = object()

T = TypeVar("T")


def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
    d: Dict[Any, Any] = {}
    for item in seq:
        k = key(item) if (key is not None) else item
        v = value(item) if (value is not None) else item
        try:
            d[k].append(v)
        except KeyError:
            d[k] = [v]
    return d


def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
    if isinstance(data, dict):
        if '__type__' in data:  # Object
            class_ = namespace[data['__type__']]
            return class_.deserialize(data, memo)
        elif '@' in data:
            return memo[data['@']]
        return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
    elif isinstance(data, list):
        return [_deserialize(value, namespace, memo) for value in data]
    return data


_T = TypeVar("_T", bound="Serialize")

class Serialize:
    """Safe-ish serialization interface that doesn't rely on Pickle

    Attributes:
        __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
        __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
                                        Should include all field types that aren't builtin types.
    """

    def memo_serialize(self, types_to_memoize: List) -> Any:
        memo = SerializeMemoizer(types_to_memoize)
        return self.serialize(memo), memo.serialize()

    def serialize(self, memo = None) -> Dict[str, Any]:
        if memo and memo.in_types(self):
            return {'@': memo.memoized.get(self)}

        fields = getattr(self, '__serialize_fields__')
        res = {f: _serialize(getattr(self, f), memo) for f in fields}
        res['__type__'] = type(self).__name__
        if hasattr(self, '_serialize'):
            self._serialize(res, memo)
        return res

    @classmethod
    def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
        namespace = getattr(cls, '__serialize_namespace__', [])
        namespace = {c.__name__:c for c in namespace}

        fields = getattr(cls, '__serialize_fields__')

        if '@' in data:
            return memo[data['@']]

        inst = cls.__new__(cls)
        for f in fields:
            try:
                setattr(inst, f, _deserialize(data[f], namespace, memo))
            except KeyError as e:
                raise KeyError("Cannot find key for class", cls, e)

        if hasattr(inst, '_deserialize'):
            inst._deserialize()

        return inst


class SerializeMemoizer(Serialize):
    "A version of serialize that memoizes objects to reduce space"

    __serialize_fields__ = 'memoized',

    def __init__(self, types_to_memoize: List) -> None:
        self.types_to_memoize = tuple(types_to_memoize)
        self.memoized = Enumerator()

    def in_types(self, value: Serialize) -> bool:
        return isinstance(value, self.types_to_memoize)

    def serialize(self) -> Dict[int, Any]:  # type: ignore[override]
        return _serialize(self.memoized.reversed(), None)

    @classmethod
    def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]:  # type: ignore[override]
        return _deserialize(data, namespace, memo)


try:
    import regex
    _has_regex = True
except ImportError:
    _has_regex = False

if sys.version_info >= (3, 11):
    import re._parser as sre_parse
    import re._constants as sre_constants
else:
    import sre_parse
    import sre_constants

categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')

def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
    if _has_regex:
        # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
        # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
        # match here below.
        regexp_final = re.sub(categ_pattern, 'A', expr)
    else:
        if re.search(categ_pattern, expr):
            raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
        regexp_final = expr
    try:
        # Fixed in next version (past 0.960) of typeshed
        return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
    except sre_constants.error:
        if not _has_regex:
            raise ValueError(expr)
        else:
            # sre_parse does not support the new features in regex. To not completely fail in that case,
            # we manually test for the most important info (whether the empty string is matched)
            c = regex.compile(regexp_final)
            # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
            # See lark-parser/lark#1376 and python/cpython#109859
            MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
            if c.match('') is None:
                # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
                return 1, int(MAXWIDTH)
            else:
                return 0, int(MAXWIDTH)

###}


_ID_START =    'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)

def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
    if len(s) != 1:
        return all(_test_unicode_category(char, categories) for char in s)
    return s == '_' or unicodedata.category(s) in categories

def is_id_continue(s: str) -> bool:
    """
    Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
    numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
    """
    return _test_unicode_category(s, _ID_CONTINUE)

def is_id_start(s: str) -> bool:
    """
    Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
    numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
    """
    return _test_unicode_category(s, _ID_START)


def dedup_list(l: Sequence[T]) -> List[T]:
    """Given a list (l) will removing duplicates from the list,
       preserving the original order of the list. Assumes that
       the list entries are hashable."""
    return list(dict.fromkeys(l))


class Enumerator(Serialize):
    def __init__(self) -> None:
        self.enums: Dict[Any, int] = {}

    def get(self, item) -> int:
        if item not in self.enums:
            self.enums[item] = len(self.enums)
        return self.enums[item]

    def __len__(self):
        return len(self.enums)

    def reversed(self) -> Dict[int, Any]:
        r = {v: k for k, v in self.enums.items()}
        assert len(r) == len(self.enums)
        return r



def combine_alternatives(lists):
    """
    Accepts a list of alternatives, and enumerates all their possible concatenations.

    Examples:
        >>> combine_alternatives([range(2), [4,5]])
        [[0, 4], [0, 5], [1, 4], [1, 5]]

        >>> combine_alternatives(["abc", "xy", '$'])
        [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]

        >>> combine_alternatives([])
        [[]]
    """
    if not lists:
        return [[]]
    assert all(l for l in lists), lists
    return list(product(*lists))

try:
    import atomicwrites
    _has_atomicwrites = True
except ImportError:
    _has_atomicwrites = False

class FS:
    exists = staticmethod(os.path.exists)

    @staticmethod
    def open(name, mode="r", **kwargs):
        if _has_atomicwrites and "w" in mode:
            return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
        else:
            return open(name, mode, **kwargs)


class fzset(frozenset):
    def __repr__(self):
        return '{%s}' % ', '.join(map(repr, self))


def classify_bool(seq: Iterable, pred: Callable) -> Any:
    false_elems = []
    true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)]  # type: ignore[func-returns-value]
    return true_elems, false_elems


def bfs(initial: Iterable, expand: Callable) -> Iterator:
    open_q = deque(list(initial))
    visited = set(open_q)
    while open_q:
        node = open_q.popleft()
        yield node
        for next_node in expand(node):
            if next_node not in visited:
                visited.add(next_node)
                open_q.append(next_node)

def bfs_all_unique(initial, expand):
    "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
    open_q = deque(list(initial))
    while open_q:
        node = open_q.popleft()
        yield node
        open_q += expand(node)


def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
    if isinstance(value, Serialize):
        return value.serialize(memo)
    elif isinstance(value, list):
        return [_serialize(elem, memo) for elem in value]
    elif isinstance(value, frozenset):
        return list(value)  # TODO reversible?
    elif isinstance(value, dict):
        return {key:_serialize(elem, memo) for key, elem in value.items()}
    # assert value is None or isinstance(value, (int, float, str, tuple)), value
    return value




def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
    """
    Splits n up into smaller factors and summands <= max_factor.
    Returns a list of [(a, b), ...]
    so that the following code returns n:

    n = 1
    for a, b in values:
        n = n * a + b

    Currently, we also keep a + b <= max_factor, but that might change
    """
    assert n >= 0
    assert max_factor > 2
    if n <= max_factor:
        return [(n, 0)]

    for a in range(max_factor, 1, -1):
        r, b = divmod(n, a)
        if a + b <= max_factor:
            return small_factors(r, max_factor) + [(a, b)]
    assert False, "Failed to factorize %s" % n


class OrderedSet(AbstractSet[T]):
    """A minimal OrderedSet implementation, using a dictionary.

    (relies on the dictionary being ordered)
    """
    def __init__(self, items: Iterable[T] =()):
        self.d = dict.fromkeys(items)

    def __contains__(self, item: Any) -> bool:
        return item in self.d

    def add(self, item: T):
        self.d[item] = None

    def __iter__(self) -> Iterator[T]:
        return iter(self.d)

    def remove(self, item: T):
        del self.d[item]

    def __bool__(self):
        return bool(self.d)

    def __len__(self) -> int:
        return len(self.d)

    def __repr__(self):
        return f"{type(self).__name__}({', '.join(map(repr,self))})"
