File: host.py

package info (click to toggle)
python-pynvim 0.5.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 432 kB
  • sloc: python: 3,040; makefile: 4
file content (299 lines) | stat: -rw-r--r-- 11,764 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# type: ignore
"""Implements a Nvim host for python plugins."""

import importlib
import inspect
import logging
import os
import os.path
import pathlib
import re
import sys
from functools import partial
from traceback import format_exc
from typing import Any, Sequence

from pynvim.api import Nvim, decode_if_bytes, walk
from pynvim.msgpack_rpc import ErrorResponse
from pynvim.plugin import script_host
from pynvim.util import format_exc_skip, get_client_info

__all__ = ('Host',)

logger = logging.getLogger(__name__)
error, debug, info, warn = (logger.error, logger.debug, logger.info,
                            logger.warning,)

host_method_spec = {"poll": {}, "specs": {"nargs": 1}, "shutdown": {}}


def _handle_import(path: str, name: str):
    """Import python module `name` from a known file path or module directory.

    The path should be the base directory from which the module can be imported.
    To support python 3.12, the use of `imp` should be avoided.
    @see https://docs.python.org/3.12/whatsnew/3.12.html#imp
    """
    if not name:
        raise ValueError("Missing module name.")

    sys.path.append(path)
    return importlib.import_module(name)


class Host(object):

    """Nvim host for python plugins.

    Takes care of loading/unloading plugins and routing msgpack-rpc
    requests/notifications to the appropriate handlers.
    """

    def __init__(self, nvim: Nvim):
        """Set handlers for plugin_load/plugin_unload."""
        self.nvim = nvim
        self._specs = {}
        self._loaded = {}
        self._load_errors = {}
        self._notification_handlers = {
            'nvim_error_event': self._on_error_event
        }
        self._request_handlers = {
            'poll': lambda: 'ok',
            'specs': self._on_specs_request,
            'shutdown': self.shutdown
        }

        self._decode_default = True

    def _on_async_err(self, msg: str) -> None:
        # uncaught python exception
        self.nvim.err_write(msg, async_=True)

    def _on_error_event(self, kind: Any, msg: str) -> None:
        # error from nvim due to async request
        # like nvim.command(..., async_=True)
        errmsg = "{}: Async request caused an error:\n{}\n".format(
            self.name, decode_if_bytes(msg))
        self.nvim.err_write(errmsg, async_=True)
        return errmsg

    def start(self, plugins):
        """Start listening for msgpack-rpc requests and notifications."""
        self.nvim.run_loop(self._on_request,
                           self._on_notification,
                           lambda: self._load(plugins),
                           err_cb=self._on_async_err)

    def shutdown(self) -> None:
        """Shutdown the host."""
        self._unload()
        self.nvim.stop_loop()

    def _wrap_delayed_function(self, cls, delayed_handlers, name, sync,
                               module_handlers, path, *args):
        # delete the delayed handlers to be sure
        for handler in delayed_handlers:
            method_name = handler._nvim_registered_name
            if handler._nvim_rpc_sync:
                del self._request_handlers[method_name]
            else:
                del self._notification_handlers[method_name]
        # create an instance of the plugin and pass the nvim object
        plugin = cls(self._configure_nvim_for(cls))

        # discover handlers in the plugin instance
        self._discover_functions(plugin, module_handlers, path, False)

        if sync:
            return self._request_handlers[name](*args)
        else:
            return self._notification_handlers[name](*args)

    def _wrap_function(self, fn, sync, decode, nvim_bind, name, *args):
        if decode:
            args = walk(decode_if_bytes, args, decode)
        if nvim_bind is not None:
            args.insert(0, nvim_bind)
        try:
            return fn(*args)
        except Exception:
            if sync:
                msg = ("error caught in request handler '{} {}':\n{}"
                       .format(name, args, format_exc_skip(1)))
                raise ErrorResponse(msg)
            else:
                msg = ("error caught in async handler '{} {}'\n{}\n"
                       .format(name, args, format_exc_skip(1)))
                self._on_async_err(msg + "\n")

    def _on_request(self, name: str, args: Sequence[Any]) -> None:
        """Handle a msgpack-rpc request."""
        name = decode_if_bytes(name)
        handler = self._request_handlers.get(name, None)
        if not handler:
            msg = self._missing_handler_error(name, 'request')
            pass # replaces next logging statement
            # error(msg)
            raise ErrorResponse(msg)

        pass # replaces next logging statement
        # debug('calling request handler for "%s", args: "%s"', name, args)
        rv = handler(*args)
        pass # replaces next logging statement
        # debug("request handler for '%s %s' returns: %s", name, args, rv)
        return rv

    def _on_notification(self, name: str, args: Sequence[Any]) -> None:
        """Handle a msgpack-rpc notification."""
        name = decode_if_bytes(name)
        handler = self._notification_handlers.get(name, None)
        if not handler:
            msg = self._missing_handler_error(name, 'notification')
            pass # replaces next logging statement
            # error(msg)
            self._on_async_err(msg + "\n")
            return

        pass # replaces next logging statement
        # debug('calling notification handler for "%s", args: "%s"', name, args)
        handler(*args)

    def _missing_handler_error(self, name, kind):
        msg = 'no {} handler registered for "{}"'.format(kind, name)
        pathmatch = re.match(r'(.+):[^:]+:[^:]+', name)
        if pathmatch:
            loader_error = self._load_errors.get(pathmatch.group(1))
            if loader_error is not None:
                msg = msg + "\n" + loader_error
        return msg

    def _load(self, plugins: Sequence[str]) -> None:
        """Load the remote plugins and register handlers defined in the plugins.

        Args:
            plugins: List of plugin paths to rplugin python modules
                registered by remote#host#RegisterPlugin('python3', ...)
                (see the generated rplugin.vim manifest)
        """
        # self.nvim.err_write("host init _load\n", async_=True)
        has_script = False
        for path in plugins:
            path = pathlib.Path(os.path.normpath(path)).as_posix()  # normalize path
            err = None
            if path in self._loaded:
                pass # replaces next logging statement
                # warn('{} is already loaded'.format(path))
                continue
            try:
                if path == "script_host.py":
                    module = script_host
                    has_script = True
                else:
                    directory, name = os.path.split(os.path.splitext(path)[0])
                    module = _handle_import(directory, name)
                handlers = []
                self._discover_classes(module, handlers, path)
                self._discover_functions(module, handlers, path, False)
                if not handlers:
                    pass # replaces next logging statement
                    # error('{} exports no handlers'.format(path))
                    continue
                self._loaded[path] = {'handlers': handlers, 'module': module}
            except Exception as e:
                err = ('Encountered {} loading plugin at {}: {}\n{}'
                       .format(type(e).__name__, path, e, format_exc(5)))
                pass # replaces next logging statement
                # error(err)
                self._load_errors[path] = err

        kind = ("script-host" if len(plugins) == 1 and has_script
                else "rplugin-host")
        info = get_client_info(kind, 'host', host_method_spec)
        self.name = info[0]
        self.nvim.api.set_client_info(*info, async_=True)

    def _unload(self) -> None:
        for path, plugin in self._loaded.items():
            handlers = plugin['handlers']
            for handler in handlers:
                method_name = handler._nvim_registered_name
                if hasattr(handler, '_nvim_shutdown_hook'):
                    handler()
                elif handler._nvim_rpc_sync:
                    del self._request_handlers[method_name]
                else:
                    del self._notification_handlers[method_name]
        self._specs = {}
        self._loaded = {}

    def _discover_classes(self, module, handlers, plugin_path):
        for _, cls in inspect.getmembers(module, inspect.isclass):
            if getattr(cls, '_nvim_plugin', False):
                # discover handlers in the plugin instance
                self._discover_functions(cls, handlers, plugin_path, True)

    def _discover_functions(self, obj, handlers, plugin_path, delay):
        def predicate(o):
            return hasattr(o, '_nvim_rpc_method_name')

        cls_handlers = []
        specs = []
        objdecode = getattr(obj, '_nvim_decode', self._decode_default)
        for _, fn in inspect.getmembers(obj, predicate):
            method = fn._nvim_rpc_method_name
            if fn._nvim_prefix_plugin_path:
                method = '{}:{}'.format(plugin_path, method)
            sync = fn._nvim_rpc_sync
            if delay:
                fn_wrapped = partial(self._wrap_delayed_function, obj,
                                     cls_handlers, method, sync,
                                     handlers, plugin_path)
            else:
                decode = getattr(fn, '_nvim_decode', objdecode)
                nvim_bind = None
                if fn._nvim_bind:
                    nvim_bind = self._configure_nvim_for(fn)

                fn_wrapped = partial(self._wrap_function, fn,
                                     sync, decode, nvim_bind, method)
            self._copy_attributes(fn, fn_wrapped)
            fn_wrapped._nvim_registered_name = method
            # register in the rpc handler dict
            if sync:
                if method in self._request_handlers:
                    raise Exception(('Request handler for "{}" is '
                                     + 'already registered').format(method))
                self._request_handlers[method] = fn_wrapped
            else:
                if method in self._notification_handlers:
                    raise Exception(('Notification handler for "{}" is '
                                     + 'already registered').format(method))
                self._notification_handlers[method] = fn_wrapped
            if hasattr(fn, '_nvim_rpc_spec'):
                specs.append(fn._nvim_rpc_spec)
            handlers.append(fn_wrapped)
            cls_handlers.append(fn_wrapped)
        if specs:
            self._specs[plugin_path] = specs

    def _copy_attributes(self, fn, fn2):
        # Copy _nvim_* attributes from the original function
        for attr in dir(fn):
            if attr.startswith('_nvim_'):
                setattr(fn2, attr, getattr(fn, attr))

    def _on_specs_request(self, path):
        path = decode_if_bytes(path)
        path = pathlib.Path(os.path.normpath(path)).as_posix()  # normalize path
        if path in self._load_errors:
            self.nvim.out_write(self._load_errors[path] + '\n')
        return self._specs.get(path, 0)

    def _configure_nvim_for(self, obj):
        # Configure a nvim instance for obj (checks encoding configuration)
        nvim = self.nvim
        decode = getattr(obj, '_nvim_decode', self._decode_default)
        if decode:
            nvim = nvim.with_decode(decode)
        return nvim