File: lookups.py

package info (click to toggle)
python-django-netfields 1.4.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 444 kB
  • sloc: python: 2,067; sh: 7; makefile: 4; sql: 2
file content (213 lines) | stat: -rw-r--r-- 5,999 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import warnings
from django.core.exceptions import FieldError
from django.db.models import Lookup, Transform, IntegerField
from django.db.models.lookups import (
    EndsWith,
    IEndsWith,
    StartsWith,
    IStartsWith,
    Regex,
    IRegex,
)
import ipaddress
from netfields.fields import InetAddressField, CidrAddressField


class InvalidLookup(Lookup):
    """
    Emulate Django 1.9 error for unsupported lookups
    """

    def as_sql(self, qn, connection):
        raise FieldError("Unsupported lookup '%s'" % self.lookup_name)


class InvalidSearchLookup(Lookup):
    """
    Emulate Django 1.9 error for unsupported search lookup
    """

    lookup_name = "search"

    def as_sql(self, qn, connection):
        raise NotImplementedError(
            "Full-text search is not implemented for this database backend"
        )


class NetFieldDecoratorMixin(object):
    def process_lhs(self, qn, connection, lhs=None):
        lhs = lhs or self.lhs
        lhs_string, lhs_params = qn.compile(lhs)
        if isinstance(
            lhs.source if hasattr(lhs, "source") else lhs.output_field, InetAddressField
        ):
            lhs_string = "HOST(%s)" % lhs_string
        elif isinstance(
            lhs.source if hasattr(lhs, "source") else lhs.output_field, CidrAddressField
        ):
            lhs_string = "TEXT(%s)" % lhs_string
        return lhs_string, list(lhs_params)


class EndsWith(NetFieldDecoratorMixin, EndsWith):
    pass


class IEndsWith(NetFieldDecoratorMixin, IEndsWith):
    pass


class StartsWith(NetFieldDecoratorMixin, StartsWith):
    pass


class IStartsWith(NetFieldDecoratorMixin, IStartsWith):
    pass


class Regex(NetFieldDecoratorMixin, Regex):
    pass


class IRegex(NetFieldDecoratorMixin, IRegex):
    pass


class NetworkLookup(object):
    def get_prep_lookup(self):
        if hasattr(self.rhs, "resolve_expression"):
            return self.rhs
        if isinstance(self.rhs, ipaddress._BaseNetwork):
            return str(self.rhs)
        return str(ipaddress.ip_network(self.rhs))


class AddressLookup(object):
    def get_prep_lookup(self):
        if hasattr(self.rhs, "resolve_expression"):
            return self.rhs
        if isinstance(self.rhs, ipaddress._BaseAddress):
            return str(self.rhs)
        return str(ipaddress.ip_interface(self.rhs))


class NetContains(AddressLookup, Lookup):
    lookup_name = "net_contains"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s >> %s" % (lhs, rhs), params


class NetContained(NetworkLookup, Lookup):
    lookup_name = "net_contained"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s << %s" % (lhs, rhs), params


class NetContainsOrEquals(AddressLookup, Lookup):
    lookup_name = "net_contains_or_equals"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s >>= %s" % (lhs, rhs), params


class NetContainedOrEqual(NetworkLookup, Lookup):
    lookup_name = "net_contained_or_equal"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s <<= %s" % (lhs, rhs), params


class NetOverlaps(NetworkLookup, Lookup):
    lookup_name = "net_overlaps"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s && %s" % (lhs, rhs), params


class HostMatches(AddressLookup, Lookup):
    lookup_name = "host"

    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "HOST(%s) = HOST(%s)" % (lhs, rhs), params


class Family(Transform):
    lookup_name = "family"

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return "family(%s)" % lhs, params

    @property
    def output_field(self):
        return IntegerField()


class _PrefixlenMixin(object):
    format_string = None

    def as_sql(self, qn, connection):
        warnings.warn(
            "min_prefixlen and max_prefixlen will be depreciated in the future; "
            "use prefixlen__gte and prefixlen__lte respectively",
            DeprecationWarning,
        )
        assert (
            self.format_string is not None
        ), "Prefixlen lookups must specify a format_string"
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return self.format_string % (lhs, rhs), params

    def process_lhs(self, qn, connection, lhs=None):
        lhs = lhs or self.lhs
        lhs_string, lhs_params = qn.compile(lhs)
        lhs_string = "MASKLEN(%s)" % lhs_string
        return lhs_string, lhs_params

    def get_prep_lookup(self):
        return str(int(self.rhs))


class MaxPrefixlen(_PrefixlenMixin, Lookup):
    lookup_name = "max_prefixlen"
    format_string = "%s <= %s"


class MinPrefixlen(_PrefixlenMixin, Lookup):
    lookup_name = "min_prefixlen"
    format_string = "%s >= %s"


class Prefixlen(Transform):
    lookup_name = "prefixlen"

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return "masklen(%s)" % lhs, params

    @property
    def output_field(self):
        return IntegerField()