1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
|
from astroid import MANAGER, scoped_nodes, nodes, inference_tip
import sys
from pylint_django import utils
_STR_FIELDS = ('CharField', 'SlugField', 'URLField', 'TextField', 'EmailField',
'CommaSeparatedIntegerField', 'FilePathField', 'GenericIPAddressField',
'IPAddressField', 'RegexField', 'SlugField')
_INT_FIELDS = ('IntegerField', 'SmallIntegerField', 'BigIntegerField',
'PositiveIntegerField', 'PositiveSmallIntegerField')
_BOOL_FIELDS = ('BooleanField', 'NullBooleanField')
def is_model_field(cls):
return cls.qname().startswith('django.db.models.fields')
def is_form_field(cls):
return cls.qname().startswith('django.forms.fields')
def is_model_or_form_field(cls):
return is_model_field(cls) or is_form_field(cls)
def apply_type_shim(cls, context=None):
if cls.name in _STR_FIELDS:
base_nodes = scoped_nodes.builtin_lookup('str')
elif cls.name in _INT_FIELDS:
base_nodes = scoped_nodes.builtin_lookup('int')
elif cls.name in _BOOL_FIELDS:
base_nodes = scoped_nodes.builtin_lookup('bool')
elif cls.name == 'FloatField':
base_nodes = scoped_nodes.builtin_lookup('float')
elif cls.name == 'DecimalField':
if sys.version_info >= (3, 5):
# I dunno, I'm tired and this works :(
base_nodes = MANAGER.ast_from_module_name('_decimal').lookup('Decimal')
else:
base_nodes = MANAGER.ast_from_module_name('decimal').lookup('Decimal')
elif cls.name in ('SplitDateTimeField', 'DateTimeField'):
base_nodes = MANAGER.ast_from_module_name('datetime').lookup('datetime')
elif cls.name == 'TimeField':
base_nodes = MANAGER.ast_from_module_name('datetime').lookup('time')
elif cls.name == 'DateField':
base_nodes = MANAGER.ast_from_module_name('datetime').lookup('date')
elif cls.name == 'ManyToManyField':
base_nodes = MANAGER.ast_from_module_name('django.db.models.query').lookup('QuerySet')
elif cls.name in ('ImageField', 'FileField'):
base_nodes = MANAGER.ast_from_module_name('django.core.files.base').lookup('File')
else:
return iter([cls])
# XXX: for some reason, with python3, this particular line triggers a
# check in the StdlibChecker for deprecated methods; one of these nodes
# is an ImportFrom which has no qname() method, causing the checker
# to die...
if utils.PY3:
base_nodes = [n for n in base_nodes[1] if not isinstance(n, nodes.ImportFrom)]
else:
base_nodes = list(base_nodes[1])
return iter([cls] + base_nodes)
def add_transforms(manager):
manager.register_transform(nodes.Class, inference_tip(apply_type_shim), is_model_or_form_field)
|