File: base.py

package info (click to toggle)
pytest-testinfra 10.2.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 676 kB
  • sloc: python: 4,951; makefile: 152; sh: 2
file content (324 lines) | stat: -rw-r--r-- 9,550 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import dataclasses
import locale
import logging
import shlex
import subprocess
import urllib.parse
from typing import TYPE_CHECKING, Any, Optional, Union

if TYPE_CHECKING:
    import testinfra.host

logger = logging.getLogger("testinfra")


@dataclasses.dataclass
class HostSpec:
    name: str
    port: Optional[str]
    user: Optional[str]
    password: Optional[str]


@dataclasses.dataclass
class CommandResult:
    """Object that encapsulates all returned details of the command execution.

    Example:

    >>> cmd = host.run("ls -l /etc/passwd")
    >>> cmd.rc
    0
    >>> cmd.stdout
    '-rw-r--r-- 1 root root 1790 Feb 11 00:28 /etc/passwd\\n'
    >>> cmd.stderr
    ''
    >>> cmd.succeeded
    True
    >>> cmd.failed
    False
    """

    backend: "BaseBackend"
    exit_status: int
    command: bytes
    _stdout: Union[str, bytes]
    _stderr: Union[str, bytes]

    @property
    def succeeded(self) -> bool:
        """Returns whether the command was successful

        >>> host.run("true").succeeded
        True
        """
        return self.exit_status == 0

    @property
    def failed(self) -> bool:
        """Returns whether the command failed

        >>> host.run("false").failed
        True
        """
        return self.exit_status != 0

    @property
    def rc(self) -> int:
        """Gets the returncode of a command

        >>> host.run("true").rc
        0
        """
        return self.exit_status

    @property
    def stdout(self) -> str:
        """Gets standard output (stdout) stream of an executed command

        >>> host.run("mkdir -v new_directory").stdout
        mkdir: created directory 'new_directory'
        """
        if isinstance(self._stdout, bytes):
            return self.backend.decode(self._stdout)
        return self._stdout

    @property
    def stderr(self) -> str:
        """Gets standard error (stderr) stream of an executed command

        >>> host.run("mkdir new_directory").stderr
        mkdir: cannot create directory 'new_directory': File exists
        """
        if isinstance(self._stderr, bytes):
            return self.backend.decode(self._stderr)
        return self._stderr

    @property
    def stdout_bytes(self) -> bytes:
        """Gets standard output (stdout) stream of an executed command as bytes

        >>> host.run("mkdir -v new_directory").stdout_bytes
        b"mkdir: created directory 'new_directory'"
        """
        if isinstance(self._stdout, str):
            return self.backend.encode(self._stdout)
        return self._stdout

    @property
    def stderr_bytes(self) -> bytes:
        """Gets standard error (stderr) stream of an executed command as bytes

        >>> host.run("mkdir new_directory").stderr_bytes
        b"mkdir: cannot create directory 'new_directory': File exists"
        """
        if isinstance(self._stderr, str):
            return self.backend.encode(self._stderr)
        return self._stderr


class BaseBackend(metaclass=abc.ABCMeta):
    """Represent the connection to the remote or local system"""

    HAS_RUN_SALT = False
    HAS_RUN_ANSIBLE = False
    NAME: str

    def __init__(
        self,
        hostname: str,
        sudo: bool = False,
        sudo_user: Optional[str] = None,
        *args: Any,
        **kwargs: Any,
    ):
        self._encoding: Optional[str] = None
        self._host: Optional[testinfra.host.Host] = None
        self.hostname = hostname
        self.sudo = sudo
        self.sudo_user = sudo_user
        super().__init__()

    def set_host(self, host: "testinfra.host.Host") -> None:
        self._host = host

    @classmethod
    def get_connection_type(cls) -> str:
        """Return the connection backend used as string.

        Can be local, paramiko, ssh, docker, salt or ansible
        """
        return cls.NAME

    def get_hostname(self) -> str:
        """Return the hostname (for testinfra) of the remote or local system


        Can be useful for multi-hosts tests:

        Example:
        ::

            import requests


            def test(TestinfraBackend):
                host = TestinfraBackend.get_hostname()
                response = requests.get("http://" + host)
                assert response.status_code == 200


        ::

            $ testinfra --hosts=server1,server2 test.py

            test.py::test[paramiko://server1] PASSED
            test.py::test[paramiko://server2] PASSED
        """
        return self.hostname

    def get_pytest_id(self) -> str:
        return self.get_connection_type() + "://" + self.get_hostname()

    @classmethod
    def get_hosts(cls, host: str, **kwargs: Any) -> list[str]:
        if host is None:
            raise RuntimeError(
                f"One or more hosts is required with the {cls.get_connection_type()} backend"
            )
        return [host]

    @staticmethod
    def quote(command: str, *args: str) -> str:
        if args:
            return command % tuple(shlex.quote(a) for a in args)
        return command

    def get_sudo_command(self, command: str, sudo_user: Optional[str]) -> str:
        if sudo_user is None:
            return self.quote("sudo /bin/sh -c %s", command)
        return self.quote("sudo -u %s /bin/sh -c %s", sudo_user, command)

    def get_command(self, command: str, *args: str) -> str:
        command = self.quote(command, *args)
        if self.sudo:
            command = self.get_sudo_command(command, self.sudo_user)
        return command

    def run(self, command: str, *args: str, **kwargs: Any) -> CommandResult:
        raise NotImplementedError

    def run_local(self, command: str, *args: str) -> CommandResult:
        command = self.quote(command, *args)
        cmd = self.encode(command)
        p = subprocess.Popen(
            cmd,
            shell=True,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        stdout, stderr = p.communicate()
        result = self.result(p.returncode, cmd, stdout, stderr)
        return result

    @staticmethod
    def parse_hostspec(hostspec: str) -> HostSpec:
        name = hostspec
        port = None
        user = None
        password = None
        if "@" in name:
            user, name = name.split("@", 1)
            if ":" in user:
                user, password = user.split(":", 1)
        # A literal IPv6 address might be like
        #  [fe80:0::a:b:c]:80
        # thus, below in words; if this starts with a '[' assume it
        # encloses an ipv6 address with a closing ']', with a possible
        # trailing port after a colon
        if name.startswith("["):
            name, port = name.split("]")
            name = name[1:]
            port = port[1:] if port.startswith(":") else None
        else:
            if ":" in name:
                name, port = name.split(":", 1)
        name = urllib.parse.unquote(name)
        if user is not None:
            user = urllib.parse.unquote(user)
        if password is not None:
            password = urllib.parse.unquote(password)
        return HostSpec(name, port, user, password)

    @staticmethod
    def parse_containerspec(containerspec: str) -> tuple[str, Optional[str]]:
        name = containerspec
        user = None
        if "@" in name:
            user, name = name.split("@", 1)
        return name, user

    def get_encoding(self) -> str:
        encoding = None
        for python in ("python3", "python"):
            cmd = self.run(
                "%s -c 'import locale;print(locale.getpreferredencoding())'",
                python,
                encoding=None,
            )
            if cmd.rc == 0:
                encoding = cmd.stdout_bytes.splitlines()[0].decode("ascii")
                break
        # Python is not installed, we hope the encoding to be the same as
        # local machine...
        if not encoding:
            encoding = locale.getpreferredencoding()
        if encoding == "ANSI_X3.4-1968":
            # Workaround default encoding ascii without LANG set
            encoding = "UTF-8"
        return encoding

    @property
    def encoding(self) -> str:
        if self._encoding is None:
            self._encoding = self.get_encoding()
        return self._encoding

    def decode(self, data: bytes) -> str:
        try:
            return data.decode("ascii")
        except UnicodeDecodeError:
            return data.decode(self.encoding)

    def encode(self, data: str) -> bytes:
        try:
            return data.encode("ascii")
        except UnicodeEncodeError:
            return data.encode(self.encoding)

    def result(
        self, rc: int, cmd: bytes, stdout: Union[str, bytes], stderr: Union[str, bytes]
    ) -> CommandResult:
        result = CommandResult(
            backend=self,
            exit_status=rc,
            command=cmd,
            _stdout=stdout,
            _stderr=stderr,
        )
        logger.debug("RUN %s", result)
        return result