File: session.py

package info (click to toggle)
python-plumbum 1.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,300 kB
  • sloc: python: 10,016; makefile: 130; sh: 8
file content (318 lines) | stat: -rw-r--r-- 11,845 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
from __future__ import annotations

import contextlib
import logging
import random
import threading
import time

from plumbum.commands import BaseCommand, run_proc
from plumbum.commands.processes import ProcessExecutionError
from plumbum.machines.base import PopenAddons


class ShellSessionError(Exception):
    """Raises when something goes wrong when calling
    :func:`ShellSession.popen <plumbum.session.ShellSession.popen>`"""


class SSHCommsError(ProcessExecutionError, EOFError):
    """Raises when the communication channel can't be created on the
    remote host or it times out."""


class SSHCommsChannel2Error(SSHCommsError):
    """Raises when channel 2 (stderr) is not available"""


class IncorrectLogin(SSHCommsError):
    """Raises when incorrect login credentials are provided"""


class HostPublicKeyUnknown(SSHCommsError):
    """Raises when the host public key isn't known"""


shell_logger = logging.getLogger("plumbum.shell")


# ===================================================================================================
# Shell Session Popen
# ===================================================================================================
class MarkedPipe:
    """A pipe-like object from which you can read lines; the pipe will return report EOF (the
    empty string) when a special marker is detected"""

    __slots__ = ["pipe", "marker", "__weakref__"]

    def __init__(self, pipe, marker):
        self.pipe = pipe
        self.marker = marker
        self.marker = bytes(self.marker, "ascii")

    def close(self):
        """'Closes' the marked pipe; following calls to ``readline`` will return "" """
        # consume everything
        while self.readline():
            pass
        self.pipe = None

    def readline(self):
        """Reads the next line from the pipe; returns "" when the special marker is reached.
        Raises ``EOFError`` if the underlying pipe has closed"""
        if self.pipe is None:
            return b""
        line = self.pipe.readline()
        if not line:
            raise EOFError()
        if line.strip() == self.marker:
            self.pipe = None
            return b""
        return line


class SessionPopen(PopenAddons):
    """A shell-session-based ``Popen``-like object (has the following attributes: ``stdin``,
    ``stdout``, ``stderr``, ``returncode``)"""

    def __init__(self, proc, argv, isatty, stdin, stdout, stderr, encoding, *, host):
        self.host = host
        self.proc = proc
        self.argv = argv
        self.isatty = isatty
        self.stdin = stdin
        self.stdout = stdout
        self.stderr = stderr
        self.custom_encoding = encoding
        self.returncode = None
        self._done = False

    def poll(self):
        """Returns the process' exit code or ``None`` if it's still running"""
        return self.returncode if self._done else None

    def wait(self):
        """Waits for the process to terminate and returns its exit code"""
        self.communicate()
        return self.returncode

    def communicate(self, input=None):  # pylint: disable=redefined-builtin
        """Consumes the process' stdout and stderr until the it terminates.

        :param input: An optional bytes/buffer object to send to the process over stdin
        :returns: A tuple of (stdout, stderr)
        """
        stdout = []
        stderr = []
        sources = [("1", stdout, self.stdout)]
        if not self.isatty:
            # in tty mode, stdout and stderr are unified
            sources.append(("2", stderr, self.stderr))
        i = 0
        while sources:
            if input:
                chunk = input[:1000]
                self.stdin.write(chunk)
                self.stdin.flush()
                input = input[1000:]
            i = (i + 1) % len(sources)
            name, coll, pipe = sources[i]
            try:
                line = pipe.readline()
                shell_logger.debug("%s> %r", name, line)
            except EOFError as err:
                shell_logger.debug("%s> Nothing returned.", name)

                self.proc.poll()
                returncode = self.proc.returncode
                stdout = b"".join(stdout).decode(self.custom_encoding, "ignore")
                stderr = b"".join(stderr).decode(self.custom_encoding, "ignore")
                argv = self.argv.decode(self.custom_encoding, "ignore").split(";")[:1]

                if returncode == 5:
                    raise IncorrectLogin(
                        argv,
                        returncode,
                        stdout,
                        stderr,
                        message="Incorrect username or password provided",
                        host=self.host,
                    ) from None
                if returncode == 6:
                    raise HostPublicKeyUnknown(
                        argv,
                        returncode,
                        stdout,
                        stderr,
                        message="The authenticity of the host can't be established",
                        host=self.host,
                    ) from None
                if returncode != 0:
                    raise SSHCommsError(
                        argv,
                        returncode,
                        stdout,
                        stderr,
                        message="SSH communication failed",
                        host=self.host,
                    ) from None
                if name == "2":
                    raise SSHCommsChannel2Error(
                        argv,
                        returncode,
                        stdout,
                        stderr,
                        message="No stderr result detected. Does the remote have Bash as the default shell?",
                        host=self.host,
                    ) from None

                raise SSHCommsError(
                    argv,
                    returncode,
                    stdout,
                    stderr,
                    message="No communication channel detected. Does the remote exist?",
                    host=self.host,
                ) from err
            if not line:
                del sources[i]
            else:
                coll.append(line)
        if self.isatty:
            stdout.pop(0)  # discard first line of prompt
        try:
            self.returncode = int(stdout.pop(-1))
        except (IndexError, ValueError):
            self.returncode = "Unknown"
        self._done = True
        stdout = b"".join(stdout)
        stderr = b"".join(stderr)
        return stdout, stderr


class ShellSession:
    """An abstraction layer over *shell sessions*. A shell session is the execution of an
    interactive shell (``/bin/sh`` or something compatible), over which you may run commands
    (sent over stdin). The output of is then read from stdout and stderr. Shell sessions are
    less "robust" than executing a process on its own, and they are susseptible to all sorts
    of malformatted-strings attacks, and there is little benefit from using them locally.
    However, they can greatly speed up remote connections, and are required for the implementation
    of :class:`SshMachine <plumbum.machines.remote.SshMachine>`, as they allow us to send multiple
    commands over a single SSH connection (setting up separate SSH connections incurs a high
    overhead). Try to avoid using shell sessions, unless you know what you're doing.

    Instances of this class may be used as *context-managers*.

    :param proc: The underlying shell process (with open stdin, stdout and stderr)
    :param encoding: The encoding to use for the shell session. If ``"auto"``, the underlying
                     process' encoding is used.
    :param isatty: If true, assume the shell has a TTY and that stdout and stderr are unified
    :param connect_timeout: The timeout to connect to the shell, after which, if no prompt
                            is seen, the shell process is killed
    """

    def __init__(
        self, proc, encoding="auto", isatty=False, connect_timeout=5, *, host=None
    ):
        self.host = host
        self.proc = proc
        self.custom_encoding = proc.custom_encoding if encoding == "auto" else encoding
        self.isatty = isatty
        self._lock = threading.RLock()
        self._current = None
        self._startup_result = None
        if connect_timeout:

            def closer():
                shell_logger.error(
                    "Connection to %s timed out (%d sec)", proc, connect_timeout
                )
                self.close()

            timer = threading.Timer(connect_timeout, closer)
            timer.start()
        try:
            self._startup_result = self.run("")
        finally:
            if connect_timeout:
                timer.cancel()

    def __enter__(self):
        return self

    def __exit__(self, t, v, tb):
        self.close()

    def __del__(self):
        with contextlib.suppress(Exception):
            self.close()

    def alive(self):
        """Returns ``True`` if the underlying shell process is alive, ``False`` otherwise"""
        return self.proc and self.proc.poll() is None

    def close(self):
        """Closes (terminates) the shell session"""
        if not self.alive():
            return
        with contextlib.suppress(ValueError, OSError):
            self.proc.stdin.write(b"\nexit\n\n\nexit\n\n")
            self.proc.stdin.flush()
            time.sleep(0.05)
        for p in (self.proc.stdin, self.proc.stdout, self.proc.stderr):
            with contextlib.suppress(Exception):
                p.close()
        with contextlib.suppress(OSError):
            self.proc.kill()
        self.proc = None

    def popen(self, cmd):
        """Runs the given command in the shell, adding some decoration around it. Only a single
        command can be executed at any given time.

        :param cmd: The command (string or :class:`Command <plumbum.commands.BaseCommand>` object)
                    to run
        :returns: A :class:`SessionPopen <plumbum.session.SessionPopen>` instance
        """
        if self.proc is None:
            raise ShellSessionError("Shell session has already been closed")
        if self._current and not self._current._done:
            raise ShellSessionError("Each shell may start only one process at a time")

        full_cmd = cmd.formulate(1) if isinstance(cmd, BaseCommand) else cmd
        marker = f"--.END{time.time() * random.random()}.--"
        if full_cmd.strip():
            full_cmd += " ; "
        else:
            full_cmd = "true ; "
        full_cmd += f"echo $? ; echo '{marker}'"
        if not self.isatty:
            full_cmd += f" ; echo '{marker}' 1>&2"
        if self.custom_encoding:
            full_cmd = full_cmd.encode(self.custom_encoding)
        shell_logger.debug("Running %r", full_cmd)
        self.proc.stdin.write(full_cmd + b"\n")
        self.proc.stdin.flush()
        self._current = SessionPopen(
            self.proc,
            full_cmd,
            self.isatty,
            self.proc.stdin,
            MarkedPipe(self.proc.stdout, marker),
            MarkedPipe(self.proc.stderr, marker),
            self.custom_encoding,
            host=self.host,
        )
        return self._current

    def run(self, cmd, retcode=0):
        """Runs the given command

        :param cmd: The command (string or :class:`Command <plumbum.commands.BaseCommand>` object)
                    to run
        :param retcode: The expected return code (0 by default). Set to ``None`` in order to
                        ignore erroneous return codes
        :returns: A tuple of (return code, stdout, stderr)
        """
        with self._lock:
            return run_proc(self.popen(cmd), retcode)