1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
#!/usr/bin/python3
import argparse
import logging
import re
from rpaths import Path
import sys
def to_bytes(s):
if isinstance(s, bytes):
return s
else:
return s.encode('ascii')
FUTURES = set([b'absolute_import', b'division', b'print_function',
b'unicode_literals'])
def make_import_statement(imports):
statement = b'from __future__ import '
for i, feature in enumerate(imports):
notfirst = 2 if i != 0 else 0
notlast = 4 if i != len(imports) - 1 else 0
if len(statement) + notfirst + len(feature) + notlast <= 79:
statement += (b', ' if notfirst else b'') + feature
else:
statement += (b', ' if notfirst else b'') + b'\\\n' + feature
return statement + b'\n'
re_whitespace = re.compile(br'^\s*$')
re_comment = re.compile(br'^\s*#')
re_import = re.compile(br'^\s*from\s+__future__\s+import(\s.+)$')
def process_file(filename, enable):
logging.debug("Processing %s..." % filename)
with filename.open('rb') as fp:
lines = fp.readlines()
# Look for an import statement to check or fix
i = 0
nb = 0
done = False
while i < len(lines):
nb += 1
line = lines[i]
# Merges lines on ending backslash
if line.rstrip(b'\r\n').endswith(b'\\'):
line += lines.pop(i+1)
lines[i] = line
f_import = re_import.match(line)
if f_import is not None:
imports = f_import.group(1).strip(b' \t\r\n()').split(b',')
imports = [feature.strip() for feature in imports]
missing = enable - set(imports)
if not missing:
logging.debug("Found needed imports (line %d)" % nb)
done = True
else:
logging.debug("Enabling imports: %s (line %d)" % (
', '.join(missing),
nb))
imports = sorted(set(imports) | enable)
line = make_import_statement(imports)
lines[i] = line
done = True
break
i += 1
# Didn't find a statement, must add one
# Where? Before first non-comment and non-docstring line
if not done:
i = 0
docstring_seen = False
while i < len(lines):
line = lines[i]
if re_whitespace.match(line):
pass
elif re_comment.match(line):
pass
elif not docstring_seen and '"""' in line:
if line.count('"""') == 1:
i += 1
while '"""' not in lines[i]:
i += 1
docstring_seen = True
elif not docstring_seen and "'''" in line:
if line.count("'''") == 1:
i += 1
while "'''" not in lines[i]:
i += 1
docstring_seen = True
else:
# Code here! Insert before
logging.debug("Inserting imports (line %d)" % (i + 1))
lines.insert(i, make_import_statement(enable) + b'\n')
break
i += 1
with filename.open('wb') as fp:
fp.writelines(lines)
def main():
parser = argparse.ArgumentParser(
description="Adds __future__ imports to Python files")
parser.add_argument('-v', '--verbose', action='count', dest='verbosity',
default=1)
parser.add_argument('-e', '--enable', action='append',
help="Future import to enable")
parser.add_argument('file',
nargs=argparse.ONE_OR_MORE,
help="File or directory in which to replace")
args = parser.parse_args()
levels = [logging.CRITICAL, logging.WARNING, logging.INFO, logging.DEBUG]
logging.basicConfig(level=levels[args.verbosity])
if not args.enable:
logging.critical("Nothing to do")
sys.exit(1)
enable = set(to_bytes(feature) for feature in args.enable)
unrecognized = enable - FUTURES
if unrecognized:
logging.critical("Error: unknown futures %s" % ', '.join(unrecognized))
sys.exit(1)
for target in args.file:
target = Path(target)
if target.is_file():
if not target.name.endswith('.py'):
logging.warning("File %s doesn't end with .py, processing "
"anyway..." % target)
process_file(target, enable)
elif target.is_dir():
logging.info("Processing %s recursively..." % target)
for filename in target.recursedir('*.py'):
process_file(filename, enable)
else:
logging.warning("Skipping %s..." % target)
if __name__ == '__main__':
main()
|