File: selectors.py

package info (click to toggle)
python-xmlschema 4.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 5,208 kB
  • sloc: python: 39,174; xml: 1,282; makefile: 36
file content (228 lines) | stat: -rw-r--r-- 8,280 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#
# Copyright (c), 2025, SISSA (International School for Advanced Studies).
# All rights reserved.
# This file is distributed under the terms of the MIT License.
# See the file 'LICENSE' in the root directory of the present
# distribution, or http://opensource.org/licenses/MIT.
#
# @author Davide Brunato <brunato@sissa.it>
#
from collections import deque
from collections.abc import Callable, Iterable, Iterator
from functools import cached_property
from typing import cast, Optional, TYPE_CHECKING, Union
from xml.etree.ElementTree import Element

from elementpath import XPath2Parser, XPathToken, ElementNode, XPathContext

from xmlschema.aliases import ElementType, NsmapType
from xmlschema.exceptions import XMLSchemaTypeError

if TYPE_CHECKING:
    from xmlschema.resources import XMLResource

CacheKeyType = Union[
    tuple[str, type['ElementSelector']],
    tuple[Union[str, type['ElementSelector'], Iterable[tuple[str, str]], tuple[str, str]], ...]
]

_selectors_cache: dict[CacheKeyType, 'ElementSelector'] = {}
_dummy_element = Element('dummy')


def is_ncname(s: str) -> bool:
    return s.isalpha() and ':' not in s and all(is_ncname_continuation(c) for c in s[1:])


def is_ncname_continuation(c: str) -> bool:
    return (c.isalnum() or c in '-.\u00B7\u0387\u06DD\u06DE\u203F\u2040'
            or 0x300 <= ord(c) <= 0x36F)


def split_path(path: str, namespaces: Optional[NsmapType] = None,
               extended_names: bool = False) -> deque[str]:
    """
    Splits a path expression to a sequence of chunks that put in evidence path steps,
    predicates and other parts that can be useful for checking some properties of the
    provided path, like the path depth of if the path is composed only by path steps
    and wildcards.

    :param path: the path expression to split.
    :param namespaces: an optional namespace mapping to use on prefixed names.
    :param extended_names: if `True` maps prefixed names to extended form. For \
    default only the default namespace is used, if defined and not empty.
    """
    start = end = 0

    def flush() -> None:
        nonlocal start
        if start < end:
            chunks.append(path[start:end])
            start = end

    def advance(condition: Callable[[str], bool]) -> None:
        nonlocal end
        end += 1
        while condition(path[end]):
            end += 1

    path = path.replace(' ', '').replace('\t', '').replace('./', '')  # path normalization
    chunks: deque[str] = deque([''])  # add an empty element to avoid index errors
    default_namespace = None if not namespaces else namespaces.get('')

    while True:
        try:
            flush()
            if path[end] in '"\'':
                advance(lambda x: x != path[start])
                end += 1
            elif path[end] == '{':
                advance(lambda x: x != '}')
                end += 1
                if path[end].isalpha():
                    advance(is_ncname_continuation)
            elif path[end].isalpha():
                advance(is_ncname_continuation)
                if path[end] == ':':
                    prefix = path[start:end]
                    end += 1
                    if path[end].isalpha():
                        advance(is_ncname_continuation)
                        if extended_names and namespaces and prefix in namespaces:
                            flush()
                            uri = namespaces[prefix]
                            chunks[-1] = f'{{{uri}}}{chunks[-1][len(prefix)+1:]}'

                elif default_namespace and not chunks[-1].endswith('@'):
                    flush()
                    chunks[-1] = f'{{{default_namespace}}}{chunks[-1]}'
            elif path[end] == '/':
                advance(lambda x: x == '/')
            else:
                end += 1

        except IndexError:
            if start < len(path):
                flush()
                if default_namespace and is_ncname(chunks[-1]) and chunks[-2] != '@':
                    chunks[-1] = f'{{{default_namespace}}}{chunks[-1]}'

            chunks.popleft()
            return chunks


class ElementSelector:
    """
    An XPath selector for selecting ElementTree elements. Raises an error
    if the path parse fails or is incompatible with the selector type.

    :param path: the XPath expression.
    :param namespaces: an optional namespace mapping.
    """

    path: str
    """The normalized XPath expression of the path provided by argument."""

    namespaces: Optional[dict[str, str]]
    """The namespaces mapping associated with the XPath expression path."""

    _parser: XPath2Parser
    _token: XPathToken

    @classmethod
    def cached_selector(cls, path: str, namespaces: Optional[NsmapType] = None) \
            -> 'ElementSelector':
        """A builder of ElementSelector instances based on a cache."""
        key: CacheKeyType = (path, cls)
        if namespaces is not None:
            key += tuple(sorted(namespaces.items()))

        try:
            return _selectors_cache[key]
        except KeyError:
            if len(_selectors_cache) > 100:
                _selectors_cache.clear()

            selector = cls(path, namespaces)
            _selectors_cache[key] = selector
            return selector

    def __init__(self, path: str, namespaces: Optional[NsmapType] = None) -> None:
        self.namespaces = None if namespaces is None else {k: v for k, v in namespaces.items()}
        self._parts = split_path(path, namespaces)

        self.path = ''.join(self._parts)

        self._parser = XPath2Parser(namespaces, strict=False)
        self._token = self._parser.parse(self.path)
        self.select(_dummy_element)

    def __repr__(self) -> str:
        return '%s(path=%r, )' % (self.__class__.__name__, self.path)

    @property
    def parts(self) -> list[str]:
        """Return a list with the parts of the parsed path."""
        return list(self._parts)

    @cached_property
    def relative_path(self) -> str:
        """The equivalent path expression relative to root element."""
        parts = self._parts.copy()
        if not parts:
            parts.appendleft('.')
        elif parts[0] == '//':
            parts.appendleft('.')
        elif parts[0] == '/':
            parts.popleft()
            while parts:
                if parts[0].startswith('/'):
                    break
                parts.popleft()
            parts.appendleft('.')
        return ''.join(parts)

    @cached_property
    def select_all(self) -> bool:
        """Returns `True` if the path is composed only by wildcards or path steps."""
        return all(c in ('*', '/', '.') for c in self._parts)

    @cached_property
    def depth(self) -> int:
        """Path depth, 0 means a self axis selector, -1 means an unlimited depth."""
        if not self._parts:
            return 0
        elif '//' in self._parts:
            return -1
        elif self._parts[0] == '/':
            return sum(s == '/' for s in self._parts) - 1
        elif self._parts[0] == '.':
            return sum(s == '/' for s in self._parts)
        else:
            return sum(s == '/' for s in self._parts) + 1

    def select(self, root: Union[ElementType, 'XMLResource']) -> list[ElementType]:
        return list(self.iter_select(root))

    def iter_select(self, root: Union[ElementType, 'XMLResource']) -> Iterator[Element]:
        if hasattr(root, 'xpath_root'):
            context = XPathContext(root.xpath_root)
        else:
            context = XPathContext(root)

        for item in self._token.select(context):
            if not isinstance(item, ElementNode):  # pragma: no cover
                msg = "XPath expressions on XML resources can select only elements"
                raise XMLSchemaTypeError(msg)
            yield cast(ElementType, item.obj)


class ElementPathSelector(ElementSelector):
    """
    An XPath selector that uses `xml.etree.ElementPath.iterfind()` for selecting elements.
    """
    def iter_select(self, root: Union[ElementType, 'XMLResource']) -> Iterator[ElementType]:
        if hasattr(root, 'root'):
            yield from root.root.iterfind(self.relative_path, self.namespaces)
        else:
            yield from root.iterfind(self.relative_path, self.namespaces)