File: matcher.py

package info (click to toggle)
python-url-matcher 0.6.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 372 kB
  • sloc: python: 627; makefile: 17
file content (208 lines) | stat: -rw-r--r-- 9,078 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
The matcher module contains the UrlMatcher class.
"""

from __future__ import annotations

from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass, field
from itertools import chain
from typing import Any

from url_matcher.patterns import PatternMatcher, get_pattern_domain, hierarchical_str
from url_matcher.util import get_domain


@dataclass(init=False, frozen=True)
class Patterns:
    include: tuple[str, ...]
    exclude: tuple[str, ...]
    priority: int

    def __init__(self, include: list[str], exclude: list[str] | None = None, priority: int = 500):
        # The initialization is manually set so that we can support an API of
        # accepting and returning lists. However, tuples are being used underneath
        # that class so that the attributes are truly immutable, in addition to
        # being frozen=True.
        # Using lists are far less likely to have human typing mistakes compared to
        # tuples since the trailing `,` char can easily be missed out. For
        # example:
        #     *  ("element") is not the same as ("element",) which is a tuple.
        # Lastly, the manner of how we set the attribute values below is in line
        # with how Python's own `dataclasses` library assign attributes to frozen
        # classes. Here's a reference:
        #     * https://github.com/python/cpython/blob/v3.10.2/Lib/dataclasses.py#L1117-L1120
        object.__setattr__(self, "include", tuple(include))
        object.__setattr__(self, "exclude", tuple(exclude or []))
        object.__setattr__(self, "priority", priority)

    def get_domains(self) -> list[str]:
        domains = [get_pattern_domain(pattern) for pattern in self.include]
        # remove duplicate domains preserving the order
        return list(dict.fromkeys(domain for domain in domains if domain))

    def get_includes_without_domain(self) -> list[str]:
        return [pattern for pattern in self.include if get_pattern_domain(pattern) is None]

    def all_includes_have_domain(self) -> bool:
        """Return true if all the include patterns have a domain"""
        return not self.get_includes_without_domain()

    def is_universal_pattern(self) -> bool:
        """Return true if there are no include patterns or they are empty. A universal pattern matches any domain"""
        return not any(pattern for pattern in self.include)

    def get_includes_for(self, domain: str) -> list[str]:
        return [pattern for pattern in self.include if get_pattern_domain(pattern) == domain]


@dataclass
class PatternsMatcher:
    identifier: Any
    patterns: Patterns
    include_matchers: list[PatternMatcher] = field(init=False)
    exclude_matchers: list[PatternMatcher] = field(init=False)

    def __post_init__(self) -> None:
        self.include_matchers = [PatternMatcher(pattern) for pattern in self.patterns.include]
        self.exclude_matchers = [PatternMatcher(pattern) for pattern in self.patterns.exclude]

    def match(self, url: str) -> bool:
        if self.include_matchers:
            for include in self.include_matchers:
                if include.match(url):
                    break
            else:
                return False
        return not any(exclude.match(url) for exclude in self.exclude_matchers)


class IncludePatternsWithoutDomainError(ValueError):
    def __init__(self, *args: Any, identifier: Any, patterns: Patterns, wrong_patterns: list[str]):
        super().__init__(*args)
        self.id = identifier
        self.patterns = patterns
        self.wrong_patterns = wrong_patterns


class URLMatcher:
    def __init__(self, data: Mapping[Any, Patterns] | Iterable[tuple[Any, Patterns]] | None = None):
        """
        A class that matches URLs against a list of patterns, returning
        the identifier of the rule that matched the URL.

        Example usage::

            matcher = URLMatcher()
            matcher.add_or_update(1, Patterns(include=["example.com/product"]))
            matcher.add_or_update(2, Patterns(include=["other.com"]))

            assert matcher.match("http://example.com/product/a_product.html") == 1
            assert matcher.match("http://other.com/a_different_page") == 2

        :param data: A map or a list of tuples with identifier, patterns pairs to
                     initialize the object from
        """
        self.matchers_by_domain: dict[str, list[PatternsMatcher]] = {}
        self.matchers_universal: list[PatternsMatcher] = []
        self.patterns: dict[Any, Patterns] = {}

        if data:
            items = data.items() if isinstance(data, Mapping) else data
            for identifier, patterns in items:
                self.add_or_update(identifier, patterns)

    def add_or_update(self, identifier: Any, patterns: Patterns) -> None:
        if not patterns.all_includes_have_domain() and not patterns.is_universal_pattern():
            wrong_patterns = [p for p in patterns.get_includes_without_domain() if p]
            raise IncludePatternsWithoutDomainError(
                f"All include patterns must belong to a domain "
                f"but the patterns {wrong_patterns} doesn't. "
                f"For example, the include pattern '/product/* "
                f"is invalid whereas the pattern 'example.com/product/*' isn't. "
                f"The only exception is the empty pattern which matches everything "
                f"and it is allowed. "
                f"identifier: {identifier}.",
                identifier=identifier,
                patterns=patterns,
                wrong_patterns=wrong_patterns,
            )
        if identifier in self.patterns:
            self.remove(identifier)
        self.patterns[identifier] = patterns
        matcher = PatternsMatcher(identifier, patterns)
        for domain in patterns.get_domains():
            self._add_matcher(domain, matcher)
        if patterns.is_universal_pattern():
            self._add_matcher("", matcher)

    def remove(self, identifier: Any) -> None:
        patterns = self.patterns.get(identifier)
        if not patterns:
            return
        del self.patterns[identifier]
        for domain in patterns.get_domains():
            self._del_matcher(domain, identifier)
        if patterns.is_universal_pattern():
            self._del_matcher("", identifier)

    def get(self, identifier: Any) -> Patterns | None:
        return self.patterns.get(identifier)

    def match(self, url: str, *, include_universal: bool = True) -> Any | None:
        return next(self.match_all(url, include_universal=include_universal), None)

    def match_all(self, url: str, *, include_universal: bool = True) -> Iterator[Any]:
        domain = get_domain(url)
        matchers: Iterable[PatternsMatcher] = self.matchers_by_domain.get(domain) or []
        if include_universal:
            matchers = chain(matchers, self.matchers_universal)
        for matcher in matchers:
            if matcher.match(url):
                yield matcher.identifier

    def match_universal(self) -> Iterator[Any]:
        return (m.identifier for m in self.matchers_universal)

    def _sort_domain(self, domain: str) -> None:
        """
        Sort all the rules within a domain so that the matching can be done in sequence:
        the first rule matching wins.

        A total ordering is defined. This is ensured by using including
        the identifier in the sorting criteria

        Sorting criteria:
          * Priority (descending)
          * Sorted list of includes for this domain (descending)
          * Rule identifier (descending)
        """

        def sort_key(matcher: PatternsMatcher) -> tuple[int, list[str], Any]:
            sorted_includes = sorted(map(hierarchical_str, matcher.patterns.get_includes_for(domain)))
            return (matcher.patterns.priority, sorted_includes, matcher.identifier)

        self.matchers_by_domain[domain].sort(key=sort_key, reverse=True)
        self.matchers_universal.sort(key=sort_key, reverse=True)

    def _del_matcher(self, domain: str, identifier: Any) -> None:
        matchers = self.matchers_by_domain[domain]
        for idx in range(len(matchers)):
            if matchers[idx].identifier == identifier:
                del matchers[idx]
                break
        if not matchers:
            del self.matchers_by_domain[domain]
        for idx in range(len(self.matchers_universal)):
            if self.matchers_universal[idx].identifier == identifier:
                del self.matchers_universal[idx]
                break

    def _add_matcher(self, domain: str, matcher: PatternsMatcher) -> None:
        # FIXME: This can be made much more efficient if we insert the data directly in order instead of resorting.
        # The bisect module could be used for this purpose.
        # I'm leaving it for the future as insertion time is not critical.
        self.matchers_by_domain.setdefault(domain, []).append(matcher)
        if domain == "":
            self.matchers_universal.append(matcher)
        self._sort_domain(domain)