File: common.py

package info (click to toggle)
python-pynvim 0.6.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 424 kB
  • sloc: python: 3,067; makefile: 4
file content (255 lines) | stat: -rw-r--r-- 8,148 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Code shared between the API classes."""
import functools
import sys
from abc import ABC, abstractmethod
from typing import (Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar,
                    Union, overload)

from msgpack import unpackb
if sys.version_info < (3, 8):
    from typing_extensions import Literal, Protocol
else:
    from typing import Literal, Protocol

from pynvim.compat import unicode_errors_default

__all__ = ()


T = TypeVar('T')
TDecodeMode = Union[Literal[True], str]


class NvimError(Exception):
    pass


class IRemote(Protocol):
    def request(self, name: str, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError


class Remote(ABC):

    """Base class for Nvim objects(buffer/window/tabpage).

    Each type of object has it's own specialized class with API wrappers around
    the msgpack-rpc session. This implements equality which takes the remote
    object handle into consideration.
    """

    def __init__(self, session: IRemote, code_data: Tuple[int, Any]):
        """Initialize from session and code_data immutable object.

        The `code_data` contains serialization information required for
        msgpack-rpc calls. It must be immutable for Buffer equality to work.
        """
        self._session = session
        self.code_data = code_data
        self.handle = unpackb(code_data[1])
        self.api = RemoteApi(self, self._api_prefix)
        self.vars = RemoteMap(self, self._api_prefix + 'get_var',
                              self._api_prefix + 'set_var',
                              self._api_prefix + 'del_var')
        self.options = RemoteMap(self, self._api_prefix + 'get_option',
                                 self._api_prefix + 'set_option')

    @property
    @abstractmethod
    def _api_prefix(self) -> str:
        raise NotImplementedError()

    def __repr__(self) -> str:
        """Get text representation of the object."""
        return '<%s(handle=%r)>' % (
            self.__class__.__name__,
            self.handle,
        )

    def __eq__(self, other: Any) -> bool:
        """Return True if `self` and `other` are the same object."""
        return (hasattr(other, 'code_data')
                and other.code_data == self.code_data)

    def __hash__(self) -> int:
        """Return hash based on remote object id."""
        return self.code_data.__hash__()

    def request(self, name: str, *args: Any, **kwargs: Any) -> Any:
        """Wrapper for nvim.request."""
        return self._session.request(name, self, *args, **kwargs)


class RemoteApi:
    """Wrapper to allow api methods to be called like python methods."""

    def __init__(self, obj: IRemote, api_prefix: str):
        """Initialize a RemoteApi with object and api prefix."""
        self._obj = obj
        self._api_prefix = api_prefix

    def __getattr__(self, name: str) -> Callable[..., Any]:
        """Return wrapper to named api method."""
        return functools.partial(self._obj.request, self._api_prefix + name)


E = TypeVar('E', bound=Exception)


def transform_keyerror(exc: E) -> Union[E, KeyError]:
    if isinstance(exc, NvimError):
        if exc.args[0].startswith('Key not found:'):
            return KeyError(exc.args[0])
        if exc.args[0].startswith('Invalid option name:'):
            return KeyError(exc.args[0])
    return exc


class RemoteMap:
    """Represents a string->object map stored in Nvim.

    This is the dict counterpart to the `RemoteSequence` class, but it is used
    as a generic way of retrieving values from the various map-like data
    structures present in Nvim.

    It is used to provide a dict-like API to vim variables and options.
    """

    _set = None
    _del = None

    def __init__(
        self,
        obj: IRemote,
        get_method: str,
        set_method: Optional[str] = None,
        del_method: Optional[str] = None
    ):
        """Initialize a RemoteMap with session, getter/setter."""
        self._get = functools.partial(obj.request, get_method)
        if set_method:
            self._set = functools.partial(obj.request, set_method)
        if del_method:
            self._del = functools.partial(obj.request, del_method)

    def __getitem__(self, key: str) -> Any:
        """Return a map value by key."""
        try:
            return self._get(key)
        except NvimError as exc:
            raise transform_keyerror(exc)

    def __setitem__(self, key: str, value: Any) -> None:
        """Set a map value by key(if the setter was provided)."""
        if not self._set:
            raise TypeError('This dict is read-only')
        self._set(key, value)

    def __delitem__(self, key: str) -> None:
        """Delete a map value by associating None with the key."""
        if not self._del:
            raise TypeError('This dict is read-only')
        try:
            return self._del(key)
        except NvimError as exc:
            raise transform_keyerror(exc)

    def __contains__(self, key: str) -> bool:
        """Check if key is present in the map."""
        try:
            self._get(key)
            return True
        except Exception:
            return False

    @overload
    def get(self, key: str, default: T) -> T: ...

    @overload
    def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ...

    def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
        """Return value for key if present, else a default value."""
        try:
            return self.__getitem__(key)
        except KeyError:
            return default


class RemoteSequence(Generic[T]):

    """Represents a sequence of objects stored in Nvim.

    This class is used to wrap msgpack-rpc functions that work on Nvim
    sequences(of lines, buffers, windows and tabpages) with an API that
    is similar to the one provided by the python-vim interface.

    For example, the 'windows' property of the `Nvim` class is a RemoteSequence
    sequence instance, and the expression `nvim.windows[0]` is translated to
    session.request('nvim_list_wins')[0].

    One important detail about this class is that all methods will fetch the
    sequence into a list and perform the necessary manipulation
    locally(iteration, indexing, counting, etc).
    """

    def __init__(self, session: IRemote, method: str):
        """Initialize a RemoteSequence with session, method."""
        self._fetch = functools.partial(session.request, method)

    def __len__(self) -> int:
        """Return the length of the remote sequence."""
        return len(self._fetch())

    @overload
    def __getitem__(self, idx: int) -> T: ...

    @overload
    def __getitem__(self, idx: slice) -> List[T]: ...

    def __getitem__(self, idx: Union[slice, int]) -> Union[T, List[T]]:
        """Return a sequence item by index."""
        if not isinstance(idx, slice):
            return self._fetch()[idx]
        return self._fetch()[idx.start:idx.stop]

    def __iter__(self) -> Iterator[T]:
        """Return an iterator for the sequence."""
        items = self._fetch()
        for item in items:
            yield item

    def __contains__(self, item: T) -> bool:
        """Check if an item is present in the sequence."""
        return item in self._fetch()


@overload
def decode_if_bytes(obj: bytes, mode: TDecodeMode = True) -> str: ...


@overload
def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]: ...


def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]:
    """Decode obj if it is bytes."""
    if mode is True:
        mode = unicode_errors_default
    if isinstance(obj, bytes):
        return obj.decode("utf-8", errors=mode)
    return obj


def walk(fn: Callable[[Any], Any], obj: Any) -> Any:
    """Recursively walk an object graph applying `fn` to objects."""

    # Note: this function is very hot, so it is worth being careful
    # about performance.
    type_ = type(obj)

    if type_ is list or type_ is tuple:
        return [walk(fn, o) for o in obj]
    if type_ is dict:
        return {walk(fn, k): walk(fn, v) for k, v in obj.items()}
    return fn(obj)