File: virtme-ssh-proxy

package info (click to toggle)
virtme-ng 1.40-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 616 kB
  • sloc: python: 5,185; sh: 518; makefile: 34
file content (88 lines) | stat: -rwxr-xr-x 2,215 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
import argparse
import array
import socket
import sys
from collections.abc import Iterable

# Supported hostname and CID separators
#
# Important: Keep in sync with `VIRTME_SSH_HOSTNAME_CID_SEPARATORS`
SEPARATORS = ("%", "/")

# Some exit codes
SUCCESS = 0
ERROR_USAGE = 2


def get_hostname_cid(ssh_dst: str) -> tuple[str | None, int | None]:
    """Return the hostname and CID

    >>> hostname, cid = get_hostname_cid("virtme-ng/24")
    >>> hostname
    'virtme-ng'
    >>> cid
    24
    >>> hostname, cid = get_hostname_cid("virtme/ng/21")
    >>> hostname
    'virtme/ng'
    >>> cid
    21
    >>> hostname, cid = get_hostname_cid("v/ng,20")
    >>> hostname
    'v/ng'
    >>> cid
    20

    """
    for sep in SEPARATORS:
        if sep not in ssh_dst:
            continue

        splitted = ssh_dst.rsplit(sep, 1)
        try:
            hostname, cid = splitted[0], int(splitted[1], 10)
        except ValueError:
            print(f"No integer given for CID: {splitted[1]}", file=sys.stderr)
            return (None, None)
        return hostname, cid

    print(
        f"SSH destination name includes no CID: {ssh_dst}. For example 'virtme-ng/30'",
        file=sys.stderr,
    )
    return (None, None)


# See https://docs.python.org/3/library/socket.html#socket.socket.sendmsg
def send_fds(sock: socket.socket, fds: Iterable[int]) -> None:
    ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]
    sock.sendmsg(
        [b"\0"],
        ancdata,
        socket.MSG_CMSG_CLOEXEC,
    )


def passfds(cid: int, port: int) -> None:
    sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM | socket.SOCK_CLOEXEC, 0)
    sock.connect((cid, port))
    return send_fds(socket.socket(fileno=1), [sock.fileno()])


def main() -> int:
    parser = argparse.ArgumentParser(
        prog="virtme-connect",
    )
    parser.add_argument("ssh_destination")
    parser.add_argument("-p", "--port", type=int, default=22)
    args = parser.parse_args()
    _hostname, cid = get_hostname_cid(args.ssh_destination)
    if cid is None:
        return ERROR_USAGE
    passfds(cid, args.port)
    return SUCCESS


if __name__ == "__main__":
    sys.exit(main())