#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>


import os
import subprocess
import traceback
from collections.abc import Iterator, Sequence
from contextlib import suppress
from typing import Any

from kitty.types import run_once
from kitty.utils import SSHConnectionData


@run_once
def ssh_options() -> dict[str, str]:
    try:
        p = subprocess.run(['ssh'], stderr=subprocess.PIPE, encoding='utf-8')
        raw = p.stderr or ''
    except FileNotFoundError:
        return {
            '4': '', '6': '', 'A': '', 'a': '', 'C': '', 'f': '', 'G': '', 'g': '', 'K': '', 'k': '',
            'M': '', 'N': '', 'n': '', 'q': '', 's': '', 'T': '', 't': '', 'V': '', 'v': '', 'X': '',
            'x': '', 'Y': '', 'y': '', 'B': 'bind_interface', 'b': 'bind_address', 'c': 'cipher_spec',
            'D': '[bind_address:]port', 'E': 'log_file', 'e': 'escape_char', 'F': 'configfile', 'I': 'pkcs11',
            'i': 'identity_file', 'J': '[user@]host[:port]', 'L': 'address', 'l': 'login_name', 'm': 'mac_spec',
            'O': 'ctl_cmd', 'o': 'option', 'p': 'port', 'Q': 'query_option', 'R': 'address',
            'S': 'ctl_path', 'W': 'host:port', 'w': 'local_tun[:remote_tun]'
        }

    ans: dict[str, str] = {}
    pos = 0
    while True:
        pos = raw.find('[', pos)  # ]
        if pos < 0:
            break
        num = 1
        epos = pos
        while num > 0:
            epos += 1
            if raw[epos] not in '[]':
                continue
            num += 1 if raw[epos] == '[' else -1  # ]
        q = raw[pos+1:epos]
        pos = epos
        if len(q) < 2 or q[0] != '-':
            continue
        if ' ' in q:
            opt, desc = q.split(' ', 1)
            ans[opt[1:]] = desc
        else:
            ans.update(dict.fromkeys(q[1:], ''))
    return ans


def is_kitten_cmdline(q: Sequence[str]) -> bool:
    if not q:
        return False
    exe_name = os.path.basename(q[0]).lower()
    if exe_name == 'kitten' and q[1:2] == ['ssh']:
        return True
    if len(q) < 4:
        return False
    if exe_name != 'kitty':
        return False
    if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
        return True
    return q[1:3] == ['+runpy', 'from kittens.runner import main; main()'] and len(q) >= 6 and q[5] == 'ssh'


def patch_cmdline(key: str, val: str, argv: list[str]) -> None:
    for i, arg in enumerate(tuple(argv)):
        if arg.startswith(f'--kitten={key}='):
            argv[i] = f'--kitten={key}={val}'
            return
        elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')):
            argv[i] = f'{key}={val}'
            return
    idx = argv.index('ssh')
    argv.insert(idx + 1, f'--kitten={key}={val}')


def remove_env_var_from_cmdline(key: str, argv: list[str]) -> None:
    while True:
        for i, arg in enumerate(tuple(argv)):
            if arg.startswith(f'--kitten=env={key}='):
                del argv[i]
                break
            elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'env={key}=') or arg.startswith(f'env {key}=')):
                del argv[i-1:i+1]
                break
        else:
            break


def set_single_env_var_in_cmdline(key: str, val: str, argv: list[str]) -> None:
    remove_env_var_from_cmdline(key, argv)
    idx = argv.index('ssh')
    argv.insert(idx+1, f'--kitten=env={key}={val}')


def set_cwd_in_cmdline(cwd: str, argv: list[str]) -> None:
    patch_cmdline('cwd', cwd, argv)


def create_shared_memory(data: Any, prefix: str) -> str:
    import atexit
    import json

    from kitty.fast_data_types import get_boss
    from kitty.shm import SharedMemory
    db = json.dumps(data).encode('utf-8')
    with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm:
        shm.write_data_with_size(db)
        shm.flush()
        atexit.register(shm.close)  # keeps shm alive till exit
        get_boss().atexit.shm_unlink(shm.name)
    return shm.name


def read_data_from_shared_memory(shm_name: str) -> Any:
    import json
    import stat

    from kitty.shm import SharedMemory
    with SharedMemory(shm_name, readonly=True) as shm:
        shm.unlink()
        if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
            raise ValueError(f'Incorrect owner on pwfile: uid={shm.stats.st_uid} gid={shm.stats.st_gid}')
        mode = stat.S_IMODE(shm.stats.st_mode)
        if mode != stat.S_IREAD | stat.S_IWRITE:
            raise ValueError(f'Incorrect permissions on pwfile: 0o{mode:03o}')
        return json.loads(shm.read_data_with_size())


def get_ssh_data(msgb: memoryview, request_id: str) -> Iterator[bytes|memoryview]:
    from base64 import standard_b64decode
    yield b'\nKITTY_DATA_START\n'  # to discard leading data
    try:
        msg = standard_b64decode(msgb).decode('utf-8')
        md = dict(x.split('=', 1) for x in msg.split(':'))
        pw = md['pw']
        pwfilename = md['pwfile']
        rq_id = md['id']
    except Exception:
        traceback.print_exc()
        yield b'invalid ssh data request message\n'
    else:
        try:
            env_data = read_data_from_shared_memory(pwfilename)
            if pw != env_data['pw']:
                raise ValueError('Incorrect password')
            if rq_id != request_id:
                raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
        except Exception as e:
            traceback.print_exc()
            yield f'{e}\n'.encode()
        else:
            yield b'OK\n'
            encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
            # macOS has a 255 byte limit on its input queue as per man stty.
            # Not clear if that applies to canonical mode input as well, but
            # better to be safe.
            line_sz = 254
            while encoded_data:
                yield encoded_data[:line_sz]
                yield b'\n'
                encoded_data = encoded_data[line_sz:]
            yield b'KITTY_DATA_END\n'


def set_env_in_cmdline(env: dict[str, str], argv: list[str], clone: bool = True) -> None:
    from kitty.options.utils import DELETE_ENV_VAR
    if clone:
        patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
        return
    idx = argv.index('ssh') - 1
    for i in range(idx, len(argv)):
        if argv[i] == '--kitten':
            idx = i + 1
        elif argv[i].startswith('--kitten='):
            idx = i
    env_dirs = []
    for k, v in env.items():
        if v is DELETE_ENV_VAR:
            x = f'--kitten=env={k}'
        else:
            x = f'--kitten=env={k}={v}'
        env_dirs.append(x)
    argv[idx+1:idx+1] = env_dirs


def get_ssh_cli() -> tuple[set[str], set[str]]:
    other_ssh_args: set[str] = set()
    boolean_ssh_args: set[str] = set()
    for k, v in ssh_options().items():
        k = f'-{k}'
        if v:
            other_ssh_args.add(k)
        else:
            boolean_ssh_args.add(k)
    return boolean_ssh_args, other_ssh_args


def is_extra_arg(arg: str, extra_args: tuple[str, ...]) -> str:
    for x in extra_args:
        if arg == x or arg.startswith(f'{x}='):
            return x
    return ''


passthrough_args = {f'-{x}' for x in 'NnfGT'}


def set_server_args_in_cmdline(
    server_args: list[str], argv: list[str],
    extra_args: tuple[str, ...] = ('--kitten',),
    allocate_tty: bool = False
) -> None:
    boolean_ssh_args, other_ssh_args = get_ssh_cli()
    ssh_args = []
    expecting_option_val = False
    found_extra_args: list[str] = []
    expecting_extra_val = ''
    ans = list(argv)
    found_ssh = False
    for i, argument in enumerate(argv):
        if not found_ssh:
            found_ssh = argument == 'ssh'
            continue
        if argument.startswith('-') and not expecting_option_val:
            if argument == '--':
                del ans[i+2:]
                if allocate_tty and ans[i-1] != '-t':
                    ans.insert(i, '-t')
                break
            if extra_args:
                matching_ex = is_extra_arg(argument, extra_args)
                if matching_ex:
                    if '=' in argument:
                        exval = argument.partition('=')[-1]
                        found_extra_args.extend((matching_ex, exval))
                    else:
                        expecting_extra_val = matching_ex
                        expecting_option_val = True
                    continue
            # could be a multi-character option
            all_args = argument[1:]
            for i, arg in enumerate(all_args):
                arg = f'-{arg}'
                if arg in boolean_ssh_args:
                    ssh_args.append(arg)
                    continue
                if arg in other_ssh_args:
                    ssh_args.append(arg)
                    rest = all_args[i+1:]
                    if rest:
                        ssh_args.append(rest)
                    else:
                        expecting_option_val = True
                    break
                raise KeyError(f'unknown option -- {arg[1:]}')
            continue
        if expecting_option_val:
            if expecting_extra_val:
                found_extra_args.extend((expecting_extra_val, argument))
                expecting_extra_val = ''
            else:
                ssh_args.append(argument)
            expecting_option_val = False
            continue
        del ans[i+1:]
        if allocate_tty and ans[i] != '-t':
            ans.insert(i, '-t')
        break
    argv[:] = ans + server_args


def get_connection_data(args: list[str], cwd: str = '', extra_args: tuple[str, ...] = ()) -> SSHConnectionData | None:
    boolean_ssh_args, other_ssh_args = get_ssh_cli()
    port: int | None = None
    expecting_port = expecting_identity = False
    expecting_option_val = False
    expecting_hostname = False
    expecting_extra_val = ''
    host_name = identity_file = found_ssh = ''
    found_extra_args: list[tuple[str, str]] = []

    for i, arg in enumerate(args):
        if not found_ssh:
            if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
                found_ssh = arg
            continue
        if expecting_hostname:
            host_name = arg
            continue
        if arg.startswith('-') and not expecting_option_val:
            if arg in boolean_ssh_args:
                continue
            if arg == '--':
                expecting_hostname = True
            if arg.startswith('-p'):
                if arg[2:].isdigit():
                    with suppress(Exception):
                        port = int(arg[2:])
                    continue
                elif arg == '-p':
                    expecting_port = True
            elif arg.startswith('-i'):
                if arg == '-i':
                    expecting_identity = True
                else:
                    identity_file = arg[2:]
                    continue
            if arg.startswith('--') and extra_args:
                matching_ex = is_extra_arg(arg, extra_args)
                if matching_ex:
                    if '=' in arg:
                        exval = arg.partition('=')[-1]
                        found_extra_args.append((matching_ex, exval))
                        continue
                    expecting_extra_val = matching_ex

            expecting_option_val = True
            continue

        if expecting_option_val:
            if expecting_port:
                with suppress(Exception):
                    port = int(arg)
                expecting_port = False
            elif expecting_identity:
                identity_file = arg
            elif expecting_extra_val:
                found_extra_args.append((expecting_extra_val, arg))
                expecting_extra_val = ''
            expecting_option_val = False
            continue

        if not host_name:
            host_name = arg
    if not host_name:
        return None
    if host_name.startswith('ssh://'):
        from urllib.parse import urlparse
        purl = urlparse(host_name)
        if purl.hostname:
            host_name = purl.hostname
        if purl.username:
            host_name = f'{purl.username}@{host_name}'
        if port is None and purl.port:
            port = purl.port
    if identity_file:
        if not os.path.isabs(identity_file):
            identity_file = os.path.expanduser(identity_file)
        if not os.path.isabs(identity_file):
            identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))

    return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
