# ------------------------------------------------------------------------------
# pycparser: c_json.py
#
# by Michael White (@mypalmike)
#
# This example includes functions to serialize and deserialize an ast
# to and from json format. Serializing involves walking the ast and converting
# each node from a python Node object into a python dict. Deserializing
# involves the opposite conversion, walking the tree formed by the
# dict and converting each dict into the specific Node object it represents.
# The dict itself is serialized and deserialized using the python json module.
#
# The dict representation is a fairly direct transformation of the object
# attributes. Each node in the dict gets one metadata field referring to the
# specific node class name, _nodetype. Each local attribute (i.e. not linking
# to child nodes) has a string value or array of string values. Each child
# attribute is either another dict or an array of dicts, exactly as in the
# Node object representation. The "coord" attribute, representing the
# node's location within the source code, is serialized/deserialized from
# a Coord object into a string of the format "filename:line[:column]".
#
# Example TypeDecl node, with IdentifierType child node, represented as a dict:
#     "type": {
#         "_nodetype": "TypeDecl",
#         "coord": "c_files/funky.c:8",
#         "declname": "o",
#         "quals": [],
#         "type": {
#             "_nodetype": "IdentifierType",
#             "coord": "c_files/funky.c:8",
#             "names": [
#                 "char"
#             ]
#         }
#     }
# ------------------------------------------------------------------------------
import json
import sys
import re
from typing import Any, Callable, Dict, Optional, Set, TypeVar

# This is not required if you've installed pycparser into
# your site-packages/ with setup.py
sys.path.extend([".", ".."])

from pycparser import parse_file, c_ast
from pycparser.c_parser import Coord


RE_CHILD_ARRAY = re.compile(r"(.*)\[(.*)\]")
RE_INTERNAL_ATTR = re.compile("__.*__")


class CJsonError(Exception):
    pass


_T = TypeVar("_T")
_R = TypeVar("_R")


def memodict(fn: Callable[[_T], _R]) -> Callable[[_T], _R]:
    """Fast memoization decorator for a function taking a single argument"""
    cache: Dict[_T, _R] = {}

    def memoized(arg: _T) -> _R:
        if arg in cache:
            return cache[arg]
        result = fn(arg)
        cache[arg] = result
        return result

    return memoized


@memodict
def child_attrs_of(klass: type[c_ast.Node]) -> Set[str]:
    """
    Given a Node class, get a set of child attrs.
    Memoized to avoid highly repetitive string manipulation

    """
    non_child_attrs = set(klass.attr_names)
    all_attrs = set([i for i in klass.__slots__ if not RE_INTERNAL_ATTR.match(i)])
    return all_attrs - non_child_attrs


def to_dict(node: c_ast.Node) -> Dict[str, Any]:
    """Recursively convert an ast into dict representation."""
    klass = node.__class__

    result: Dict[str, Any] = {}

    # Metadata
    result["_nodetype"] = klass.__name__

    # Local node attributes
    for attr in klass.attr_names:
        result[attr] = getattr(node, attr)

    # Coord object
    if node.coord:
        result["coord"] = str(node.coord)
    else:
        result["coord"] = None

    # Child attributes
    for child_name, child in node.children():
        # Child strings are either simple (e.g. 'value') or arrays (e.g. 'block_items[1]')
        match = RE_CHILD_ARRAY.match(child_name)
        if match:
            array_name, array_index = match.groups()
            array_index = int(array_index)
            # arrays come in order, so we verify and append.
            result[array_name] = result.get(array_name, [])
            if array_index != len(result[array_name]):
                raise CJsonError(
                    "Internal ast error. Array {} out of order. "
                    "Expected index {}, got {}".format(
                        array_name, len(result[array_name]), array_index
                    )
                )
            result[array_name].append(to_dict(child))
        else:
            result[child_name] = to_dict(child)

    # Any child attributes that were missing need "None" values in the json.
    for child_attr in child_attrs_of(klass):
        if child_attr not in result:
            result[child_attr] = None

    return result


def to_json(node: c_ast.Node, **kwargs: Any) -> str:
    """Convert ast node to json string"""
    return json.dumps(to_dict(node), **kwargs)


def file_to_dict(filename: str) -> Dict[str, Any]:
    """Load C file into dict representation of ast"""
    ast = parse_file(filename, use_cpp=True)
    return to_dict(ast)


def file_to_json(filename: str, **kwargs: Any) -> str:
    """Load C file into json string representation of ast"""
    ast = parse_file(filename, use_cpp=True)
    return to_json(ast, **kwargs)


def _parse_coord(coord_str: Optional[str]) -> Optional[Coord]:
    """Parse coord string (file:line[:column]) into Coord object."""
    if coord_str is None:
        return None

    vals = coord_str.split(":")
    vals.extend(["", "", ""])
    filename, line, column = vals[:3]
    line_num = int(line) if line else 0
    column_num = int(column) if column else None
    return Coord(filename, line_num, column_num)


def _convert_to_obj(value: Any) -> Any:
    """
    Convert an object in the dict representation into an object.
    Note: Mutually recursive with from_dict.

    """
    match value:
        case dict():
            return from_dict(value)
        case list():
            return [_convert_to_obj(item) for item in value]
        case _:
            # String
            return value


def from_dict(node_dict: Dict[str, Any]) -> c_ast.Node:
    """Recursively build an ast from dict representation"""
    class_name = node_dict.pop("_nodetype")

    klass = getattr(c_ast, class_name)

    # Create a new dict containing the key-value pairs which we can pass
    # to node constructors.
    objs = {}
    for key, value in node_dict.items():
        if key == "coord":
            objs[key] = _parse_coord(value)
        else:
            objs[key] = _convert_to_obj(value)

    # Use keyword parameters, which works thanks to beautifully consistent
    # ast Node initializers.
    return klass(**objs)


def from_json(ast_json: str) -> c_ast.Node:
    """Build an ast from json string representation"""
    return from_dict(json.loads(ast_json))


# ------------------------------------------------------------------------------
if __name__ == "__main__":
    if len(sys.argv) > 1:
        # Some test code...
        # Do trip from C -> ast -> dict -> ast -> json, then print.
        ast_dict = file_to_dict(sys.argv[1])
        ast = from_dict(ast_dict)
        print(to_json(ast, sort_keys=True, indent=4))
    else:
        print("Please provide a filename as argument")
