File: socket.py

package info (click to toggle)
pytest-testinfra 10.2.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 676 kB
  • sloc: python: 4,951; makefile: 152; sh: 2
file content (356 lines) | stat: -rw-r--r-- 12,450 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import socket
from typing import Optional

from testinfra.modules.base import Module


def parse_socketspec(socketspec):
    protocol, address = socketspec.split("://", 1)

    if protocol not in ("udp", "tcp", "unix"):
        raise RuntimeError(
            f"Cannot validate protocol '{protocol}'. Should be tcp, udp or unix"
        )

    if protocol == "unix":
        # unix:///foo/bar.sock
        host = address
        port = None
    elif ":" in address:
        # tcp://127.0.0.1:22
        # tcp://:::22
        host, port = address.rsplit(":", 1)
    else:
        # tcp://22
        host = None
        port = address

    family = None
    if protocol != "unix" and host is not None:
        for f in (socket.AF_INET, socket.AF_INET6):
            try:
                socket.inet_pton(f, host)
            except OSError:
                pass
            else:
                family = f
                break

        if family is None:
            raise RuntimeError(f"Cannot validate ip address '{host}'")

    if port is not None:
        try:
            port = int(port)
        except ValueError:
            raise RuntimeError(f"Cannot validate port '{port}'") from None

    return protocol, host, port


class Socket(Module):
    """Test listening tcp/udp and unix sockets

    ``socketspec`` must be specified as ``<protocol>://<host>:<port>``

    This module requires the ``netstat`` command to on the target host.

    Example:

      - Unix sockets: ``unix:///var/run/docker.sock``
      - All ipv4 and ipv6 tcp sockets on port 22: ``tcp://22``
      - All ipv4 sockets on port 22: ``tcp://0.0.0.0:22``
      - All ipv6 sockets on port 22: ``tcp://:::22``
      - udp socket on 127.0.0.1 port 69: ``udp://127.0.0.1:69``

    """

    _command = None

    def __init__(self, socketspec):
        if socketspec is not None:
            self.protocol, self.host, self.port = parse_socketspec(socketspec)
        else:
            self.protocol = self.host = self.port = None
        super().__init__()

    @property
    def is_listening(self):
        """Test if socket is listening

        >>> host.socket("unix:///var/run/docker.sock").is_listening
        False
        >>> # This HTTP server listen on all ipv4 addresses but not on ipv6
        >>> host.socket("tcp://0.0.0.0:80").is_listening
        True
        >>> host.socket("tcp://:::80").is_listening
        False
        >>> host.socket("tcp://80").is_listening
        False

        .. note:: If you don't specify a host for udp and tcp sockets,
                  then the socket is listening if and only if the
                  socket listen on **both** all ipv4 and ipv6 addresses
                  (ie 0.0.0.0 and ::)
        """
        sockets = list(self._iter_sockets(True))
        if self.protocol == "unix":
            return ("unix", self.host) in sockets
        allipv4 = (self.protocol, "0.0.0.0", self.port) in sockets
        allipv6 = (self.protocol, "::", self.port) in sockets
        return any([allipv6, all([allipv4, allipv6])]) or (
            self.host is not None
            and (
                (":" in self.host and allipv6 in sockets)
                or (":" not in self.host and allipv4 in sockets)
                or (self.protocol, self.host, self.port) in sockets
            )
        )

    @property
    def clients(self) -> list[Optional[tuple[str, int]]]:
        """Return a list of clients connected to a listening socket

        For tcp and udp sockets a list of pair (address, port) is returned.
        For unix sockets a list of None is returned (thus you can make a
        len() for counting clients).

        >>> host.socket("tcp://22").clients
        [('2001:db8:0:1', 44298), ('192.168.31.254', 34866)]
        >>> host.socket("unix:///var/run/docker.sock")
        [None, None, None]

        """
        sockets: list[Optional[tuple[str, int]]] = []
        for sock in self._iter_sockets(False):
            if sock[0] != self.protocol:
                continue

            if self.protocol == "unix":
                if sock[1] == self.host:
                    sockets.append(None)
                continue

            if sock[2] != self.port:
                continue

            if (
                self.host is None
                or (self.host == "0.0.0.0" and ":" not in sock[3])
                or (self.host == "::" and ":" in sock[3])
                or self.host == sock[3]
            ):
                sockets.append((sock[3], sock[4]))
        return sockets

    @classmethod
    def get_listening_sockets(cls):
        """Return a list of all listening sockets

        >>> host.socket.get_listening_sockets()
        ['tcp://0.0.0.0:22', 'tcp://:::22', 'unix:///run/systemd/private', ...]
        """
        sockets = []
        for sock in cls(None)._iter_sockets(True):
            if sock[0] == "unix":
                sockets.append("unix://" + sock[1])
            else:
                sockets.append(f"{sock[0]}://{sock[1]}:{sock[2]}")
        return sockets

    def _iter_sockets(self, listening):
        raise NotImplementedError

    def __repr__(self):
        return "<socket {}://{}{}>".format(
            self.protocol,
            self.host + ":" if self.host else "",
            self.port,
        )

    @classmethod
    def get_module_class(cls, host):
        if host.system_info.type == "linux":
            for cmd, impl in (
                ("ss", LinuxSocketSS),
                ("netstat", LinuxSocketNetstat),
            ):
                try:
                    command = host.find_command(cmd)
                except ValueError:
                    pass
                else:
                    return type(impl.__name__, (impl,), {"_command": command})
            raise RuntimeError(
                'could not use the Socket module, either "ss" or "netstat"'
                " utility is required in $PATH"
            )
        if host.system_info.type.endswith("bsd"):
            return BSDSocket
        raise NotImplementedError


class LinuxSocketSS(Socket):
    def _iter_sockets(self, listening):
        cmd = "%s --numeric"
        if listening:
            cmd += " --listening"
        else:
            cmd += " --all"
        if self.protocol == "tcp":
            cmd += " --tcp"
        elif self.protocol == "udp":
            cmd += " --udp"
        elif self.protocol == "unix":
            cmd += " --unix"

        for line in self.run(cmd, self._command).stdout_bytes.splitlines()[1:]:
            # Ignore unix datagram sockets.
            if line.split(None, 1)[0] == b"u_dgr":
                continue
            splitted = line.decode().split()

            # If listing only TCP or UDP sockets, output has 5 columns:
            # (State, Recv-Q, Send-Q, Local Address:Port, Peer Address:Port)
            if self.protocol in ("tcp", "udp"):
                protocol = self.protocol
                status, local, remote = (splitted[0], splitted[3], splitted[4])
            # If listing all or just unix sockets, output has 6 columns:
            # Netid, State, Recv-Q, Send-Q, LocalAddress:Port, PeerAddress:Port
            else:
                protocol, status, local, remote = (
                    splitted[0],
                    splitted[1],
                    splitted[4],
                    splitted[5],
                )

            # ss reports unix socket as u_str.
            if protocol == "u_str":
                protocol = "unix"
                host, port = local, None
            elif protocol in ("tcp", "udp"):
                host, port = local.rsplit(":", 1)
                port = int(port)
                # new versions of ss output ipv6 addresses enclosed in []
                if host and host[0] == "[" and host[-1] == "]":
                    host = host[1:-1]
            else:
                continue

            # UDP listening sockets may be in 'UNCONN' status.
            if listening and status in ("LISTEN", "UNCONN"):
                if host == "*" and protocol in ("tcp", "udp"):
                    yield protocol, "::", port
                    yield protocol, "0.0.0.0", port
                elif protocol in ("tcp", "udp"):
                    yield protocol, host, port
                else:
                    yield protocol, host
            elif not listening and status == "ESTAB":
                if protocol in ("tcp", "udp"):
                    remote_host, remote_port = remote.rsplit(":", 1)
                    remote_port = int(remote_port)
                    yield protocol, host, port, remote_host, remote_port
                else:
                    yield protocol, remote


class LinuxSocketNetstat(Socket):
    def _iter_sockets(self, listening):
        cmd = "%s -n"

        if listening:
            cmd += " -l"

        if self.protocol == "tcp":
            cmd += " -t"
        elif self.protocol == "udp":
            cmd += " -u"
        elif self.protocol == "unix":
            cmd += " --unix"

        for line in self.check_output(cmd, self._command).splitlines():
            line = line.replace("\t", " ")
            splitted = line.split()
            protocol = splitted[0]
            if protocol in ("udp", "tcp", "tcp6", "udp6"):
                if protocol == "udp6":
                    protocol = "udp"
                elif protocol == "tcp6":
                    protocol = "tcp"
                address = splitted[3]
                host, port = address.rsplit(":", 1)
                port = int(port)
                if listening:
                    yield protocol, host, port
                else:
                    remote = splitted[4]
                    remote_host, remote_port = remote.rsplit(":", 1)
                    remote_port = int(remote_port)
                    yield protocol, host, port, remote_host, remote_port
            elif protocol == "unix":
                yield protocol, splitted[-1]


class BSDSocket(Socket):
    @functools.cached_property
    def _command(self):
        return self.find_command("netstat")

    def _iter_sockets(self, listening):
        cmd = "%s -n"

        if listening:
            cmd += " -a"

        if self.protocol == "unix":
            cmd += " -f unix"

        for line in self.check_output(cmd, self._command).splitlines():
            line = line.replace("\t", " ")
            splitted = line.split()
            # FreeBSD: tcp4/tcp6
            # OpeNBSD/NetBSD: tcp/tcp6
            if splitted[0] in ("tcp", "udp", "udp4", "tcp4", "tcp6", "udp6"):
                address = splitted[3]
                if address == "*.*":
                    # On OpenBSD 6.3 (issue #338)
                    # udp          0      0  *.*                    *.*
                    # udp6         0      0  *.*                    *.*
                    continue
                host, port = address.rsplit(".", 1)
                port = int(port)

                if host == "*":
                    host = "::" if splitted[0] in ("udp6", "tcp6") else "0.0.0.0"
                if splitted[0] in ("udp", "udp6", "udp4"):
                    protocol = "udp"
                elif splitted[0] in ("tcp", "tcp6", "tcp4"):
                    protocol = "tcp"

                remote = splitted[4]
                if remote == "*.*" and listening:
                    yield protocol, host, port
                elif not listening:
                    remote_host, remote_port = remote.rsplit(".", 1)
                    remote_port = int(remote_port)
                    yield protocol, host, port, remote_host, remote_port
            elif len(splitted) == 9 and splitted[1] in ("stream", "dgram"):
                if (splitted[4] != "0" and listening) or (
                    splitted[4] == "0" and not listening
                ):
                    yield "unix", splitted[-1]