# -*- coding: utf-8 -*-

# Copyright: (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from __future__ import annotations


DOCUMENTATION = r"""
---
module: wait_for
short_description: Waits for a condition before continuing
description:
     - You can wait for a set amount of time O(timeout), this is the default if nothing is specified or just O(timeout) is specified.
       This does not produce an error.
     - Waiting for a port to become available is useful for when services are not immediately available after their init scripts return
       which is true of certain Java application servers.
     - It is also useful when starting guests with the M(community.libvirt.virt) module and needing to pause until they are ready.
     - This module can also be used to wait for a regex match a string to be present in a file.
     - In Ansible 1.6 and later, this module can also be used to wait for a file to be available or
       absent on the filesystem.
     - In Ansible 1.8 and later, this module can also be used to wait for active connections to be closed before continuing, useful if a node
       is being rotated out of a load balancer pool.
     - For Windows targets, use the M(ansible.windows.win_wait_for) module instead.
version_added: "0.7"
options:
  host:
    description:
      - A resolvable hostname or IP address to wait for.
    type: str
    default: 127.0.0.1
  timeout:
    description:
      - Maximum number of seconds to wait for, when used with another condition it will force an error.
      - When used without other conditions it is equivalent of just sleeping.
    type: int
    default: 300
  connect_timeout:
    description:
      - Maximum number of seconds to wait for a connection to happen before closing and retrying.
    type: int
    default: 5
  delay:
    description:
      - Number of seconds to wait before starting to poll.
    type: int
    default: 0
  port:
    description:
      - Port number to poll.
      - O(path) and O(port) are mutually exclusive parameters.
    type: int
  active_connection_states:
    description:
      - The list of TCP connection states which are counted as active connections.
    type: list
    elements: str
    default: [ ESTABLISHED, FIN_WAIT1, FIN_WAIT2, SYN_RECV, SYN_SENT, TIME_WAIT ]
    version_added: "2.3"
  state:
    description:
      - Either V(present), V(started), or V(stopped), V(absent), or V(drained).
      - When checking a port V(started) will ensure the port is open, V(stopped) will check that it is closed, V(drained) will check for active connections.
      - When checking for a file or a search string V(present) or V(started) will ensure that the file or string is present before continuing,
        V(absent) will check that file is absent or removed.
    type: str
    choices: [ absent, drained, present, started, stopped ]
    default: started
  path:
    description:
      - Path to a file on the filesystem that must exist before continuing.
      - O(path) and O(port) are mutually exclusive parameters.
    type: path
    version_added: "1.4"
  search_regex:
    description:
      - Can be used to match a string in either a file or a socket connection.
      - Defaults to a multiline regex.
      - When inspecting a system log file and a static string, remember that Ansible by default logs its own actions there;
        see the notes and examples for information.
    type: str
    version_added: "1.4"
  exclude_hosts:
    description:
      - List of hosts or IPs to ignore when looking for active TCP connections for V(drained) state.
    type: list
    elements: str
    version_added: "1.8"
  sleep:
    description:
      - Number of seconds to sleep between checks.
      - Before Ansible 2.3 this was hardcoded to 1 second.
    type: int
    default: 1
    version_added: "2.3"
  msg:
    description:
      - This overrides the normal error message from a failure to meet the required conditions.
    type: str
    version_added: "2.4"
extends_documentation_fragment: action_common_attributes
attributes:
    check_mode:
        support: none
    diff_mode:
        support: none
    platform:
        platforms: posix
notes:
  - Under some circumstances when using mandatory access control, a path may always be treated as being absent even if it exists, but
    can't be modified or created by the remote user either.
  - When waiting for a path, symbolic links will be followed.  Many other modules that manipulate files do not follow symbolic links,
    so operations on the path using other modules may not work exactly as expected.
  - When searching a static string within a system log file, it is important to account for potential self-matching against log entries
    generated by the Ansible modules.  To prevent this, add a regular expression construct into the search string. For example, to match
    a literal string 'this thing', one could use a regular expression like 'this t[h]ing'.
seealso:
- module: ansible.builtin.wait_for_connection
- module: ansible.windows.win_wait_for
- module: community.windows.win_wait_for_process
author:
    - Jeroen Hoekx (@jhoekx)
    - John Jarvis (@jarv)
    - Andrii Radyk (@AnderEnder)
"""

EXAMPLES = r"""
- name: Sleep for 300 seconds and continue with play
  ansible.builtin.wait_for:
    timeout: 300
  delegate_to: localhost

- name: Wait for port 8000 to become open on the host, don't start checking for 10 seconds
  ansible.builtin.wait_for:
    port: 8000
    delay: 10

- name: Waits for port 8000 of any IP to close active connections, don't start checking for 10 seconds
  ansible.builtin.wait_for:
    host: 0.0.0.0
    port: 8000
    delay: 10
    state: drained

- name: Wait for port 8000 of any IP to close active connections, ignoring connections for specified hosts
  ansible.builtin.wait_for:
    host: 0.0.0.0
    port: 8000
    state: drained
    exclude_hosts: 10.2.1.2,10.2.1.3

- name: Wait until the file /tmp/foo is present before continuing
  ansible.builtin.wait_for:
    path: /tmp/foo

- name: Wait until the string "completed" is in the file /tmp/foo before continuing
  ansible.builtin.wait_for:
    path: /tmp/foo
    search_regex: completed

- name: Wait until the string "tomcat up" is in syslog, use regex character set to avoid self match
  ansible.builtin.wait_for:
    path: /var/log/syslog
    search_regex: 'tomcat [u]p'

- name: Wait until regex pattern matches in the file /tmp/foo and print the matched group
  ansible.builtin.wait_for:
    path: /tmp/foo
    search_regex: completed (?P<task>\w+)
  register: waitfor
- ansible.builtin.debug:
    msg: Completed {{ waitfor['match_groupdict']['task'] }}

- name: Wait until the lock file is removed
  ansible.builtin.wait_for:
    path: /var/lock/file.lock
    state: absent

- name: Wait until the process is finished and pid was destroyed
  ansible.builtin.wait_for:
    path: /proc/3466/status
    state: absent

- name: Output customized message when failed
  ansible.builtin.wait_for:
    path: /tmp/foo
    state: present
    msg: Timeout to find file /tmp/foo

# Do not assume the inventory_hostname is resolvable and delay 10 seconds at start
- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
  ansible.builtin.wait_for:
    port: 22
    host: '{{ (ansible_ssh_host|default(ansible_host))|default(inventory_hostname) }}'
    search_regex: OpenSSH
    delay: 10
    timeout: 300
  delegate_to: localhost

# Same as above but using config lookup for the target,
# most plugins use 'remote_addr', but ssh uses 'host'
- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
  ansible.builtin.wait_for:
    port: 22
    host: "{{ lookup('config', 'host', plugin_name='ssh', plugin_type='connection') }}"
    search_regex: OpenSSH
    delay: 10
    timeout: 300
  delegate_to: localhost
"""

RETURN = r"""
elapsed:
  description: The number of seconds that elapsed while waiting
  returned: always
  type: int
  sample: 23
match_groups:
  description: Tuple containing all the subgroups of the match as returned by U(https://docs.python.org/3/library/re.html#re.MatchObject.groups)
  returned: always
  type: list
  sample: ['match 1', 'match 2']
match_groupdict:
  description: Dictionary containing all the named subgroups of the match, keyed by the subgroup name,
    as returned by U(https://docs.python.org/3/library/re.html#re.MatchObject.groupdict)
  returned: always
  type: dict
  sample:
    {
      'group': 'match'
    }
"""

import binascii
import contextlib
import errno
import math
import mmap
import os
import re
import select
import socket
import time

from datetime import datetime, timedelta, timezone

from ansible.module_utils.basic import AnsibleModule, missing_required_lib
from ansible.module_utils.common.sys_info import get_platform_subclass
from ansible.module_utils.common.text.converters import to_bytes, to_native


HAS_PSUTIL = False
PSUTIL_IMP_ERR = None
try:
    import psutil
    HAS_PSUTIL = True
    # just because we can import it on Linux doesn't mean we will use it
except ImportError as ex:
    PSUTIL_IMP_ERR = ex


class TCPConnectionInfo(object):
    """
    This is a generic TCP Connection Info strategy class that relies
    on the psutil module, which is not ideal for targets, but necessary
    for cross platform support.

    A subclass may wish to override some or all of these methods.
      - _get_exclude_ips()
      - get_active_connections()

    All subclasses MUST define platform and distribution (which may be None).
    """
    platform = 'Generic'
    distribution = None

    match_all_ips = {
        socket.AF_INET: '0.0.0.0',
        socket.AF_INET6: '::',
    }
    ipv4_mapped_ipv6_address = {
        'prefix': '::ffff',
        'match_all': '::ffff:0.0.0.0'
    }

    def __new__(cls, *args, **kwargs):
        new_cls = get_platform_subclass(TCPConnectionInfo)
        return super(cls, new_cls).__new__(new_cls)

    def __init__(self, module):
        self.module = module
        self.ips = _convert_host_to_ip(module.params['host'])
        self.port = int(self.module.params['port'])
        self.exclude_ips = self._get_exclude_ips()
        if not HAS_PSUTIL:
            module.fail_json(msg=missing_required_lib('psutil'), exception=PSUTIL_IMP_ERR)

    def _get_exclude_ips(self):
        exclude_hosts = self.module.params['exclude_hosts']
        exclude_ips = []
        if exclude_hosts is not None:
            for host in exclude_hosts:
                exclude_ips.extend(_convert_host_to_ip(host))
        return exclude_ips

    def get_active_connections_count(self):
        active_connections = 0
        for p in psutil.process_iter():
            try:
                if hasattr(p, 'get_connections'):
                    connections = p.get_connections(kind='inet')
                else:
                    connections = p.connections(kind='inet')
            except psutil.Error:
                # Process is Zombie or other error state
                continue
            for conn in connections:
                if conn.status not in self.module.params['active_connection_states']:
                    continue
                if hasattr(conn, 'local_address'):
                    (local_ip, local_port) = conn.local_address
                else:
                    (local_ip, local_port) = conn.laddr
                if self.port != local_port:
                    continue
                if hasattr(conn, 'remote_address'):
                    (remote_ip, remote_port) = conn.remote_address
                else:
                    (remote_ip, remote_port) = conn.raddr
                if (conn.family, remote_ip) in self.exclude_ips:
                    continue
                if any((
                    (conn.family, local_ip) in self.ips,
                    (conn.family, self.match_all_ips[conn.family]) in self.ips,
                    local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
                        (conn.family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
                )):
                    active_connections += 1
        return active_connections


# ===========================================
# Subclass: Linux

class LinuxTCPConnectionInfo(TCPConnectionInfo):
    """
    This is a TCP Connection Info evaluation strategy class
    that utilizes information from Linux's procfs. While less universal,
    does allow Linux targets to not require an additional library.
    """
    platform = 'Linux'
    distribution = None

    source_file = {
        socket.AF_INET: '/proc/net/tcp',
        socket.AF_INET6: '/proc/net/tcp6'
    }
    match_all_ips = {
        socket.AF_INET: '00000000',
        socket.AF_INET6: '00000000000000000000000000000000',
    }
    ipv4_mapped_ipv6_address = {
        'prefix': '0000000000000000FFFF0000',
        'match_all': '0000000000000000FFFF000000000000'
    }
    local_address_field = 1
    remote_address_field = 2
    connection_state_field = 3

    def __init__(self, module):
        self.module = module
        self.ips = _convert_host_to_hex(module.params['host'])
        self.port = "%0.4X" % int(module.params['port'])
        self.exclude_ips = self._get_exclude_ips()

    def _get_exclude_ips(self):
        exclude_hosts = self.module.params['exclude_hosts']
        exclude_ips = []
        if exclude_hosts is not None:
            for host in exclude_hosts:
                exclude_ips.extend(_convert_host_to_hex(host))
        return exclude_ips

    def get_active_connections_count(self):
        active_connections = 0
        for family in self.source_file.keys():
            if not os.path.isfile(self.source_file[family]):
                continue
            try:
                with open(self.source_file[family]) as f:
                    for tcp_connection in f.readlines():
                        tcp_connection = tcp_connection.strip().split()
                        if tcp_connection[self.local_address_field] == 'local_address':
                            continue
                        if (tcp_connection[self.connection_state_field] not in
                                [get_connection_state_id(_connection_state) for _connection_state in self.module.params['active_connection_states']]):
                            continue
                        (local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
                        if self.port != local_port:
                            continue
                        (remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
                        if (family, remote_ip) in self.exclude_ips:
                            continue
                        if any((
                            (family, local_ip) in self.ips,
                            (family, self.match_all_ips[family]) in self.ips,
                            local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
                                (family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
                        )):
                            active_connections += 1
            except OSError:
                pass

        return active_connections


def _convert_host_to_ip(host):
    """
    Perform forward DNS resolution on host, IP will give the same IP

    Args:
        host: String with either hostname, IPv4, or IPv6 address

    Returns:
        List of tuples containing address family and IP
    """
    addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)
    ips = []
    for family, socktype, proto, canonname, sockaddr in addrinfo:
        ip = sockaddr[0]
        ips.append((family, ip))
        if family == socket.AF_INET:
            ips.append((socket.AF_INET6, "::ffff:" + ip))
    return ips


def _convert_host_to_hex(host):
    """
    Convert the provided host to the format in /proc/net/tcp*

    /proc/net/tcp uses little-endian four byte hex for ipv4
    /proc/net/tcp6 uses little-endian per 4B word for ipv6

    Args:
        host: String with either hostname, IPv4, or IPv6 address

    Returns:
        List of tuples containing address family and the
        little-endian converted host
    """
    ips = []
    if host is not None:
        for family, ip in _convert_host_to_ip(host):
            hexip_nf = binascii.b2a_hex(socket.inet_pton(family, ip))
            hexip_hf = ""
            for i in range(0, len(hexip_nf), 8):
                ipgroup_nf = hexip_nf[i:i + 8]
                ipgroup_hf = socket.ntohl(int(ipgroup_nf, base=16))
                hexip_hf = "%s%08X" % (hexip_hf, ipgroup_hf)
            ips.append((family, hexip_hf))
    return ips


def _timedelta_total_seconds(timedelta):
    return (
        timedelta.microseconds + 0.0 +
        (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6) / 10 ** 6


def get_connection_state_id(state):
    connection_state_id = {
        'ESTABLISHED': '01',
        'SYN_SENT': '02',
        'SYN_RECV': '03',
        'FIN_WAIT1': '04',
        'FIN_WAIT2': '05',
        'TIME_WAIT': '06',
    }
    return connection_state_id[state]


def main():

    module = AnsibleModule(
        argument_spec=dict(
            host=dict(type='str', default='127.0.0.1'),
            timeout=dict(type='int', default=300),
            connect_timeout=dict(type='int', default=5),
            delay=dict(type='int', default=0),
            port=dict(type='int'),
            active_connection_states=dict(type='list', elements='str', default=['ESTABLISHED', 'FIN_WAIT1', 'FIN_WAIT2', 'SYN_RECV', 'SYN_SENT', 'TIME_WAIT']),
            path=dict(type='path'),
            search_regex=dict(type='str'),
            state=dict(type='str', default='started', choices=['absent', 'drained', 'present', 'started', 'stopped']),
            exclude_hosts=dict(type='list', elements='str'),
            sleep=dict(type='int', default=1),
            msg=dict(type='str'),
        ),
    )

    host = module.params['host']
    timeout = module.params['timeout']
    connect_timeout = module.params['connect_timeout']
    delay = module.params['delay']
    port = module.params['port']
    state = module.params['state']

    path = module.params['path']
    b_path = to_bytes(path, errors='surrogate_or_strict', nonstring='passthru')

    search_regex = module.params['search_regex']
    b_search_regex = to_bytes(search_regex, errors='surrogate_or_strict', nonstring='passthru')

    msg = module.params['msg']

    if search_regex is not None:
        try:
            b_compiled_search_re = re.compile(b_search_regex, re.MULTILINE)
        except re.error as e:
            module.fail_json(msg="Invalid regular expression: %s" % e)
    else:
        b_compiled_search_re = None

    match_groupdict = {}
    match_groups = ()

    if port and path:
        module.fail_json(msg="port and path parameter can not both be passed to wait_for", elapsed=0)
    if path and state == 'stopped':
        module.fail_json(msg="state=stopped should only be used for checking a port in the wait_for module", elapsed=0)
    if path and state == 'drained':
        module.fail_json(msg="state=drained should only be used for checking a port in the wait_for module", elapsed=0)
    if module.params['exclude_hosts'] is not None and state != 'drained':
        module.fail_json(msg="exclude_hosts should only be with state=drained", elapsed=0)
    for _connection_state in module.params['active_connection_states']:
        try:
            get_connection_state_id(_connection_state)
        except Exception:
            module.fail_json(msg="unknown active_connection_state (%s) defined" % _connection_state, elapsed=0)

    start = datetime.now(timezone.utc)

    if delay:
        time.sleep(delay)

    if not port and not path and state != 'drained':
        time.sleep(timeout)
    elif state in ['absent', 'stopped']:
        # first wait for the stop condition
        end = start + timedelta(seconds=timeout)

        while datetime.now(timezone.utc) < end:
            if path:
                try:
                    if not os.access(b_path, os.F_OK):
                        break
                except OSError:
                    break
            elif port:
                try:
                    s = socket.create_connection((host, port), connect_timeout)
                    s.shutdown(socket.SHUT_RDWR)
                    s.close()
                except Exception:
                    break
            # Conditions not yet met, wait and try again
            time.sleep(module.params['sleep'])
        else:
            elapsed = datetime.now(timezone.utc) - start
            if port:
                module.fail_json(msg=msg or "Timeout when waiting for %s:%s to stop." % (host, port), elapsed=elapsed.seconds)
            elif path:
                module.fail_json(msg=msg or "Timeout when waiting for %s to be absent." % (path), elapsed=elapsed.seconds)

    elif state in ['started', 'present']:
        # wait for start condition
        end = start + timedelta(seconds=timeout)
        while datetime.now(timezone.utc) < end:
            if path:
                try:
                    os.stat(b_path)
                except OSError as e:
                    # If anything except file not present, throw an error
                    if e.errno != 2:
                        elapsed = datetime.now(timezone.utc) - start
                        module.fail_json(msg=msg or "Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
                    # file doesn't exist yet, so continue
                else:
                    # File exists.  Are there additional things to check?
                    if not b_compiled_search_re:
                        # nope, succeed!
                        break

                    try:
                        with open(b_path, 'rb') as f:
                            try:
                                with contextlib.closing(mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)) as mm:
                                    search = b_compiled_search_re.search(mm)
                                    if search:
                                        if search.groupdict():
                                            match_groupdict = search.groupdict()
                                        if search.groups():
                                            match_groups = search.groups()
                                        break
                            except (ValueError, OSError) as e:
                                module.debug('wait_for failed to use mmap on "%s": %s. Falling back to file read().' % (path, to_native(e)))
                                # cannot mmap this file, try normal read
                                search = re.search(b_compiled_search_re, f.read())
                                if search:
                                    if search.groupdict():
                                        match_groupdict = search.groupdict()
                                    if search.groups():
                                        match_groups = search.groups()
                                    break
                            except Exception as e:
                                module.warn('wait_for failed on "%s", unexpected exception(%s): %s.).' % (path, to_native(e.__class__), to_native(e)))
                    except OSError:
                        pass
            elif port:
                alt_connect_timeout = math.ceil(
                    _timedelta_total_seconds(end - datetime.now(timezone.utc)),
                )
                try:
                    s = socket.create_connection((host, int(port)), min(connect_timeout, alt_connect_timeout))
                except Exception:
                    # Failed to connect by connect_timeout. wait and try again
                    pass
                else:
                    # Connected -- are there additional conditions?
                    if b_compiled_search_re:
                        b_data = b''
                        matched = False
                        while datetime.now(timezone.utc) < end:
                            max_timeout = math.ceil(
                                _timedelta_total_seconds(
                                    end - datetime.now(timezone.utc),
                                ),
                            )
                            readable = select.select([s], [], [], max_timeout)[0]
                            if not readable:
                                # No new data.  Probably means our timeout
                                # expired
                                continue
                            response = s.recv(1024)
                            if not response:
                                # Server shutdown
                                break
                            b_data += response
                            if b_compiled_search_re.search(b_data):
                                matched = True
                                break

                        # Shutdown the client socket
                        try:
                            s.shutdown(socket.SHUT_RDWR)
                        except OSError as ex:
                            if ex.errno != errno.ENOTCONN:
                                raise
                        # else, the server broke the connection on its end, assume it's not ready
                        else:
                            s.close()
                        if matched:
                            # Found our string, success!
                            break
                    else:
                        # Connection established, success!
                        try:
                            s.shutdown(socket.SHUT_RDWR)
                        except OSError as ex:
                            if ex.errno != errno.ENOTCONN:
                                raise
                        # else, the server broke the connection on its end, assume it's not ready
                        else:
                            s.close()
                        break

            # Conditions not yet met, wait and try again
            time.sleep(module.params['sleep'])

        else:   # while-else
            # Timeout expired
            elapsed = datetime.now(timezone.utc) - start
            if port:
                if search_regex:
                    module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s:%s" % (search_regex, host, port), elapsed=elapsed.seconds)
                else:
                    module.fail_json(msg=msg or "Timeout when waiting for %s:%s" % (host, port), elapsed=elapsed.seconds)
            elif path:
                if search_regex:
                    module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s" % (search_regex, path), elapsed=elapsed.seconds)
                else:
                    module.fail_json(msg=msg or "Timeout when waiting for file %s" % (path), elapsed=elapsed.seconds)

    elif state == 'drained':
        # wait until all active connections are gone
        end = start + timedelta(seconds=timeout)
        tcpconns = TCPConnectionInfo(module)
        while datetime.now(timezone.utc) < end:
            if tcpconns.get_active_connections_count() == 0:
                break

            # Conditions not yet met, wait and try again
            time.sleep(module.params['sleep'])
        else:
            elapsed = datetime.now(timezone.utc) - start
            module.fail_json(msg=msg or "Timeout when waiting for %s:%s to drain" % (host, port), elapsed=elapsed.seconds)

    elapsed = datetime.now(timezone.utc) - start
    module.exit_json(state=state, port=port, search_regex=search_regex, match_groups=match_groups, match_groupdict=match_groupdict, path=path,
                     elapsed=elapsed.seconds)


if __name__ == '__main__':
    main()
