1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
#!/usr/bin/env python3
import argparse
import array
import socket
import sys
from collections.abc import Iterable
# Supported hostname and CID separators
#
# Important: Keep in sync with `VIRTME_SSH_HOSTNAME_CID_SEPARATORS`
SEPARATORS = ("%", "/")
# Some exit codes
SUCCESS = 0
ERROR_USAGE = 2
def get_hostname_cid(ssh_dst: str) -> tuple[str | None, int | None]:
"""Return the hostname and CID
>>> hostname, cid = get_hostname_cid("virtme-ng/24")
>>> hostname
'virtme-ng'
>>> cid
24
>>> hostname, cid = get_hostname_cid("virtme/ng/21")
>>> hostname
'virtme/ng'
>>> cid
21
>>> hostname, cid = get_hostname_cid("v/ng,20")
>>> hostname
'v/ng'
>>> cid
20
"""
for sep in SEPARATORS:
if sep not in ssh_dst:
continue
splitted = ssh_dst.rsplit(sep, 1)
try:
hostname, cid = splitted[0], int(splitted[1], 10)
except ValueError:
print(f"No integer given for CID: {splitted[1]}", file=sys.stderr)
return (None, None)
return hostname, cid
print(
f"SSH destination name includes no CID: {ssh_dst}. For example 'virtme-ng/30'",
file=sys.stderr,
)
return (None, None)
# See https://docs.python.org/3/library/socket.html#socket.socket.sendmsg
def send_fds(sock: socket.socket, fds: Iterable[int]) -> None:
ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]
sock.sendmsg(
[b"\0"],
ancdata,
socket.MSG_CMSG_CLOEXEC,
)
def passfds(cid: int, port: int) -> None:
sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM | socket.SOCK_CLOEXEC, 0)
sock.connect((cid, port))
return send_fds(socket.socket(fileno=1), [sock.fileno()])
def main() -> int:
parser = argparse.ArgumentParser(
prog="virtme-connect",
)
parser.add_argument("ssh_destination")
parser.add_argument("-p", "--port", type=int, default=22)
args = parser.parse_args()
_hostname, cid = get_hostname_cid(args.ssh_destination)
if cid is None:
return ERROR_USAGE
passfds(cid, args.port)
return SUCCESS
if __name__ == "__main__":
sys.exit(main())
|