1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
|
# Copyright (c) 2017-2026 Juancarlo AƱez (apalala@gmail.com)
# SPDX-License-Identifier: BSD-4-Clause
from __future__ import annotations
import dataclasses
import inspect
import warnings
import weakref
from collections.abc import Callable
from functools import cache
from typing import Any, overload
from ..ast import AST
from ..infos import ParseInfo
from ..util.abctools import rowselect
from ..util.asjson import AsJSONMixin, asjson, asjsons
__all__ = ['BaseNode', 'TatSuDataclassParams', 'tatsudataclass']
TatSuDataclassParams = dict( # noqa: C408
eq=False,
repr=False,
match_args=False,
unsafe_hash=False,
kw_only=True,
)
@overload
def tatsudataclass[T: type](cls: T) -> T: ...
@overload
def tatsudataclass[T: type](**params: Any) -> Callable[[T], T]: ...
def tatsudataclass[T: type](
cls: T | None = None,
**params: Any,
) -> T | Callable[[T], T]:
# by Gemini (2026-02-07)
# by [apalala@gmail.com](https://github.com/apalala)
def decorator(target: T) -> T:
allparams = {**TatSuDataclassParams, **params}
return dataclasses.dataclass(**allparams)(target)
# If cls is passed, it was used as @tatsudataclass with no arguments
if cls is not None:
return decorator(cls)
return decorator
@tatsudataclass
class BaseNode(AsJSONMixin):
ast: Any = dataclasses.field(kw_only=False, default=None)
# _: dataclasses.KW_ONLY
ctx: Any = None
parseinfo: ParseInfo | None = None
def __init__(self, ast: Any = None, **attributes: Any):
# NOTE:
# A @datclass subclass may not call this,
# but __post_init__() should still be honored
super().__init__()
self.ast = ast
self.__set_attributes(**attributes)
self.__post_init__()
def __post_init__(self):
if self.ast and isinstance(self.ast, dict):
self.ast = AST(self.ast)
ast = self.ast
if not isinstance(ast, AST):
return
if not self.parseinfo:
self.parseinfo = ast.parseinfo
# note:
# Node objects are created by a model builer when invoked by he parser,
# which passes only the ast recovered when the object was created.
# `
# point::Point = ... left:... right:... ;
# `
# Here the key,value pairs in the AST are injected into the corresponding
# attributes declared by the Node subclass. Synthetic classes
# override this to create the attributes.
for name in ast:
if not hasattr(self, name) or inspect.ismethod(getattr(self, name)):
continue
setattr(self, name, ast[name])
def asjson(self) -> Any:
return asjson(self)
def __set_attributes(self, **attrs) -> None:
if not isinstance(attrs, dict):
return
for name, value in attrs.items():
if not hasattr(self, name):
raise ValueError(f'Unknown argument {name}={value!r}')
if inspect.ismethod(method := getattr(self, name)):
raise TypeError(f'Overriding method {name}={method!r}')
if (prev := getattr(self, name, None)) and inspect.ismethod(prev):
warnings.warn(
f'`{name}` in keyword arguments will shadow'
f' `{type(self).__name__}.{name}`',
stacklevel=2,
)
setattr(self, name, value)
@staticmethod
@cache
def _basenode_keys() -> frozenset[str]:
# Gemini (2026-02-14)
return frozenset(dir(BaseNode))
def __pub__(self) -> dict[str, Any]:
pub = super().__pub__()
# Gemini (2026-02-14)
wanted = pub.keys() - self._basenode_keys()
if wanted or self.ast is None:
pass
elif not isinstance(self.ast, AST):
wanted = {'ast'} # self.ast may be all this object has
return rowselect(wanted, pub)
def __repr__(self) -> str:
fieldindex = {
f.name: i for i, f in enumerate(dataclasses.fields(self)) # type: ignore
}
def fieldorder(n) -> int:
return fieldindex.get(n, len(fieldindex))
pub = self.__pub__()
sortedkeys = sorted(pub.keys(), key=fieldorder)
attrs = ', '.join(
f'{name}={pub[name]!r}' for name in sortedkeys if pub[name] is not None
)
return f'{type(self).__name__}({attrs})'
def __str__(self) -> str:
return asjsons(self)
def __eq__(self, other) -> bool:
# NOTE: No use case for structural equality
return other is self
def __hash__(self) -> int:
# NOTE: No use case for structural equality
return hash(id(self))
def __getstate__(self) -> Any:
return {
name: (
value
if not isinstance(value, (weakref.ReferenceType, *weakref.ProxyTypes))
else None
)
for name, value in vars(self).items()
}
def __setstate__(self, state):
self.__dict__.update(state)
|