File: cmdbase.py

package info (click to toggle)
python-nubia 0.2.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 772 kB
  • sloc: python: 4,182; makefile: 9; sh: 1
file content (522 lines) | stat: -rw-r--r-- 20,321 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

import asyncio
import copy
import inspect
import sys
import traceback
from collections import OrderedDict
from textwrap import dedent
from typing import Iterable, Optional, Callable

from prompt_toolkit.completion import CompleteEvent, Completion, WordCompleter
from prompt_toolkit.document import Document
from termcolor import cprint

from nubia.internal import parser
from nubia.internal.completion import AutoCommandCompletion
from nubia.internal.exceptions import CommandParseError
from nubia.internal.helpers import (
    find_approx,
    function_to_str,
    suggestions_msg,
    try_await,
)
from nubia.internal.options import Options
from nubia.internal.typing import FunctionInspection, inspect_object
from nubia.internal.typing.argparse import (
    get_arguments_for_command,
    get_arguments_for_inspection,
    register_command,
)
from nubia.internal.typing.builder import apply_typing
from nubia.internal.typing.inspect import is_list_type

from . import context


class Command:
    """A Command is the abstraction over one or more commands that will executed
    by the shell, A Command sub-class must implement `cmds` with a dict that
    maps command to a description.
    """

    def __init__(self):
        self._command_registry = None
        self._built_in = False

    @property
    def built_in(self) -> bool:
        return self._built_in

    def set_command_registry(self, command_registry):
        self._command_registry = command_registry

    async def run_interactive(self, cmd, args, raw):
        """
        This function MUST be overridden by all commands. It will be called when
        the command is executed in interactive mode.
        """
        raise NotImplementedError("run_interactive must be overridden")

    async def run_cli(self, args):
        """
        This function SHOULD be implemented in order to expose a subcommand in
        the CLI interface. It will be called when run from the CLI.
        """
        pass

    async def add_arguments(self, parser):
        """
        This function receives an instance of an "argparse.ArgumentParser".
        Every command SHOULD use it to tell the CLI interface which options
        needs.
        """
        # register_command(parser, inspect_object(self._fn))
        pass

    @property
    def metadata(self) -> FunctionInspection:
        """
        Returns the command specification as an instance of FunctionInspection
        object. This is used to generate a completion model for external
        completers
        """
        return {}

    def get_completions(self, cmd, document, complete_event) -> Iterable[Completion]:
        """
        This function SHOULD be implemented to feed the interactive auto
        completion of command arguments. Example: auto complete the available
        tables in the "describe" query command.
        """
        return []

    def get_command_names(self):
        """
        This function MUST be implemented to tell the framework which commands
        this module implements. Must return a list of strings.
        """
        raise NotImplementedError("get_command_names must be overridden")

    def get_cli_aliases(self):
        """
        This function SHOULD be implemented to instruct the command dispatcher
        about alternative commands available in the CLI. Example: while the
        "commands/query.py" exports "select" and" describe" in interactive
        mode, the CLI uses the subcommand "query" to run those commands.
        Must return a list of strings.
        """
        return []

    def get_help(self, cmd, *args):
        """
        This function SHOULD be implemented to show command help when running
        ':help'. It must return a string associated with the given command.
        """
        pass

    def get_help_short(self, cmd, *args):
        """Return a shortened help.

        This is for example used for interactive autocompletion."""
        help = self.get_help(cmd, *args)
        return help.split("\n", 1)[0] if help else None

    @property
    def super_command(self) -> bool:
        """
        Does this command parse sub-commands?
        """
        return False

    def has_subcommand(self, subcommand) -> bool:
        """
        Does this command have `subcommand` as a valid sub-command?
        """
        return False


class AutoCommand(Command):
    def __init__(self, fn, options: Optional[Options] = None):
        self._built_in = False
        self._fn = fn
        self._options = options or Options()

        if not callable(fn):
            raise ValueError("fn argument must be a callable")

        self._obj_metadata = inspect_object(fn)
        self._is_super_command = len(self.metadata.subcommands) > 0
        self._subcommand_names = []

        # We never expect a function to be passed here that has a self argument
        # In that case, we should get a bound method
        if "self" in self.metadata.arguments and not inspect.ismethod(self._fn):
            raise ValueError(
                "Expecting either a function (eg. bar) or "
                "a bound method (eg. Foo().bar). "
                "You passed what appears to be an unbound method "
                "(eg. Foo.bar) it has a 'self' argument: %s" % function_to_str(fn)
            )

        if not self.metadata.command:
            raise ValueError(
                "function or class {} needs to be annotated with "
                "@command".format(function_to_str(fn))
            )
        # If this is a super command, we need a completer for sub-commands
        if self.super_command:
            self._commands_completer = WordCompleter(
                [], ignore_case=True, sentence=True
            )
            for _, inspection in self.metadata.subcommands:
                _sub_name = inspection.command.name
                self._commands_completer.words.append(_sub_name)
                self._commands_completer.meta_dict[_sub_name] = dedent(
                    inspection.command.help
                ).strip()
                self._subcommand_names.append(_sub_name)

    @property
    def metadata(self) -> FunctionInspection:
        """
        The Inspection object of this command. This object contains all the
        information required by AutoCommand to understand the command arguments
        type information, help messages, aliases, and attributes.
        """
        return self._obj_metadata

    def _create_subcommand_obj(self, key_values):
        """
        Instantiates an object of the super command class, passes the right
        arguments and returns a dict with the remaining unused arguments
        """
        kwargs = {
            k: v
            for k, v in get_arguments_for_inspection(self.metadata, key_values).items()
            if v is not None
        }
        remaining = {
            k: v
            for k, v in key_values.items()
            if k.replace("-", "_") not in kwargs.keys()
        }
        return self._fn(**kwargs), remaining

    async def run_interactive(self, cmd, args, raw):
        try:
            args_metadata = self.metadata.arguments
            parsed = parser.parse(args, expect_subcommand=self.super_command)

            # prepare args dict
            parsed_dict = parsed.asDict()
            args_dict = parsed.kv.asDict()
            key_values = parsed.kv.asDict()
            command_name = cmd
            # if this is a super command, we need first to create an instance of
            # the class (fn) and pass the right arguments
            if self.super_command:
                subcommand = parsed_dict.get("__subcommand__")
                if not subcommand:
                    cprint(
                        "A sub-command must be supplied, valid values: "
                        "{}".format(", ".join(self._get_subcommands())),
                        "red",
                    )
                    return 2
                subcommands = self._get_subcommands()

                if subcommand not in subcommands:
                    suggestions = find_approx(subcommand, subcommands)
                    if (
                        len(suggestions) == 1
                        and self._options.auto_execute_single_suggestions
                    ):
                        print()
                        cprint(
                            "Auto-correcting '{}' to '{}'".format(
                                subcommand, suggestions[0]
                            ),
                            "red",
                            attrs=["bold"],
                        )
                        subcommand = suggestions[0]
                    else:
                        print()
                        cprint(
                            "Invalid sub-command '{}'{} "
                            "valid sub-commands: {}".format(
                                subcommand,
                                suggestions_msg(suggestions),
                                ", ".join(self._get_subcommands()),
                            ),
                            "red",
                            attrs=["bold"],
                        )
                        return 2

                sub_inspection = self.subcommand_metadata(subcommand)
                instance, remaining_args = self._create_subcommand_obj(args_dict)
                assert instance
                args_dict = remaining_args
                key_values = copy.copy(args_dict)
                args_metadata = sub_inspection.arguments
                attrname = self._find_subcommand_attr(subcommand)
                command_name = subcommand
                assert attrname is not None
                fn = getattr(instance, attrname)
            else:
                # not a super-command, use use the function instead
                fn = self._fn
            positionals = parsed_dict["positionals"] if parsed.positionals != "" else []
            # We only allow positionals for arguments that have positional=True
            # ِ We filter out the OrderedDict this way to ensure we don't lose the
            # order of the arguments. We also filter out arguments that have
            # been passed by name already. The order of the positional arguments
            # follows the order of the function definition.
            can_be_positional = self._positional_arguments(
                args_metadata, args_dict.keys()
            )

            if len(positionals) > len(can_be_positional):
                if len(can_be_positional) == 0:
                    err = "This command does not support positional arguments"
                else:
                    # We have more positionals than we should
                    err = (
                        "This command only supports ({}) positional arguments, "
                        "namely arguments ({}). You have passed {} arguments ({})"
                        " instead!"
                    ).format(
                        len(can_be_positional),
                        ", ".join(can_be_positional.keys()),
                        len(positionals),
                        ", ".join(str(x) for x in positionals),
                    )
                cprint(err, "red")
                return 2
            # constuct key_value dict from positional arguments.
            args_from_positionals = {
                key: value for value, key in zip(positionals, can_be_positional)
            }
            # update the total arguments dict with the positionals
            args_dict.update(args_from_positionals)

            # Run some validations on number of arguments provided

            # do we have keys that are supplied in both positionals and
            # key_value style?
            duplicate_keys = set(args_from_positionals.keys()).intersection(
                set(key_values.keys())
            )
            if duplicate_keys:
                cprint(
                    "Arguments '{}' have been passed already, cannot have"
                    " duplicate keys".format(list(duplicate_keys)),
                    "red",
                )
                return 2

            # check for verbosity override in kwargs
            ctx = context.get_context()
            old_verbose = ctx.args.verbose
            if "verbose" in args_dict:
                ctx.set_verbose(args_dict["verbose"])
                del args_dict["verbose"]
                del key_values["verbose"]

            # do we have keys that we know nothing about?
            extra_keys = set(args_dict.keys()) - set(args_metadata)
            if extra_keys:
                cprint(
                    f"Unknown argument(s) {sorted(extra_keys)} were passed",
                    "magenta",
                )
                return 2

            # is there any required keys that were not resolved from positionals
            # nor key_values?
            missing_keys = set(args_metadata) - set(args_dict.keys())
            if missing_keys:
                required_missing = []
                for key in missing_keys:
                    if not args_metadata[key].default_value_set:
                        required_missing.append(key)
                if required_missing:
                    cprint(
                        "Missing required argument(s) {} for command"
                        " {}".format(required_missing, command_name),
                        "yellow",
                    )
                    return 3

            # convert expected types for arguments
            for key, value in args_dict.items():
                target_type = args_metadata[key].type
                if target_type is None:
                    target_type = str
                try:
                    new_value = apply_typing(value, target_type)
                except ValueError:
                    fn_name = function_to_str(target_type, False, False)
                    cprint(
                        'Cannot convert value "{}" to {} on argument {}'.format(
                            value, fn_name, key
                        ),
                        "yellow",
                    )
                    return 4
                else:
                    args_dict[key] = new_value

            # Validate that arguments with `choices` are supplied with the
            # acceptable values. We can't validate dynamic completions yet
            for arg, value in args_dict.items():
                choices = args_metadata[arg].choices
                if choices and not isinstance(choices, Callable):
                    # Validate the choices in the case of values and list of
                    # values.
                    if is_list_type(args_metadata[arg].type):
                        bad_inputs = [v for v in value if v not in choices]
                        if bad_inputs:
                            cprint(
                                f"Argument '{arg}' got an unexpected "
                                f"value(s) '{bad_inputs}'. Expected one "
                                f"or more of {choices}.",
                                "red",
                            )
                            return 4
                    elif value not in choices:
                        cprint(
                            f"Argument '{arg}' got an unexpected value "
                            f"'{value}'. Expected one of "
                            f"{choices}.",
                            "red",
                        )
                        return 4

            # arguments appear to be fine, time to run the function
            try:
                # convert argument names back to match the function signature
                args_dict = {args_metadata[k].arg: v for k, v in args_dict.items()}

                ret = await try_await(fn(**args_dict))
                ctx.set_verbose(old_verbose)
            except Exception as e:
                cprint("Error running command: {}".format(str(e)), "red")
                cprint("-" * 60, "yellow")
                traceback.print_exc(file=sys.stderr)
                cprint("-" * 60, "yellow")
                return 1

            return ret

        except CommandParseError as e:
            cprint("Error parsing command", "red")
            cprint(cmd + " " + args, "white", attrs=["bold"])
            cprint((" " * (e.col + len(cmd))) + "^", "white", attrs=["bold"])
            cprint(str(e), "yellow")
            return 1

    def _positional_arguments(self, args_metadata, filter_out):
        positionals = OrderedDict()
        for k, v in args_metadata.items():
            if v.positional and k not in filter_out:
                positionals[k] = v
        return positionals

    def subcommand_metadata(self, name: str) -> FunctionInspection:
        assert self.super_command
        subcommands = self.metadata.subcommands
        for _, inspection in subcommands:
            if inspection.command.name == name:
                return inspection

    def _find_subcommand_attr(self, name):
        assert self.super_command
        subcommands = self.metadata.subcommands
        for attr, inspection in subcommands:
            if inspection.command.name == name or name in inspection.command.aliases:
                return attr
        # be explicit about returning None for readability
        return None

    def _get_subcommands(self) -> Iterable[str]:
        assert self.super_command
        return [inspection.command.name for _, inspection in self.metadata.subcommands]

    def _kwargs_for_fn(self, fn, args):
        return {
            k: v
            for k, v in get_arguments_for_command(fn, args).items()
            if v is not None
        }

    async def run_cli(self, args):
        # if this is a super-command, we need to dispatch the call to the
        # correct function
        kwargs = self._kwargs_for_fn(self._fn, args)
        try:
            if self._is_super_command:
                # let's instantiate an instance of the klass
                instance = self._fn(**kwargs)
                # we need to find the actual method we want to call, in addition to
                # this we need to extract the correct kwargs for this method
                # find which function it is in the sub commands
                attrname = self._find_subcommand_attr(args._subcmd)
                assert attrname is not None
                fn = getattr(instance, attrname)
                kwargs = self._kwargs_for_fn(fn, args)
            else:
                fn = self._fn

            ret = await try_await(fn(**kwargs))

            return ret
        except Exception as e:
            cprint("Error running command: {}".format(str(e)), "red")
            cprint("-" * 60, "yellow")
            traceback.print_exc(file=sys.stderr)
            cprint("-" * 60, "yellow")
            return 1

    @property
    def super_command(self):
        return self._is_super_command

    def has_subcommand(self, subcommand):
        assert self.super_command
        return subcommand.lower() in self._subcommand_names

    async def add_arguments(self, parser):
        register_command(parser, self.metadata)

    def get_command_names(self):
        command = self.metadata.command
        return [command.name] + command.aliases

    def get_completions(
        self, _: str, document: Document, complete_event: CompleteEvent
    ) -> Iterable[Completion]:
        if self._is_super_command:
            exploded = document.text.lstrip().split(" ", 1)
            # Are we at the first word? we expect a sub-command here
            if len(exploded) <= 1:
                return self._commands_completer.get_completions(
                    document, complete_event
                )

        state_machine = AutoCommandCompletion(self, document, complete_event)
        return state_machine.get_completions()

    def get_help(self, cmd, *args):
        help = self.metadata.command.help
        return dedent(help).strip() if help else None