1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
|
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes that roughly correspond to Clang AST node types."""
import collections
import dataclasses
import re
from typing import Dict, Optional, Sequence, Tuple, Union
# We are relying on Clang to do the actual source parsing and are only doing
# a little bit of extra parsing of function parameter type declarations here.
# These patterns are here for sanity checking rather than actual parsing.
VALID_TYPE_NAME_PATTERN = re.compile('[A-Za-z_][A-Za-z0-9_]*')
C_INVALID_TYPE_NAMES = frozenset([
'auto', 'break', 'case', 'const', 'continue', 'default', 'do', 'else',
'enum', 'extern', 'for', 'goto', 'if', 'inline', 'register', 'restrict',
'return', 'sizeof', 'static', 'struct', 'switch', 'typedef', 'union',
'volatile', 'while', '_Alignas', '_Atomic', '_Generic', '_Imaginary',
'_Noreturn', '_Static_assert', '_Thread_local', '__attribute__', '_Pragma'])
def _is_valid_integral_type(type_str: str):
"""Checks if a string is a valid integral type."""
parts = re.split(r'\s+', type_str)
counter = collections.defaultdict(lambda: 0)
wildcard_counter = 0
for part in parts:
if part in ('signed', 'unsigned', 'short', 'long', 'int', 'char'):
counter[part] += 1
elif VALID_TYPE_NAME_PATTERN.fullmatch(part):
# a non-keyword can be a typedef for int
wildcard_counter += 1
else:
return False
if (counter['signed'] + counter['unsigned'] > 1 or
counter['short'] > 1 or counter['long'] > 2 or
(counter['short'] and counter['long']) or
((counter['short'] or counter['long']) and counter['char']) or
counter['char'] + counter['int'] + wildcard_counter > 1):
return False
else:
return True
@dataclasses.dataclass
class ValueType:
"""Represents a C type that is neither a pointer type nor an array type."""
name: str
is_const: bool = False
is_volatile: bool = False
def __init__(self, name: str, is_const: bool = False,
is_volatile: bool = False):
is_valid_type_name = (
VALID_TYPE_NAME_PATTERN.fullmatch(name) or
_is_valid_integral_type(name)) and name not in C_INVALID_TYPE_NAMES
if not is_valid_type_name:
raise ValueError(f'{name!r} is not a valid value type name')
self.name = name
self.is_const = is_const
self.is_volatile = is_volatile
def decl(self, name_or_decl: Optional[str] = None) -> str:
parts = []
if self.is_const:
parts.append('const')
if self.is_volatile:
parts.append('volatile')
parts.append(self.name)
if name_or_decl:
parts.append(name_or_decl)
return ' '.join(parts)
def __str__(self):
return self.decl()
@dataclasses.dataclass
class ArrayType:
"""Represents a C array type."""
inner_type: Union[ValueType, 'PointerType']
extents: Tuple[int]
def __init__(self, inner_type: Union[ValueType, 'PointerType'],
extents: Sequence[int]):
self.inner_type = inner_type
self.extents = tuple(extents)
@property
def _extents_str(self) -> str:
return ''.join(f'[{n}]' for n in self.extents)
def decl(self, name_or_decl: Optional[str] = None) -> str:
name_or_decl = name_or_decl or ''
return self.inner_type.decl(f'{name_or_decl}{self._extents_str}')
def __str__(self):
return self.decl()
@dataclasses.dataclass
class PointerType:
"""Represents a C pointer type."""
inner_type: Union[ValueType, ArrayType, 'PointerType']
is_const: bool = False
is_volatile: bool = False
is_restrict: bool = False
def decl(self, name_or_decl: Optional[str] = None) -> str:
"""Creates a string that declares an object of this type."""
parts = ['*']
if self.is_const:
parts.append('const')
if self.is_volatile:
parts.append('volatile')
if self.is_restrict:
parts.append('restrict')
if name_or_decl:
parts.append(name_or_decl)
ptr_decl = ' '.join(parts)
if isinstance(self.inner_type, ArrayType):
ptr_decl = f'({ptr_decl})'
return self.inner_type.decl(ptr_decl)
def __str__(self):
return self.decl()
@dataclasses.dataclass
class FunctionParameterDecl:
"""Represents a parameter in a function declaration.
Note that according to the C language rule, a function parameter of array
type undergoes array-to-pointer decay, and therefore appears as a pointer
parameter in an actual C AST. We retain the arrayness of a parameter here
since the array's extents are informative.
"""
name: str
type: Union[ValueType, ArrayType, PointerType]
def __str__(self):
return self.type.decl(self.name)
@property
def decltype(self) -> str:
return self.type.decl()
@dataclasses.dataclass
class FunctionDecl:
"""Represents a function declaration."""
name: str
return_type: Union[ValueType, ArrayType, PointerType]
parameters: Tuple[FunctionParameterDecl]
doc: str
def __init__(self, name: str,
return_type: Union[ValueType, ArrayType, PointerType],
parameters: Sequence[FunctionParameterDecl],
doc: str):
self.name = name
self.return_type = return_type
self.parameters = tuple(parameters)
self.doc = doc
def __str__(self):
param_str = ', '.join(str(p) for p in self.parameters)
return f'{self.return_type} {self.name}({param_str})'
@property
def decltype(self) -> str:
param_str = ', '.join(str(p.decltype) for p in self.parameters)
return f'{self.return_type} ({param_str})'
class _EnumDeclValues(Dict[str, int]):
"""A dict with modified stringified representation.
The __repr__ method of this class adds a trailing comma to the list of values.
This is done as a hint for code formatters to place one item per line when
the stringified OrderedDict is used in generated Python code.
"""
def __repr__(self):
out = super().__repr__()
if self:
out = re.sub(r'\(\[(.+)\]\)\Z', r'([\1,])', out)
return re.sub(r'\A_EnumDeclValues', 'dict', out)
@dataclasses.dataclass
class EnumDecl:
"""Represents an enum declaration."""
name: str
declname: str
values: Dict[str, int]
def __init__(self, name: str, declname: str, values: Dict[str, int]):
self.name = name
self.declname = declname
self.values = _EnumDeclValues(values)
|