# -*- coding: utf-8 -*-

# Copyright (C) 2010-2023 by Mike Gabriel <mike.gabriel@das-netzwerkteam.de>
#
# Python X2Go is free software; you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Python X2Go is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program; if not, write to the
# Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.

"""\
Providing mechanisms to ``X2GoControlSession*`` backends for checking host validity.

"""
__NAME__ = 'x2gocheckhosts-pylib'

__package__ = 'x2go'
__name__    = 'x2go.checkhosts'

# modules
import paramiko
import binascii
import sys

# Python X2Go modules
from . import log
from . import x2go_exceptions
import random
import string

class X2GoMissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):
    """\
    Skeleton class for Python X2Go's missing host key policies.


    """
    def __init__(self, caller=None, session_instance=None, fake_hostname=None):
        """\
        :param caller: calling instance
        :type caller: ``class``
        :param session_instance: an X2Go session instance
        :type session_instance: :class:`x2go.session.X2GoSession` instance

        """
        self.caller = caller
        self.session_instance = session_instance
        self.fake_hostname = fake_hostname

    def get_client(self):
        """\
        Retrieve the Paramiko SSH/Client.


        :returns: the associated X2Go control session instance.

        :rtype: ``X2GoControlSession*`` instance

        """
        return self.client

    def get_hostname(self):
        """\
        Retrieve the server hostname:port expression of the server to be validated.


        :returns: hostname:port

        :rtype: ``str``

        """
        return self.fake_hostname or self.hostname

    def get_hostname_name(self):
        """\
        Retrieve the server hostname string of the server to be validated.


        :returns: hostname

        :rtype: ``str``

        """
        if ":" in self.get_hostname():
            return self.get_hostname().split(':')[0].lstrip('[').rstrip(']')
        else:
            return self.get_hostname().lstrip('[').rstrip(']')

    def get_hostname_port(self):
        """\
        Retrieve the server port of the server to be validated.


        :returns: port

        :rtype: ``str``

        """
        if ":" in self.get_hostname():
            return int(self.get_hostname().split(':')[1])
        else:
            return 22

    def get_key(self):
        """\
        Retrieve the host key of the server to be validated.


        :returns: host key

        :rtype: Paramiko/SSH key instance

        """
        return self.key

    def get_key_name(self):
        """\
        Retrieve the host key name of the server to be validated.


        :returns: host key name (RSA, DSA, ECDSA...)

        :rtype: ``str``

        """
        return self.key.get_name().upper()

    def get_key_fingerprint(self):
        """\
        Retrieve the host key fingerprint of the server to be validated.


        :returns: host key fingerprint

        :rtype: ``str``

        """
        if sys.version_info[0] >= 3:
            return binascii.hexlify(self.key.get_fingerprint()).decode()
        else:
            return binascii.hexlify(self.key.get_fingerprint())

    def get_key_fingerprint_with_colons(self):
        """\
        Retrieve the (colonized) host key fingerprint of the server
        to be validated.


        :returns: host key fingerprint (with colons)

        :rtype: ``str``

        """
        _fingerprint = self.get_key_fingerprint()
        _colon_fingerprint = ''
        idx = 0
        for char in _fingerprint:
            idx += 1
            _colon_fingerprint += char
            if idx % 2 == 0:
                _colon_fingerprint += ':'
        return _colon_fingerprint.rstrip(':')


class X2GoAutoAddPolicy(X2GoMissingHostKeyPolicy):

    def missing_host_key(self, client, hostname, key):
        self.client = client
        self.hostname = hostname
        self.key = key
        if self.session_instance and self.session_instance.control_session.unique_hostkey_aliases:
            self.client._host_keys.add(self.session_instance.get_profile_id(), self.key.get_name(), self.key)
        else:
            self.client._host_keys.add(self.get_hostname(), self.key.get_name(), self.key)
        if self.client._host_keys_filename is not None:
            self.client.save_host_keys(self.client._host_keys_filename)
        self.client._log(paramiko.common.DEBUG, 'Adding %s host key for %s: %s' %
                         (self.key.get_name(), self.get_hostname(), binascii.hexlify(self.key.get_fingerprint())))


class X2GoInteractiveAddPolicy(X2GoMissingHostKeyPolicy):
    """\
    Policy for making host key information available to Python X2Go after a
    Paramiko/SSH connect has been attempted. This class needs information
    about the associated :class:`x2go.session.X2GoSession` instance.

    Once called, the :func:`missing_host_key()` method of this class will try to call
    :func:`X2GoSession.HOOK_check_host_dialog() <x2go.session.X2GoSession.HOOK_check_host_dialog()>`. This hook method---if not re-defined
    in your application---will then try to call the :func:`X2GoClient.HOOK_check_host_dialog() <x2go.client.X2GoClient.HOOK_check_host_dialog()>`,
    which then will return ``True`` by default if not customized in your application.

    To accept host key checks, make sure to either customize the
    :func:`X2GoClient.HOOK_check_host_dialog() <x2go.client.X2GoClient.HOOK_check_host_dialog()>` method or the :func:`X2GoSession.HOOK_check_host_dialog() <x2go.session.X2GoSession.HOOK_check_host_dialog()>`
    method and hook some interactive user dialog to either of them.


    """
    def missing_host_key(self, client, hostname, key):
        """\
        Handle a missing host key situation. This method calls

        Once called, the :func:`missing_host_key()` method will try to call
        :func:`X2GoSession.HOOK_check_host_dialog() <x2go.session.X2GoSession.HOOK_check_host_dialog()>`. This hook method---if not re-defined
        in your application---will then try to call the :func:`X2GoClient.HOOK_check_host_dialog() <x2go.client.X2GoClient.HOOK_check_host_dialog()>`,
        which then will return ``True`` by default if not customized in your application.

        To accept host key checks, make sure to either customize the
        :func:`X2GoClient.HOOK_check_host_dialog() <x2go.client.X2GoClient.HOOK_check_host_dialog()>` method or the :func:`X2GoSession.HOOK_check_host_dialog() <x2go.session.X2GoSession.HOOK_check_host_dialog()>`
        method and hook some interactive user dialog to either of them.

        :param client: SSH client (``X2GoControlSession*``) instance
        :type client: ``X2GoControlSession*`` instance
        :param hostname: remote hostname
        :type hostname: ``str``
        :param key: host key to validate
        :type key: Paramiko/SSH key instance
        :raises X2GoHostKeyException: if the X2Go server host key is not in the ``known_hosts`` file
        :raises X2GoSSHProxyHostKeyException: if the SSH proxy host key is not in the ``known_hosts`` file
        :raises SSHException: if this instance does not know its {self.session_instance}

        """
        self.client = client
        self.hostname = hostname
        if (self.hostname.find(']') == -1) and (self.hostname.find(':') == -1):
            # if hostname is an IPv4 quadruple with standard SSH port...
            self.hostname = '[%s]:22' % self.hostname
        self.key = key
        self.client._log(paramiko.common.DEBUG, 'Interactively Checking %s host key for %s: %s' %
                         (self.key.get_name(), self.get_hostname(), binascii.hexlify(self.key.get_fingerprint())))
        if self.session_instance:

            if self.fake_hostname is not None:
                server_key = client.get_transport().get_remote_server_key()
                keytype = server_key.get_name()
                our_server_key = client._system_host_keys.get(self.fake_hostname, {}).get(keytype, None)
                if our_server_key is None:
                    if self.session_instance.control_session.unique_hostkey_aliases:
                        our_server_key = client._host_keys.get(self.session_instance.get_profile_id(), {}).get(keytype, None)
                        if our_server_key is not None:
                            self.session_instance.logger('SSH host key verification for SSH-proxied host %s with %s fingerprint ,,%s\'\' succeeded. This host is known by the X2Go session profile ID of profile ,,%s\'\'.' % (self.fake_hostname, self.get_key_name(), self.get_key_fingerprint_with_colons(), self.session_instance.profile_name), loglevel=log.loglevel_NOTICE)
                            return
                    else:
                        our_server_key = client._host_keys.get(self.fake_hostname, {}).get(keytype, None)
                        if our_server_key is not None:
                            self.session_instance.logger('SSH host key verification for SSH-proxied host %s with %s fingerprint ,,%s\'\' succeeded. This host is known by the address it has behind the SSH proxy host.' % (self.fake_hostname, self.get_key_name(), self.get_key_fingerprint_with_colons()), loglevel=log.loglevel_NOTICE)
                            return

            self.session_instance.logger('SSH host key verification for host %s with %s fingerprint ,,%s\'\' initiated. We are seeing this X2Go server for the first time.' % (self.get_hostname(), self.get_key_name(), self.get_key_fingerprint_with_colons()), loglevel=log.loglevel_NOTICE)
            _valid = self.session_instance.HOOK_check_host_dialog(self.get_hostname_name(),
                                                                  port=self.get_hostname_port(),
                                                                  fingerprint=self.get_key_fingerprint_with_colons(),
                                                                  fingerprint_type=self.get_key_name(),
                                                                 )
            if _valid:
                if self.session_instance.control_session.unique_hostkey_aliases and type(self.caller) not in (sshproxy.X2GoSSHProxy, ):
                    paramiko.AutoAddPolicy().missing_host_key(client, self.session_instance.get_profile_id(), key)
                else:
                    paramiko.AutoAddPolicy().missing_host_key(client, self.get_hostname(), key)

            else:
                if type(self.caller) in (sshproxy.X2GoSSHProxy, ):
                    raise x2go_exceptions.X2GoSSHProxyHostKeyException('Invalid host %s is not authorized for access. Add the host to Paramiko/SSH\'s known_hosts file.' % self.get_hostname())
                else:
                    raise x2go_exceptions.X2GoHostKeyException('Invalid host %s is not authorized for access. Add the host to Paramiko/SSH\'s known_hosts file.' % self.get_hostname())
        else:
            raise x2go_exceptions.SSHException('Policy has collected host key information on %s for further introspection' % self.get_hostname())


def check_ssh_host_key(x2go_sshclient_instance, hostname, port=22):
    """\
    Perform a Paramiko/SSH host key check by connecting to the host and
    validating the results (i.e. by validating raised exceptions during the
    connect process).

    :param x2go_sshclient_instance: a Paramiko/SSH client instance to be used for testing host key validity.
    :type x2go_sshclient_instance: ``X2GoControlSession*`` instance
    :param hostname: hostname of server to validate
    :type hostname: ``str``
    :param port: port of server to validate (Default value = 22)
    :type port: ``int``
    :returns: returns a tuple with the following components (<host_ok>, <hostname>, <port>, <fingerprint>, <fingerprint_type>)
    :rtype: ``tuple``
    :raises SSHException: if an SSH exception occurred, that we did not provocate in :func:`X2GoInteractiveAddPolicy.missing_host_key() <x2go.checkhosts.X2GoInteractiveAddPolicy.missing_host_key()`

    """
    _hostname = hostname
    _port = port
    _fingerprint = 'NO-FINGERPRINT'
    _fingerprint_type = 'SOME-KEY-TYPE'

    _check_policy = X2GoInteractiveAddPolicy()
    x2go_sshclient_instance.set_missing_host_key_policy(_check_policy)

    host_ok = False
    try:
        paramiko.SSHClient.connect(x2go_sshclient_instance, hostname=hostname, port=port, username='foo', password="".join([random.choice(string.letters+string.digits) for x in range(1, 20)]))
    except x2go_exceptions.AuthenticationException:
        host_ok = True
        x2go_sshclient_instance.logger('SSH host key verification for host [%s]:%s succeeded. Host is already known to the client\'s Paramiko/SSH sub-system.' % (_hostname, _port), loglevel=log.loglevel_NOTICE)
    except x2go_exceptions.SSHException as e:
        msg = str(e)
        if msg.startswith('Policy has collected host key information on '):
            _hostname = _check_policy.get_hostname().split(':')[0].lstrip('[').rstrip(']')
            _port = _check_policy.get_hostname().split(':')[1]
            _fingerprint = _check_policy.get_key_fingerprint_with_colons()
            _fingerprint_type = _check_policy.get_key_name()
        else:
            raise(e)
        x2go_sshclient_instance.set_missing_host_key_policy(paramiko.RejectPolicy())
    except:
        # let any other error be handled by subsequent algorithms
        pass

    return (host_ok, _hostname, _port, _fingerprint, _fingerprint_type)
