1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578
|
#!/usr/bin/env python
"""
This module will scan RDKit sources searching for docstrings
that lack parameter definitions, or member functions that
do not have an explicit "self" parameter, and will patch
the C++ sources accordingly.
"""
import sys
import os
import re
import itertools
import glob
import json
import importlib
import queue
import subprocess
import multiprocessing
import shutil
import traceback
import logging
import tempfile
from threading import Thread
from pathlib import Path
RDKIT_MODULE_NAME = "rdkit"
CLANG_CPP_EXE = os.environ.get("CLANG_CPP_EXE", "clang++")
CLANG_FORMAT_EXE = os.environ.get("CLANG_FORMAT_EXE", "clang-format")
CLANG_PYTHON_BINDINGS_PATH = os.environ.get("CLANG_PYTHON_BINDINGS_PATH", None)
if CLANG_PYTHON_BINDINGS_PATH is None:
raise ValueError(f"Please set CLANG_PYTHON_BINDINGS_PATH to the absolute path to the bindings/python directory under the clang tree")
if CLANG_PYTHON_BINDINGS_PATH not in sys.path:
sys.path.insert(0, CLANG_PYTHON_BINDINGS_PATH)
if sys.platform.startswith("linux"):
CLANG_LIBCLANG = "libclang.so"
elif sys.platform.startswith("darwin"):
CLANG_LIBCLANG = "libclang.dylib"
elif sys.platform.startswith("win32"):
CLANG_LIBCLANG = "clang.dll"
else:
raise ValueError(f"Unsupported platform {sys.platform}")
CLANG_LIBRARY_PATH = os.environ.get("CLANG_LIBRARY_PATH", None)
if CLANG_LIBRARY_PATH is None:
res = subprocess.run([CLANG_CPP_EXE, "-print-resource-dir"], capture_output=True, check=True)
lib_path = res.stdout.decode("utf-8").strip()
have_libclang = False
while lib_path != os.sep:
libclang_path = os.path.join(lib_path, CLANG_LIBCLANG)
have_libclang = os.path.exists(libclang_path)
if have_libclang:
break
lib_path = os.path.dirname(lib_path)
if have_libclang:
CLANG_LIBRARY_PATH = libclang_path
if CLANG_LIBRARY_PATH is None:
raise ValueError(f"Please set CLANG_LIBRARY_PATH to the absolute path to {CLANG_LIBCLANG}")
clang_cindex = importlib.import_module("clang.cindex")
clang_cindex.Config.set_library_file(CLANG_LIBRARY_PATH)
TranslationUnit = clang_cindex.TranslationUnit
CursorKind = clang_cindex.CursorKind
class FunctionDef:
def __init__(self, def_cursor, func_name, is_staticmethod, level):
self.def_cursor = def_cursor
self.func_name = func_name
self.is_staticmethod = is_staticmethod
self.level = level
class WorkerResult:
"""Result generated by a Worker thread.
"""
def __init__(self, worker_idx):
self.worker_idx = worker_idx
self.processed_cpp_files = set()
self.proc_error = ""
class DictLike(dict):
"""Base class that confers dict-like behavior
to any class derived from it.
"""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def to_json(self):
"""Serialize class to a JSON string.
Returns:
str: JSON-serialized class content
"""
return json.dumps(self)
@classmethod
def from_json(cls, j):
"""Deserialize class from a JSON string.
Args:
j (str): JSON string
Returns:
cls: an instance of cls
"""
instance = cls()
for k, v in json.loads(j).items():
setattr(instance, k, v)
return instance
class ClassInfo(DictLike):
def __init__(self, hash, parents):
self.hash = hash
self.parents = parents
self.python_class_name = None
self.cpp_class_name = None
class CppFile(DictLike):
"""Class associated to a single C++ file.
"""
QUOTED_FIELD_REGEX = re.compile(r"\"([^\"]*)\"")
EXTRACT_BASE_CLASS_NAME_REGEX = re.compile(r"\s*(\S+)\s*<[^>]+>\s*$")
EXTRACT_INIT_ARGS = re.compile(r"^<(.*)\s>+\s$")
IS_TEMPLATE_TYPE = re.compile(r"^T\d*$")
SELF_LITERAL = "self"
def __init__(self, cpp_path=None):
"""_summary_
Args:
cpp_path (str, optional): absolute path to a C++ file.
Defaults to None.
"""
self.cpp_path = cpp_path
self.arg1_func_defs = []
# type_ref_dict is a dictionary of class alias typedefs
# relating typedefs to the actual class name
# found while walking the AST tree
self.type_ref_dict = {}
self.ast_error = None
self.ast_warning = None
@property
def ast_path(self):
"""Return absolute file to .ast file associated to this C++ file.
Returns:
str: absolute file to .ast file associated to this C++ file
"""
return f"{self.cpp_path_noext}.ast"
@property
def cpp_path_noext(self):
"""Return absolute path to this C++ file without extension.
Returns:
str: absolute path to this C++ file without extension
"""
return os.path.splitext(self.cpp_path)[0]
@property
def cpp_path_ext(self):
"""Return extension of this C++ file.
Returns:
str: extension of this C++ file
"""
return os.path.splitext(self.cpp_path)[1]
def extract_quoted_content(self, s):
"""Extract string between double quotes.
Args:
s (str): string to be parsed
Returns:
str: string between quotes, or None
if there were no quotes.
"""
res = None
m = self.QUOTED_FIELD_REGEX.match(s)
if m:
res = m.group(1)
return res
def generate_ast(self, clang_flags):
"""Generate AST file with clang++.
Args:
clang_flags (list): list of flags to be passed to clang++
Returns:
bool: True if success, False if failure
"""
res = False
try:
cpp_file = self.cpp_path
cpp_dir = os.path.dirname(cpp_file)
self.ast_error = ""
proc = subprocess.run([CLANG_CPP_EXE] + clang_flags + [cpp_file],
capture_output=True, cwd=cpp_dir)
if proc.returncode:
self.ast_error += f"{cpp_file}: Failed to parse with clang."
if proc.stderr:
self.ast_error += "\nError was:\n" + proc.stderr.decode("utf-8")
else:
res = True
if proc.stderr:
self.ast_warning = proc.stderr.decode("utf-8")
except Exception as e:
res = False
tb = traceback.format_exc()
self.ast_error += f"{cpp_file}: Failed to run clang\n{str(e)}\n{str(tb)}\n"
return res
def recurse_ast_cursor(self, cursor, hnd=sys.stdout, recursion_level=0):
"""Recursively walk on the AST tree and write a dump to hnd.
Args:
cursor (Cursor): current cursor position
hnd (file-like object, optional): Handle to which the dump should be written.
Defaults to sys.stdout.
recursion_level (int, optional): Used to indent the dump. Defaults to 0.
"""
recursion_level += 1
tabs = "".join(" " for i in range(recursion_level))
for child in cursor.get_children():
if child.kind == CursorKind.TYPEDEF_DECL:
typedef_key = child.spelling
typedef_value = None
for child2 in child.get_children():
if child2.kind in (CursorKind.TYPE_REF, CursorKind.TEMPLATE_REF):
typedef_value = child2.spelling
break
if typedef_value is not None:
typedef_value = typedef_value.split(" ")[-1]
self.type_ref_dict[typedef_key] = typedef_value
print(f"{tabs}{child.kind}:{child.spelling}", file=hnd)
hnd.flush()
self.recurse_ast_cursor(child, hnd, recursion_level)
def get_func_name_if_has_arg1_param_r(self, cursor, arg1_func_names):
"""Return name of function that needs fixing.
Args:
cursor (Cursor): current cursor position
arg1_func_names (iterable): function names that need fixing
Returns:
str: function name
"""
if cursor.kind == CursorKind.STRING_LITERAL:
func_name = self.extract_quoted_content(cursor.spelling)
if func_name is not None and func_name in arg1_func_names:
return func_name
for child in cursor.get_children():
res = self.get_func_name_if_has_arg1_param_r(child, arg1_func_names)
if res is not None:
return res
return None
def have_decl_ref_expr_r(self, cursor):
"""Check if there is a "def" child below this cursor.
Args:
cursor (Cursor): current cursor position
Returns:
bool: True if there is a "def" child
"""
if cursor.kind == CursorKind.DECL_REF_EXPR and cursor.spelling == "def":
return True
for child in cursor.get_children():
if self.have_decl_ref_expr_r(child):
return True
return False
def find_non_class_defs(self, cursor, class_method_node_hashes, arg1_func_names):
"""Recursively find free function nodes that need fixing.
Args:
cursor (Cursor): current cursor position
non_class_defs (dict): dict relating a function name to a list of nodes
class_method_node_hashes (iterable): set of hashes corresponding to class method
nodes that need fixing previously found
arg1_func_names (iterable): function names that need fixing
"""
non_class_defs = {}
self.find_non_class_defs_r(cursor, non_class_defs, class_method_node_hashes, set(arg1_func_names))
return non_class_defs
def find_non_class_defs_r(self, cursor, non_class_defs, class_method_node_hashes, arg1_func_names):
"""Find free function nodes that need fixing (recursive).
Args:
cursor (Cursor): current cursor position
non_class_defs (dict): dict relating a function name to a list of nodes
class_method_node_hashes (iterable): set of hashes corresponding to class method
nodes that need fixing previously found
arg1_func_names (iterable): function names that need fixing
"""
if (cursor.kind == CursorKind.CALL_EXPR and cursor.spelling == "def"
and self.have_decl_ref_expr_r(cursor) and cursor.hash not in class_method_node_hashes):
func_name = self.get_func_name_if_has_arg1_param_r(cursor, arg1_func_names)
if func_name is not None:
node_list = non_class_defs.get(func_name, [])
node_list.append(cursor)
non_class_defs[func_name] = node_list
for child in cursor.get_children():
self.find_non_class_defs_r(child, non_class_defs, class_method_node_hashes, arg1_func_names)
def find_nodes(self, cursor):
"""Recursively walk on the AST tree and associate node
to their python::class_ hash.
Args:
cursor (Cursor): current cursor position
Returns:
dict[int, ClassInfo]: dict associating a class hash
to a ClassInfo instance
"""
class_info_by_class_hash = {}
self.find_nodes_r(cursor, [], class_info_by_class_hash)
return class_info_by_class_hash
def find_nodes_r(self, cursor, parents, class_info_by_class_hash):
"""Walk on the AST tree and associate node
to their python::class_ hash (recursive)
Args:
cursor (Cursor): current cursor position
parents (list): growing list of parents
class_info_by_class_hash (dict): dict associating a class hash
to a ClassInfo instance
"""
if cursor.kind == CursorKind.CALL_EXPR:
parents = [cursor] + parents
if cursor.kind in (CursorKind.CALL_EXPR, CursorKind.TEMPLATE_REF) and cursor.spelling == "class_":
class_hash = cursor.hash
prev_class_info = class_info_by_class_hash.get(class_hash, None)
if prev_class_info is None or len(parents) + 1 > len(prev_class_info.parents):
class_info = ClassInfo(class_hash, parents)
class_info_by_class_hash[class_hash] = class_info
for child in cursor.get_children():
if child.kind == CursorKind.TYPE_REF:
cpp_class_name = child.spelling.split(" ")[-1].split("::")[-1]
class_info.cpp_class_name = self.type_ref_dict.get(cpp_class_name, cpp_class_name)
break
return
for child in cursor.get_children():
self.find_nodes_r(child, parents, class_info_by_class_hash)
def find_class_name_r(self, class_info, cursor, found_class_names, arg1_func_byclass_dict):
"""Find the name of the python::class_ associated to class_hash.
Args:
class_info (ClassInfo): ClassInfo instance
cursor (Cursor): current cursor position
found_class_names (dict): dict relating class hash to class name
arg1_func_byclass_dict (dict): dict relating class names to methods
that need fixing. Also free functions are included under class name
FixSignatures.NO_CLASS_KEY
Returns:
bool: True if the class name corresponding to class_hash was found
"""
class_hash = class_info.hash
if cursor.kind == CursorKind.STRING_LITERAL:
class_name = self.extract_quoted_content(cursor.spelling)
if (class_name is not None and class_name in arg1_func_byclass_dict
and class_hash not in found_class_names and class_name not in found_class_names.values()):
found_class_names[class_hash] = class_name
class_info.python_class_name = class_name
return True
for child in cursor.get_children():
if self.find_class_name_r(class_info, child, found_class_names, arg1_func_byclass_dict):
return True
return False
def prune_nodes(self, tu_cursor, class_info_by_class_hash, arg1_func_byclass_dict):
"""Return a dict relating class name to a ClassInfo instance.
Args:
class_info_by_class_hash (dict): dict relating class hash to class_info
arg1_func_byclass_dict (dict): dict relating class names to methods
Returns:
dict: dict relating class name to a ClassInfo instance
"""
# populate found_class_names dictionary {class_hash: class_name}
# with classes that have methods we need to fix
found_class_names = {}
for class_hash, class_info in class_info_by_class_hash.items():
call_expr_class_node = class_info.parents[0]
# we might not find the class name as STRING_LITERAL for template classes
self.find_class_name_r(class_info, call_expr_class_node, found_class_names, arg1_func_byclass_dict)
for class_hash, class_info in class_info_by_class_hash.items():
# 2nd pass over the whole translation unit to find template classes
self.find_class_name_r(class_info, tu_cursor, found_class_names, arg1_func_byclass_dict)
# prune class_hash entries that do not have methods we need to fix
class_info_by_class_name = {}
for class_hash in tuple(class_info_by_class_hash.keys()):
if class_hash in found_class_names.keys():
found_class_name = found_class_names[class_hash]
class_info_by_class_name[found_class_name] = class_info_by_class_hash[class_hash]
return class_info_by_class_name
def have_python_range_r(self, cursor, requested_level, level=0):
"""Return True if there is a python::range among the children of cursor.
Args:
cursor (Cursor): current cursor position
requested_level (int): requested nesting level
level (int, optional): current nesting level. Defaults to 0.
Returns:
bool: True if there is a python::range among the children of cursor
"""
level += 1
for child in cursor.get_children():
if level == requested_level and child.kind == CursorKind.CALL_EXPR and child.spelling == "range":
return True
if self.have_python_range_r(child, requested_level, level):
return True
return False
def find_func_name_r(self, cursor, def_cursor, func_names, func_name_to_hash, def_init_nodes, level=0):
"""Find the Python function name connected to this cursor.
Args:
cursor (Cursor): current cursor position
def_cursor (Cursor): node corrisponding to the function "def"
func_names (list): function names that need fixing
func_name_to_hash (dict): dict relating function names to the
"def" node hash
def_init_nodes (dict): dict relating node hash to a FunctionDef instance
level (int, optional): _description_. Defaults to 0.
"""
log_path = self.cpp_path_noext + ".log"
is_staticmethod = (def_cursor.spelling == "staticmethod" and level == 1)
level += 1
for child in cursor.get_children():
if child.kind == CursorKind.STRING_LITERAL:
func_name = self.extract_quoted_content(child.spelling)
if (func_name is not None and func_name not in ("__init__", "__exit__", "__enter__")
and not (func_name == "__iter__" and self.have_python_range_r(def_cursor, level))):
try:
func_name_idx = func_names.index(func_name)
if def_cursor.hash in def_init_nodes:
with open(log_path, "a") as hnd:
print(f"1) find_func_name_r def_cursor.hash {def_cursor.hash} level {level} func_name {func_name} kind {def_cursor.kind} tokens {[t.spelling for t in def_cursor.get_tokens()]}", file=hnd)
hnd.flush()
assert def_cursor.hash not in def_init_nodes
func_names.pop(func_name_idx)
func_name_to_hash[func_name] = def_cursor.hash
def_init_nodes[def_cursor.hash] = FunctionDef(def_cursor, func_name, is_staticmethod, level)
with open(log_path, "a") as hnd:
print(f"2) find_func_name_r def_cursor.hash {def_cursor.hash} level {level} func_name {func_name} kind {def_cursor.kind} tokens {[t.spelling for t in def_cursor.get_tokens()]}", file=hnd)
hnd.flush()
except ValueError:
hash_for_func_name = func_name_to_hash.get(func_name, None)
if hash_for_func_name is not None and hash_for_func_name != def_cursor.hash:
prev_function_def = def_init_nodes.get(hash_for_func_name, None)
if prev_function_def is not None:
if prev_function_def.is_staticmethod and not is_staticmethod:
def_init_nodes[hash_for_func_name] = FunctionDef(def_cursor, func_name, True, level)
elif not prev_function_def.is_staticmethod and is_staticmethod:
def_init_nodes[hash_for_func_name] = FunctionDef(prev_function_def.def_cursor, func_name, True, prev_function_def.level)
elif (not (prev_function_def.is_staticmethod ^ is_staticmethod)
and prev_function_def.func_name == func_name
and def_cursor.kind != CursorKind.MEMBER_REF_EXPR
and def_cursor.hash not in def_init_nodes
and ([t.spelling for t in def_cursor.get_tokens()].count("def")
!= [t.spelling for t in prev_function_def.def_cursor.get_tokens()].count("def"))
and self.is_last_def(func_name, list(def_cursor.get_tokens()))):
with open(log_path, "a") as hnd:
print(f"3) find_func_name_r def_cursor.hash {def_cursor.hash} level {level} func_name {func_name} kind {def_cursor.kind} tokens {[t.spelling for t in def_cursor.get_tokens()]}", file=hnd)
hnd.flush()
def_init_nodes[def_cursor.hash] = FunctionDef(def_cursor, func_name, is_staticmethod, prev_function_def.level)
elif func_name == "__init__":
def_init_nodes[cursor.hash] = FunctionDef(cursor, "__init__", False, level)
self.find_func_name_r(child, def_cursor, func_names, func_name_to_hash, def_init_nodes, level)
def find_cpp_func_r(self, cursor, requested_level, func_name, param_count=-1, level=0):
"""Find the C++ function corresponding to this cursor.
Args:
cursor (Cursor): current cursor position
requested_level (int): requested nesting level
param_count (int, optional): parameter count for this function. Defaults to -1.
level (int, optional): current nesting level. Defaults to 0.
Returns:
tuple: (param_count, func_name) tuple. func_name can be prefixed with the
class name if the function is a class method.
"""
log_path = self.cpp_path_noext + ".log"
res = None
level += 1
for child in cursor.get_children():
if level == requested_level - 1 and "CAST" in str(child.kind):
param_count = 0
if level == requested_level:
if child.kind == CursorKind.DECL_REF_EXPR and child.spelling != "def":
res = child
elif child.kind == CursorKind.PARM_DECL:
assert param_count != -1
param_count += 1
elif child.kind in (CursorKind.UNARY_OPERATOR, CursorKind.UNEXPOSED_EXPR):
for child2 in child.get_children():
if child2.kind == CursorKind.DECL_REF_EXPR:
res = child2
break
elif child.kind == CursorKind.CALL_EXPR and child.spelling == "make_constructor":
for child2 in child.get_children():
if child2.kind in (CursorKind.UNARY_OPERATOR, CursorKind.UNEXPOSED_EXPR) and not child2.spelling:
for child3 in child2.get_children():
if child3.kind == CursorKind.DECL_REF_EXPR:
res = child3
break
if res is not None:
break
if res is None:
res = self.find_cpp_func_r(child, requested_level, func_name, param_count, level)
if res is not None and not isinstance(res, tuple):
decl_ref = res
res = None
for child in decl_ref.get_children():
if child.kind == CursorKind.TEMPLATE_REF and decl_ref.spelling:
template_ref = child.spelling.split("::")[-1]
with open(log_path, "a") as hnd:
print(f"1) find_cpp_func_r template_ref {template_ref}", file=hnd)
hnd.flush()
template_ref = self.type_ref_dict.get(template_ref, template_ref)
with open(log_path, "a") as hnd:
print(f"2) find_cpp_func_r template_ref {template_ref}", file=hnd)
hnd.flush()
res = template_ref + "::" + decl_ref.spelling
break
elif child.kind == CursorKind.TYPE_REF and decl_ref.spelling:
type_ref = child.spelling.split("::")[-1]
with open(log_path, "a") as hnd:
print(f"3) find_cpp_func_r type_ref {type_ref}", file=hnd)
hnd.flush()
type_ref = self.type_ref_dict.get(type_ref, type_ref)
with open(log_path, "a") as hnd:
print(f"4) find_cpp_func_r type_ref {type_ref}", file=hnd)
hnd.flush()
res = type_ref + "::" + decl_ref.spelling
break
elif child.kind == CursorKind.OVERLOADED_DECL_REF and not decl_ref.spelling and child.spelling:
decl_ref = child
if res is None and decl_ref.spelling:
res = decl_ref.spelling
res = (param_count, res)
return res
def extract_base_class_name(self, cpp_class_name):
"""Extract the base class name, if present.
Args:
cpp_class_name (str): C++ class name
Returns:
str: base class name, if present, otherwise input class name
"""
m = self.EXTRACT_BASE_CLASS_NAME_REGEX.match(cpp_class_name)
if m:
cpp_class_name = m.group(1)
return cpp_class_name.split("::")[-1]
def find_cpp_class_r(self, cursor, cpp_class_name, func_name):
"""Find the node corresponding to the func_name method of the
C++ class cpp_class_name.
If the node cannot be found, it returns the base class name
instead, so a new recursive search on the base class can
be carried out.
Args:
cursor (Cursor): current cursor position
cpp_class_name (str): C++ class name
func_name (str): C++ method name
Returns:
cursor|str: node corresponding to the method, or
base class name if the method could not be found
"""
res = None
base_cpp_class_name = None
for child in cursor.get_children():
if (child.kind in (CursorKind.CLASS_DECL, CursorKind.CLASS_TEMPLATE, CursorKind.STRUCT_DECL)
and child.spelling == cpp_class_name):
for child2 in child.get_children():
if child2.spelling.split("<")[0] == func_name and (
(child2.kind in (CursorKind.CXX_METHOD, CursorKind.FUNCTION_TEMPLATE)
or (func_name == cpp_class_name and child2.kind == CursorKind.CONSTRUCTOR))
):
res = child
break
elif child2.kind == CursorKind.CXX_BASE_SPECIFIER:
base_cpp_class_name = self.extract_base_class_name(child2.spelling)
if res is not None:
break
res = self.find_cpp_class_r(child, cpp_class_name, func_name)
if res is not None:
break
if res is None and base_cpp_class_name is not None:
return base_cpp_class_name
return res
@staticmethod
def have_param(param_list, param):
"""If param is part of param_list return True and pop it from param_list.
Args:
param_list (list[str]): list of parameters
param (str): parameter
Returns:
bool: True if param is part of param_list, False if not
"""
res = param in param_list
if res:
param_list.pop(param_list.index(param))
return res
def num_matching_parameters(self, expected_params, params):
"""Find the number of matching params between params
(list of individual parameter typenames) and expected_params
(concatenated string of expected parameter typenames)
Args:
expected_params (str): concatenated string of expected parameter typenames
params (list[str]): list of individual parameter typenames
Returns:
tuple[int, int]: number of matching params, number of non-matching params
"""
expected_params_tok = [p.split("::")[-1] for p in expected_params.split()]
params_tok = [p.split("::")[-1] for p in " ".join(params).split()]
num_matched_params = [self.have_param(expected_params_tok, p) for p in params_tok].count(True)
num_non_matched_params = len(params_tok) - num_matched_params
return num_matched_params, -num_non_matched_params
def find_cpp_func_params(self, cursor, is_staticmethod, cpp_class_name, func_name,
expected_cpp_params, expected_param_count):
"""Find parameter names of a C++ method.
First we try to find the exact number of expected parameters.
If we fail, we will accept an overload with a number of parameters
greater than the expected one, in the assumption that some parameters
can be optional.
Args:
cursor (Cursor): current cursor position
cpp_class_name (str): C++ class name
func_name (str): C++ method name
expected_cpp_params: expected parameter string based
on the Python function signature. This is a cumulative, concatenated
string with no spaces which is used when there are multiple overloads
with the same number of parameters to try and pick the C++ funciton whose
parameter types best fit the Python signature.
expected_param_count (int): expected parameter count based
on the Python function signature
Returns:
list[str]: list of parameter names
"""
self.params = None
assigned_overloads = None
if cpp_class_name == func_name:
key = f"{cpp_class_name}::{cpp_class_name}"
assigned_overloads = self.assigned_overloads.get(key, [])
if not assigned_overloads:
self.assigned_overloads[key] = assigned_overloads
self.assigned_overloads_for_func = assigned_overloads
for accept_params_no_type in (False, True):
self.accept_params_no_type = accept_params_no_type
for cmp_func in (int.__eq__, int.__gt__):
self.find_cpp_func_params_r(cursor, cpp_class_name, func_name,
expected_cpp_params, expected_param_count, cmp_func)
if self.params is not None:
break
if self.params is not None:
if assigned_overloads is not None and not self.has_template_type(self.params):
assigned_overloads.append(self.get_params_hash(self.params))
break
if self.params is None:
params = [f"arg{i + 1}" for i in range(expected_param_count)]
if not is_staticmethod:
params.insert(0, "self")
return params
return [p for p, _ in self.params]
def has_template_type(self, params):
"""Find if any parameter in params is of template type.
Args:
params (list[tuple[str, str]]): list of (name, type) tuples
Returns:
bool: True if params contain parameters of template type
(i.e., T, optionally followed by a number)
"""
return any(self.IS_TEMPLATE_TYPE.match(t) for _, t in params)
@staticmethod
def get_params_hash(params):
"""Get a hash from function parameters.
Args:
params (list[tuple[str, str]]): list of function parameters
as (parameter name, paramater type) tuples
Returns:
tuple: a sorted tuple that can be used as a hash
"""
return tuple(sorted(params))
def find_cpp_func_params_r(self, cursor, cpp_class_name, func_name,
expected_cpp_params, expected_param_count, cmp_func):
"""Find parameter names of a C++ method (recursive).
Args:
cursor (Cursor): current cursor position
cpp_class_name (str): C++ class name
func_name (str): C++ method name
expected_cpp_params: expected parameter string based
on the Python function signature. This is a cumulative, concatenated
string with no spaces which is used when there are multiple overloads
with the same number of parameters to try and pick the C++ funciton whose
parameter types best fit the Python signature.
expected_param_count (int): expected parameter count based
on the Python function signature
cmp_func (function): the comparator to use between the expected number
of parameters and the best-fitting found number of parameters
"""
accepted_kinds = [CursorKind.FUNCTION_DECL,
CursorKind.FUNCTION_TEMPLATE]
if cpp_class_name is not None:
if func_name != cpp_class_name:
accepted_kinds.append(CursorKind.CXX_METHOD)
else:
accepted_kinds.append(CursorKind.CONSTRUCTOR)
for child in cursor.get_children():
if child.kind in accepted_kinds and child.spelling.split("<")[0] == func_name:
params = [(child2.spelling, " ".join(child3.spelling for child3 in child2.get_children()
if child3.kind in (CursorKind.TEMPLATE_REF, CursorKind.TYPE_REF)))
for child2 in child.get_children() if child2.kind == CursorKind.PARM_DECL]
# certain C++ headers have only the type declaration but no variable name,
# in that case we replace "" with a dummy parameter name since python::args("")
# is not acceptable
params = [(p or f"arg{i + 1}", t) for i, (p, t) in enumerate(params)]
params_hash = self.get_params_hash(params)
if self.assigned_overloads_for_func is not None and params_hash in self.assigned_overloads_for_func:
continue
if ((expected_param_count == -1 or cmp_func(len(params), expected_param_count))
and (not expected_cpp_params or (self.accept_params_no_type and self.params is None)
or (self.params is not None and
self.num_matching_parameters(expected_cpp_params, [t for _, t in params])
> self.num_matching_parameters(expected_cpp_params, [t for _, t in self.params])))):
if expected_param_count != -1:
params = params[:expected_param_count]
self.params = params
else:
self.find_cpp_func_params_r(child, cpp_class_name, func_name, expected_cpp_params, expected_param_count, cmp_func)
def find_def_init_nodes_in_class_r(self, cursor, func_names, func_name_to_hash, def_init_nodes):
"""Find nodes corresponding to Python constructors and methods for a class.
Args:
cursor (Cursor): current cursor position
func_names (list): function names that need fixing
func_name_to_hash (dict): dict relating function names to the
"def" node hash
def_init_nodes (dict): dict relating node hash to a FunctionDef instance
"""
if cursor.kind in (CursorKind.CALL_EXPR, CursorKind.TEMPLATE_REF, CursorKind.MEMBER_REF_EXPR):
if cursor.spelling == "init":
if cursor.hash not in def_init_nodes:
def_init_nodes[cursor.hash] = FunctionDef(cursor, "__init__", False, 0)
# templated python::class_ may have no "def", so we accept empty spelling
elif not cursor.spelling or cursor.spelling == "def" or cursor.spelling == "staticmethod":
self.find_func_name_r(cursor, cursor, func_names, func_name_to_hash, def_init_nodes)
for child in cursor.get_children():
self.find_def_init_nodes_in_class_r(child, func_names, func_name_to_hash, def_init_nodes)
def is_class_hash_among_node_children_r(self, class_hash, node):
"""Return True if class_hash is found among the children of node.
Args:
class_hash (int): class hash
node (cursor): cursor
Returns:
bool: True if class_hash is found among the children of node.
"""
if node.hash == class_hash:
return True
for child in node.get_children():
if self.is_class_hash_among_node_children_r(class_hash, child):
return True
return False
def find_def_init_nodes(self, class_info_by_class_name, arg1_func_byclass_dict):
"""Find Python constructors and methods.
Args:
class_info_by_class_name (dict): dict relating class name
to a ClassInfo instance
arg1_func_byclass_dict (dict): dict relating class name to methods
Returns:
dict: dict relating class name to a (def_init_nodes, func_names) tuple,
where def_init_nodes is an iterable of constructors and methods,
and func_names a list of function names that need fixing but could
not be associated to any methods (currently unused)
"""
res = {}
for class_name, class_info in class_info_by_class_name.items():
def_init_nodes = {}
func_name_to_hash = {}
func_names = arg1_func_byclass_dict[class_name]
for i, node in enumerate(class_info.parents):
if i and not self.is_class_hash_among_node_children_r(class_info.hash, node):
break
self.find_def_init_nodes_in_class_r(node, func_names, func_name_to_hash, def_init_nodes)
res[class_name] = (def_init_nodes.values(), func_names)
return res
def find_python_args(self, tokens):
"""Insert the "self" arg into existing python::args.
Args:
tokens (iterable): iterable of Token objects
Returns:
list(tuple)|None: list of tuples with source line number, source column number
and string to be inserted on that line at that column position
"""
for i, t in enumerate(tokens):
if (t.spelling == "python"
and len(tokens[i:]) > 4
and tokens[i+1].spelling == "::"
and tokens[i+2].spelling == "args"
and tokens[i+3].spelling == "("):
concat_tokens = "".join(t.spelling for t in tokens)
# if there are multiple python::args keywords or default parameters,
# treat them as if they were python::arg
potential_non_self_token = tokens[i+4]
arg_name = self.extract_quoted_content(potential_non_self_token.spelling)
if arg_name != self.SELF_LITERAL and (concat_tokens.count("python::args") > 1 or "=" in concat_tokens):
return self.find_python_arg(tokens, "args")
if arg_name is not None:
if arg_name == self.SELF_LITERAL:
return []
source_loc = potential_non_self_token.extent.start
source_line = source_loc.line
source_col = source_loc.column
return [(source_line, source_col, f"\"{self.SELF_LITERAL}\", ")]
return None
def find_python_arg(self, tokens, arg_keyword="arg"):
"""Insert the "self" arg into existing python::arg.
Args:
tokens (iterable): iterable of Token objects
Returns:
list(tuple)|None: list of tuples with source line number, source column number
and string to be inserted on that line at that column position
"""
need_additional_bracket = False
open_bracket = ""
for i, t in enumerate(tokens):
if (t.spelling == "python"
and len(tokens[i:]) > 4
and tokens[i+1].spelling == "::"
and tokens[i+2].spelling == arg_keyword
and tokens[i+3].spelling == "("):
bracket_count = 0
j = i
while j:
j -= 1
if tokens[j].spelling == "(":
bracket_count += 1
elif bracket_count:
break
assert bracket_count
if bracket_count == 1:
need_additional_bracket = True
open_bracket = "("
j = i + 4
potential_non_self_token = tokens[j]
arg_name = self.extract_quoted_content(potential_non_self_token.spelling)
if arg_name is not None:
if arg_name == self.SELF_LITERAL:
return []
source_loc = t.extent.start
source_line = source_loc.line
source_col = source_loc.column
res = [(source_line, source_col, f"{open_bracket}python::{arg_keyword}(\"{self.SELF_LITERAL}\"), ")]
if need_additional_bracket:
found = False
j += 1
while tokens[j+1:] and not found:
j += 1
found = tokens[j].spelling in (",", ")")
assert found
source_loc = tokens[j].extent.start
source_line = source_loc.line
source_col = source_loc.column
res += [(source_line, source_col, ")")]
return res
return None
def find_no_arg(self, is_init, tokens, is_staticmethod, cpp_func_name, expected_param_count, cursor, class_info):
"""Insert the appropriate python::args where needed based on the C++
method parameter names.
Args:
is_init (bool): True if the method is a constructor
tokens (iterable): iterable of Token objects
is_staticmethod (bool): True if the method is static
cpp_func_name (str): name of the C++ function
expected_param_count (int): expected number of parameters
based on the Python function signature
cursor (Cursor): current cursor position
class_info (ClassInfo): ClassInfo instance
Raises:
IndexError: in case there are unexpected inconsistencies
(should never happen)
Returns:
list(tuple)|None: list of tuples with source line number, source column number
and string to be inserted on that line at that column position
"""
log_path = self.cpp_path_noext + ".log"
bracket_count = 0
init_args = ""
expected_cpp_params = None
for i, t in enumerate(tokens):
num_downstream_tokens = len(tokens[i:])
if is_init:
open_bracket_count = t.spelling.count("<")
closed_bracket_count = t.spelling.count(">")
if open_bracket_count or bracket_count:
init_args += t.spelling + " "
bracket_count += (open_bracket_count - closed_bracket_count)
if bracket_count == 0:
if init_args:
m = self.EXTRACT_INIT_ARGS.match(init_args)
if not m or "python::optional" in init_args:
init_args = ""
is_init = False
else:
init_args = m.group(1).replace("<", "").strip()
if init_args:
cpp_func_name = f"{class_info.cpp_class_name}::{class_info.cpp_class_name}"
expected_param_count = 1 + init_args.count(",")
expected_cpp_params = init_args
init_args = ""
is_def = (t.spelling == "def")
if (num_downstream_tokens > 2 and (is_init or is_def)
and tokens[i+1].spelling == "("):
need_comma = (tokens[i+2].spelling != ")")
is_make_constructor = "make_constructor" in (t.spelling for t in tokens)
python_args = "python::args("
need_self = not is_staticmethod and not is_make_constructor
if need_self:
python_args += f"\"{self.SELF_LITERAL}\"" + init_args
if cpp_func_name is not None and expected_param_count is not None and cursor is not None:
cpp_func_name_tokens = cpp_func_name.split("::")
cpp_class_name = cpp_func_name_tokens[-2] if len(cpp_func_name_tokens) > 1 else None
func_name = cpp_func_name_tokens[-1]
with open(log_path, "a") as hnd:
print(f"1) find_no_arg cpp_func_name {cpp_func_name} cpp_class_name {cpp_class_name} func_name {func_name} expected_param_count {expected_param_count} is_staticmethod {is_staticmethod} tokens {[t.spelling for t in tokens]}", file=hnd)
hnd.flush()
rename_first_param = need_self
if cpp_class_name is not None:
while 1:
res = self.find_cpp_class_r(cursor, cpp_class_name, func_name)
with open(log_path, "a") as hnd:
print(f"2) find_no_arg res {res}", file=hnd)
hnd.flush()
if not isinstance(res, str):
break
cpp_class_name = res
if res is not None:
rename_first_param = False
cursor = res
params = self.find_cpp_func_params(cursor, is_staticmethod, cpp_class_name, func_name, expected_cpp_params, expected_param_count)
if rename_first_param:
if not params:
raise IndexError(f"Expected at least one parameter on {func_name}, found none")
params[0] = self.SELF_LITERAL
with open(log_path, "a") as hnd:
print(f"3) find_no_arg params {params}", file=hnd)
hnd.flush()
if params is not None:
params = ", ".join(f"\"{p}\"" for p in params if p != self.SELF_LITERAL)
if params:
if need_self:
python_args += ", "
python_args += params
python_args += ")"
if is_init:
token_idx = i + 2
last_seen_idx = token_idx
if need_comma:
python_args += ", "
else:
token_idx = i + 4
bracket_count = 0
last_seen_idx = None
while token_idx < len(tokens):
s = tokens[token_idx].spelling
if (s == "."
and token_idx + 1 < len(tokens)
and tokens[token_idx + 1].spelling in ("def", "def_pickle", "staticmethod")):
break
if s and s[0] in ("(", "<"):
incr = s.count(s[0])
bracket_count += incr
last_seen_idx = None
elif s and s[0] in (")", ">") and bracket_count:
incr = s.count(s[0])
assert bracket_count >= incr
bracket_count -= incr
elif last_seen_idx is None and not bracket_count and s in (",", ")"):
last_seen_idx = token_idx
token_idx += 1
if last_seen_idx is None:
raise IndexError(f"Failed to find end of definitions; tokens[i+4]: {[t.spelling for t in tokens[i+4:]]}")
if need_comma:
python_args = ", " + python_args
potential_non_self_token = tokens[last_seen_idx]
source_loc = potential_non_self_token.extent.start
source_line = source_loc.line
source_col = source_loc.column
return [(source_line, source_col, python_args)]
return None
def find_func_def(self, func_name, tokens):
"""Find the tokens corresponding to the Python def
for func_name.
Args:
func_name (str): Python function name
tokens (iterable): iterable of Token objects
Returns:
iterable: iterable of Token objects
"""
for i, t in reversed(list(enumerate(tokens))):
if (t.spelling == "def"
and i + 2 < len(tokens)
and tokens[i+1].spelling == "("
and tokens[i+2].spelling == f"\"{func_name}\""):
return tokens[i:]
return None
def is_last_def(self, func_name, tokens):
"""Return true if the last "def" in tokens corresponds to func_name.
Args:
func_name (str): Python function name
tokens (iterable): iterable of Token objects
Returns:
True if the last "def" in tokens corresponds to func_name, False if not
"""
for i, t in reversed(list(enumerate(tokens))):
if t.spelling == "def" and tokens[i+1].spelling == "(":
return (tokens[i+2].spelling == f"\"{func_name}\"")
return False
def get_insertion(self, is_init, tokens, is_staticmethod=False, cpp_func_name=None, param_count=None, tu_cursor=None, class_info=None):
"""Get the insertion string to fix a Python function signature.
Args:
is_init (bool): True if we are dealing with a constructor
tokens (iterable): iterable of Token objects
is_staticmethod (bool, optional): True if this a static method. Defaults to False.
cpp_func_name (str, optional): C++ function name. Defaults to None.
param_count (int, optional): expected number of parameters. Defaults to None.
tu_cursor (Cursor, optional): translation unit cursor. Defaults to None.
class_info (ClassInfo, optional): ClassInfo instance. Defaults to None.
Returns:
list(tuple)|None: list of tuples with source line number, source column number
and string to be inserted on that line at that column position
"""
insertion = self.find_python_args(tokens)
log_path = self.cpp_path_noext + ".log"
with open(log_path, "a") as hnd:
print(f"1) get_insertion insertion {insertion}", file=hnd)
hnd.flush()
if insertion is not None:
return insertion if not is_staticmethod else None
insertion = self.find_python_arg(tokens)
with open(log_path, "a") as hnd:
print(f"2) get_insertion insertion {insertion}", file=hnd)
hnd.flush()
if insertion is not None:
return insertion if not is_staticmethod else None
insertion = self.find_no_arg(is_init, tokens, is_staticmethod, cpp_func_name, param_count, tu_cursor, class_info)
with open(log_path, "a") as hnd:
print(f"3) get_insertion insertion {insertion}", file=hnd)
hnd.flush()
return insertion
def apply_insertions(self, insertions_by_line):
"""Apply insertions
Args:
insertions_by_line (dict): dictionary of insertions keyed by line number (1-based)
Returns:
str: errors/warnings reported by clang-format
"""
res = ""
if not insertions_by_line:
return res
cpp_docmod_path = self.cpp_path_noext + FixSignatures.DOCMOD_SUFFIX + self.cpp_path_ext
with open(self.cpp_path, "r") as cpp_in:
with open(cpp_docmod_path, "w") as cpp_out:
for line_num, line in enumerate(cpp_in):
insertions = insertions_by_line.get(line_num + 1, None)
if insertions is not None:
insertions_at_line = sorted(insertions, reverse=True)
for col_num, insertion in insertions_at_line:
col_num -= 1
line = line[:col_num] + insertion + line[col_num:]
cpp_out.write(line)
with tempfile.NamedTemporaryFile() as clang_format_out:
proc = subprocess.run([CLANG_FORMAT_EXE, cpp_docmod_path], stdout=clang_format_out.file, stderr=subprocess.PIPE)
if proc.returncode:
res += f"Failed to run {CLANG_FORMAT_EXE} on {cpp_docmod_path}"
else:
shutil.copyfile(clang_format_out.name, self.cpp_path)
os.remove(cpp_docmod_path)
if proc.stderr:
msg = proc.stderr.decode("utf-8")
res += f"{CLANG_FORMAT_EXE} reported the following on stderr:\n{msg}\n"
return res
@staticmethod
def add_insertion(insertion, insertions_by_line):
"""Adds insertion to insertion dict.
Args:
insertion list(tuple): list of (line, col, text insertion) tuples
insertions_by_line (dict): dictionary of insertions keyed by line number (1-based)
"""
if insertion is not None:
for ins in insertion:
line, col, text = ins
insertions_at_line = insertions_by_line.get(line, set())
insertions_at_line.add((col, text))
insertions_by_line[line] = insertions_at_line
def parse_ast(self, arg1_func_byclass_dict):
"""Parse AST file generated by clang++.
Args:
arg1_func_byclass_dict (dict): dict relating class names to methods
that need fixing. Also free functions are included under class name
FixSignatures.NO_CLASS_KEY
"""
self.assigned_overloads = {}
try:
translation_unit = TranslationUnit.from_ast_file(self.ast_path)
out_path = self.cpp_path_noext + ".out"
with open(out_path, "w") as hnd:
self.recurse_ast_cursor(translation_unit.cursor, hnd)
log_path = self.cpp_path_noext + ".log"
with open(log_path, "w") as hnd:
pass
class_info_by_class_hash = self.find_nodes(translation_unit.cursor)
class_method_node_hashes = set(itertools.chain.from_iterable([node.hash for node in class_info.parents] for class_info in class_info_by_class_hash.values()))
arg1_non_class_func_names = arg1_func_byclass_dict.get(FixSignatures.NO_CLASS_KEY, None)
if arg1_non_class_func_names is not None:
non_class_defs = self.find_non_class_defs(translation_unit.cursor, class_method_node_hashes, arg1_non_class_func_names)
with open(log_path, "a") as hnd:
print(f"1) parse_ast cpp_path {self.cpp_path} class_info_by_class_hash {tuple(class_info_by_class_hash.keys())}", file=hnd)
hnd.flush()
class_info_by_class_name = self.prune_nodes(translation_unit.cursor, class_info_by_class_hash, arg1_func_byclass_dict)
with open(log_path, "a") as hnd:
print(f"2) parse_ast cpp_path {self.cpp_path} class_info_by_class_name {[(class_name, class_info.hash) for class_name, class_info in class_info_by_class_name.items()]}", file=hnd)
hnd.flush()
def_init_nodes_and_unassigned_func_names_by_class_name = self.find_def_init_nodes(
class_info_by_class_name, arg1_func_byclass_dict)
insertions = {}
with open(log_path, "a") as hnd:
print(f"3) parse_ast cpp_path {self.cpp_path} def_init_nodes_and_unassigned_func_names_by_class_name {def_init_nodes_and_unassigned_func_names_by_class_name}", file=hnd)
hnd.flush()
with open(log_path, "a") as hnd:
for class_name, (def_init_nodes,_unassigned_func_names) in def_init_nodes_and_unassigned_func_names_by_class_name.items():
class_info = class_info_by_class_name[class_name]
for function_def in def_init_nodes:
tokens = list(function_def.def_cursor.get_tokens())
insertion = None
boost_python_entity = "".join(t.spelling for t in tokens[:3])
is_init = False
if boost_python_entity == "python::init":
print(f"4) parse_ast cpp_path {self.cpp_path} class_name {class_name} cpp_class_name {class_info.cpp_class_name} func_name {function_def.func_name} python::init tokens {[t.spelling for t in tokens]}", file=hnd)
hnd.flush()
is_init = True
insertion = self.get_insertion(is_init, tokens[3:], tu_cursor=translation_unit.cursor, class_info=class_info)
elif boost_python_entity == "python::class_":
res = self.find_cpp_func_r(function_def.def_cursor, function_def.level, function_def.func_name)
param_count = None
cpp_func_name = None
if res is not None:
param_count, cpp_func_name = res
print(f"5) parse_ast cpp_path {self.cpp_path} cpp_func_name {cpp_func_name} func_name {function_def.func_name} param_count {param_count} tokens {[t.spelling for t in tokens[3:]]}", file=hnd)
hnd.flush()
tokens_from_func_def = self.find_func_def(function_def.func_name, tokens[3:])
if tokens_from_func_def is not None:
print(f"6) parse_ast cpp_path {self.cpp_path} python::class_ tokens_from_func_def {[t.spelling for t in tokens_from_func_def]}", file=hnd)
hnd.flush()
insertion = self.get_insertion(is_init, tokens_from_func_def, function_def.is_staticmethod, cpp_func_name, param_count, translation_unit.cursor, class_info=class_info)
self.add_insertion(insertion, insertions)
print(f"8) parse_ast cpp_path {self.cpp_path} {insertions}", file=hnd)
hnd.flush()
for func_name, def_nodes in non_class_defs.items():
for def_node in def_nodes:
tokens = list(def_node.get_tokens())
insertion = None
is_init = False
requested_level = 2
is_staticmethod = True
hnd.flush()
boost_python_entity = "".join(t.spelling for t in tokens[:3])
if boost_python_entity == "python::def":
res = self.find_cpp_func_r(def_node, requested_level, func_name)
param_count = None
cpp_func_name = None
if res is not None:
param_count, cpp_func_name = res
print(f"9) parse_ast cpp_path {self.cpp_path} cpp_func_name {cpp_func_name} func_name {func_name} param_count {param_count}", file=hnd)
hnd.flush()
tokens_from_func_def = self.find_func_def(func_name, tokens[2:])
if tokens_from_func_def is not None:
print(f"10) parse_ast cpp_path {self.cpp_path} python::def func_name {func_name} tokens_from_func_def {[t.spelling for t in tokens_from_func_def]}", file=hnd)
hnd.flush()
insertion = self.get_insertion(is_init, tokens_from_func_def, is_staticmethod, cpp_func_name, param_count, translation_unit.cursor)
self.add_insertion(insertion, insertions)
self.ast_error += self.apply_insertions(insertions)
except Exception as e:
tb = traceback.format_exc()
self.ast_error += f"{self.cpp_path}: Failed to parse AST\n{str(e)}\n{str(tb)}\n"
class ClangWorkerData(DictLike):
"""Data class passed to Worker as JSON string."""
def __init__(self, clang_flags=None):
self.clang_flags = clang_flags
self.arg1_func_byclass_dict = {}
class FixSignatures:
"""Main FixSignatures class.
Raises:
ValueError
"""
concurrency = max(1, multiprocessing.cpu_count() - 2)
log_level = "INFO"
cpp_source_path = os.environ.get("RDBASE", os.getcwd())
rdkit_stubs_path = os.path.join(os.getcwd(), f"{RDKIT_MODULE_NAME}-stubs")
clean = False
include_path = os.path.join(os.environ.get("CONDA_PREFIX", os.getcwd()), "include")
python_include_path = None
rdkit_include_path = None
clang_flags = "-emit-ast"
user_clang_flags = ""
CLANG_WORKER_SCRIPT = os.path.join(os.path.dirname(__file__), "clang_worker.py")
DOCORIG_SUFFIX = "_RDKDOCORIG"
DOCMOD_SUFFIX = "_RDKDOCMOD"
DEFINE_RDK_REGEX = re.compile(r"^\s*#define\s+(RDK_\S+)\s*$")
INCLUDE_PATH_BY_DEFINITION = {
"RDK_BUILD_COORDGEN_SUPPORT": ["External", "CoordGen"],
"RDK_USE_URF": ["External", "RingFamilies", "RingDecomposerLib", "src", "RingDecomposerLib"],
"RDK_HAS_EIGEN3": os.environ.get("EIGEN3_INCLUDE_DIR", include_path),
"RDK_BUILD_CAIRO_SUPPORT": [include_path, "cairo"],
}
NO_CLASS_KEY = "-"
def __init__(self, args=None):
"""Constructor. Runs whole workflow.
Args:
args (Namespace, optional): ArgParser args
"""
if args:
for k, v in args._get_kwargs():
setattr(self, k, v)
self.logger = logging.getLogger(self.__class__.__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(getattr(logging, self.log_level))
if self.python_include_path is None:
python_include_path = sorted(glob.glob(os.path.join(self.include_path, "python*")))
if python_include_path:
self.python_include_path = python_include_path[0]
if self.rdkit_include_path is None:
print(f"Failed to find RDKit include path. Please set {self.__class__.__name__}.rdkit_include_path")
sys.exit(1)
self.init_cpp_file_dict()
self.init_clang_worker_data()
msg = self.generate_ast_files()
if msg:
self.logger.warning(msg)
def init_cpp_file_dict(self):
"""Initialize dict relating each C++ file to patch to a CppFile object.
Also creates backups of original C++ files with DOCORIG_SUFFIX extension
if they do not exist yet. If they exist, it overwrites the current C++
file with its backup. The --clean command line switch forces re-generating
fresh backups from the current C++ file.
Raises:
ValueError
"""
cpp_source_path = Path(self.cpp_source_path)
self.cpp_file_dict = dict()
paths = [p for p in sorted(cpp_source_path.rglob("*.cpp"))
if self.DOCMOD_SUFFIX not in str(p) and self.DOCORIG_SUFFIX not in str(p)
and "Demos" not in str(p)]
cpp_paths_to_be_modified = []
for p in paths:
cpp_path = os.path.abspath(str(p))
with open(cpp_path) as hnd:
if any("python::class_" in line or "python::def" in line for line in hnd):
cpp_paths_to_be_modified.append(cpp_path)
cpp_path_noext, cpp_path_ext = os.path.splitext(cpp_path)
if cpp_path_noext in self.cpp_file_dict:
raise ValueError("There are multiple C++ files definining python::class_ "
f"or python::def sharing the same basename {cpp_path_noext} "
"but with different extensions; this should never happen")
self.cpp_file_dict[cpp_path_noext] = CppFile(cpp_path)
for cpp_path in cpp_paths_to_be_modified:
cpp_path_noext, cpp_path_ext = os.path.splitext(cpp_path)
cpp_docorig_path = cpp_path_noext + self.DOCORIG_SUFFIX + cpp_path_ext
have_docorig = os.path.exists(cpp_docorig_path)
if self.clean and have_docorig:
os.remove(cpp_docorig_path)
have_docorig = False
if not have_docorig:
shutil.copyfile(cpp_path, cpp_docorig_path)
else:
shutil.copyfile(cpp_docorig_path, cpp_path)
def get_rdk_build_flags(self):
"""Generate command line clang++ flags to build RDKit
based on the contents of RDGeneral/RDConfig.h.
Returns:
str: command line clang++ flags
"""
rdconfig_h = os.path.join(self.rdkit_include_path, RDKIT_MODULE_NAME, "RDGeneral", "RDConfig.h")
definitions = set()
includes = set()
with open(rdconfig_h, "r") as hnd:
for line in hnd:
m = self.DEFINE_RDK_REGEX.match(line)
if not m:
continue
macro_name = m.group(1)
definitions.add(macro_name)
include_path = self.INCLUDE_PATH_BY_DEFINITION.get(macro_name, None)
if include_path is None:
continue
if not isinstance(include_path, str):
include_path = os.path.join(self.cpp_source_path, *include_path)
includes.add(include_path)
return (" ".join(f"-D{d}" for d in sorted(definitions)) +
" " + " ".join(f"-I{i}" for i in sorted(includes)))
@staticmethod
def get_include_flags_from_include_path(include_path):
"""Generate command line clang++ include flags from include_path.
Args:
include_path (str): include path
Returns:
str: command line clang++ include flags
"""
include_path = include_path or ""
res = " ".join(f"-I{i}" for i in include_path.split(os.pathsep))
if res:
res = " " + res
return res
def add_func_to_dict_if_arg1(self, func, class_name=None):
"""Add the passed function to the dict of functions to be fixed if:
1. it is a callable
2. it has a docstring
3. its docstring contains arg1
Args:
func (function): candidate function
class_name (str, optional): clas name if the function is a class method
"""
arg1_func_byclass_dict = self.clang_worker_data.arg1_func_byclass_dict
if not isinstance(func.__doc__, str) or not callable(func) or "arg1" not in func.__doc__:
return
if class_name is None:
class_name = self.NO_CLASS_KEY
arg1_func_name_set = set(arg1_func_byclass_dict.get(class_name, []))
arg1_func_name_set.add(func.__name__)
arg1_func_byclass_dict[class_name] = sorted(arg1_func_name_set)
def init_clang_worker_data(self):
"""Initialize ClangWorkerData."""
rdkit_stubs_path = Path(self.rdkit_stubs_path)
python_include_path = f"-I{self.python_include_path}" if self.python_include_path else ""
rdkit_code = os.path.join(self.cpp_source_path, "Code")
rdkit_external = os.path.join(self.cpp_source_path, "External")
user_clang_flags = " " + self.user_clang_flags if self.user_clang_flags else ""
rdk_build_defs = self.get_rdk_build_flags()
qt_include_dirs = self.get_include_flags_from_include_path(os.environ.get("QT_INCLUDE_DIRS", None))
rdkit_external_path = Path(rdkit_external)
avalon_include_dir = os.path.abspath(str(max(rdkit_external_path.rglob("AvalonTools/ava-formake-AvalonToolkit_*/src/main/C/include"))))
clang_flags = (
f"-I{self.include_path} {python_include_path} -I{rdkit_code} "
f"-I{rdkit_external} -I{avalon_include_dir} -I. -I..{qt_include_dirs}"
f" {rdk_build_defs} {self.clang_flags}{user_clang_flags}"
).strip().split()
self.clang_worker_data = ClangWorkerData(clang_flags)
for p in sorted(rdkit_stubs_path.rglob("*.pyi")):
if str(p.stem) == "__init__":
p = p.parent
pyi_module_path = os.path.splitext(str(p.relative_to(rdkit_stubs_path)).replace("/", "."))[0]
if pyi_module_path == ".":
pyi_module_path = RDKIT_MODULE_NAME
else:
pyi_module_path = RDKIT_MODULE_NAME + "." + pyi_module_path
try:
pyi_module = importlib.import_module(pyi_module_path)
except Exception as e:
self.logger.warning(f"ERROR: {str(e)}")
continue
for entry_name in dir(pyi_module):
entry = getattr(pyi_module, entry_name, None)
if (entry is None):
continue
if entry.__class__.__name__ != "class":
self.add_func_to_dict_if_arg1(entry)
else:
for method_name in dir(entry):
method = getattr(entry, method_name)
self.add_func_to_dict_if_arg1(method, entry_name)
def clang_worker_thread(self, worker_idx):
"""Function run by each Worker thread.
Args:
worker_idx (int): Worker index (0-based)
"""
proc = None
res = WorkerResult(worker_idx)
while 1:
e = ""
try:
cpp_file_class = self.queue.get_nowait()
self.logger.info(f"Processing {cpp_file_class.cpp_path}")
except queue.Empty:
self.logger.debug("Queue empty")
break
if proc is None:
cmd = [sys.executable, self.CLANG_WORKER_SCRIPT, self.clang_worker_data.to_json()]
try:
self.logger.debug(f"Attempting to run '{cmd}'")
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as e:
pass
if proc is None:
res.proc_error += f"Worker {worker_idx}: failed to start process.\n"
if e:
res.proc_error += f"Exception was: {str(e)}\n"
self.queue.task_done()
continue
stdout_data = None
try:
proc.stdin.write((cpp_file_class.to_json() + "\n").encode("utf-8"))
proc.stdin.flush()
stdout_data = proc.stdout.readline()
if stdout_data:
stdout_data = stdout_data.decode("utf-8").strip()
cpp_file_class = CppFile.from_json(stdout_data)
self.cpp_file_dict[cpp_file_class.cpp_path_noext] = cpp_file_class
res.processed_cpp_files.add(cpp_file_class.cpp_path_noext)
else:
proc.poll()
except Exception:
res.proc_error += (f"Exception while attempting to send {cpp_file_class.cpp_path} for processing "
f"to {self.CLANG_WORKER_SCRIPT}:\n{str(e)}")
self.queue.task_done()
if not stdout_data and proc.returncode:
res.proc_error += f"{self.CLANG_WORKER_SCRIPT} deamon not running."
stderr_data = proc.stderr.read()
if stderr_data:
res.proc_error += "\nError was:\n" + stderr_data.decode("utf-8")
proc = None
break
if proc is not None:
try:
proc.stdin.write("\n".encode("utf-8"))
proc.stdin.flush()
stdout_data = proc.stdout.readline()
if stdout_data:
stdout_data = stdout_data.decode("utf-8").strip()
if stdout_data:
res.proc_error += f"Worker {worker_idx}: expected empty message, found:\n{stdout_data}"
else:
res.proc_error += f"Worker {worker_idx}: failed to receive empty message.\n"
except Exception:
pass
self.thread_results[worker_idx] = res
def generate_ast_files(self):
"""Generate clang++ AST files.
Returns:
str: errors generated by clang++
"""
msg = ""
self.queue = queue.Queue()
cpp_class_files = list(self.cpp_file_dict.values())
# Uncomment the following to troubleshoot specific file(s)
# cpp_class_files = [f for f in cpp_class_files if os.path.basename(f.cpp_path) == "Atom.cpp"]
n_files = len(cpp_class_files)
self.logger.debug(f"Number of files: {n_files}")
n_workers = min(self.concurrency, n_files)
self.thread_results = {}
clang_tasks = [Thread(target=self.clang_worker_thread, args=(i,), daemon=True) for i in range(n_workers)]
for cpp_class_file in cpp_class_files:
self.queue.put_nowait(cpp_class_file)
for clang_task in clang_tasks:
clang_task.start()
have_alive_thread = True
to_go_prev = n_files + 1
while have_alive_thread:
have_alive_thread = False
for clang_task in clang_tasks:
clang_task.join(timeout=0.1)
have_alive_thread |= clang_task.is_alive()
to_go_curr = [cpp_class.ast_error for cpp_class in cpp_class_files].count(None)
if to_go_curr < to_go_prev:
to_go_prev = to_go_curr
for thread_idx, res in self.thread_results.items():
if res.proc_error:
msg += f"Process error in thread {thread_idx}:\n{res.proc_error}\n"
for cpp_file_no_ext in sorted(res.processed_cpp_files):
ast_error = self.cpp_file_dict[cpp_file_no_ext].ast_error
if ast_error:
msg += f"clang AST errors in thread {thread_idx}:\n{ast_error}\n"
return msg
|