# Copyright 2012-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Discover environment and server configuration, initialize PyMongo client."""

import os
import socket
import sys
from functools import wraps
from test.utils import create_user
from test.version import Version
from unittest import SkipTest

import pymongo.errors

# mypy: ignore-errors


HAVE_SSL = True
try:
    import ssl
except ImportError:
    HAVE_SSL = False
    ssl = None

HAVE_TORNADO = True
try:
    import tornado
except ImportError:
    HAVE_TORNADO = False
    tornado = None

HAVE_ASYNCIO = True
try:
    import asyncio
except ImportError:
    HAVE_ASYNCIO = False
    asyncio = None

HAVE_AIOHTTP = True
try:
    import aiohttp
except ImportError:
    HAVE_AIOHTTP = False
    aiohttp = None

HAVE_PYMONGOCRYPT = True
try:
    import pymongocrypt  # noqa: F401
except ImportError:
    HAVE_PYMONGOCRYPT = False


# Copied from PyMongo.
def partition_node(node):
    """Split a host:port string into (host, int(port)) pair."""
    host = node
    port = 27017
    idx = node.rfind(":")
    if idx != -1:
        host, port = node[:idx], int(node[idx + 1 :])
    if host.startswith("["):
        host = host[1:-1]
    return host, port


def connected(client):
    """Convenience, wait for a new PyMongo MongoClient to connect."""
    client.admin.command("ping")  # Force connection.
    return client


# If these are set to the empty string, substitute None.
db_user = os.environ.get("DB_USER") or None
db_password = os.environ.get("DB_PASSWORD") or None

CERT_PATH = os.environ.get(
    "CERT_DIR", os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
)
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
MONGODB_X509_USERNAME = "CN=client,OU=kerneluser,O=10Gen,L=New York City,ST=New York,C=US"


def is_server_resolvable():
    """Returns True if 'server' is resolvable."""
    socket_timeout = socket.getdefaulttimeout()
    socket.setdefaulttimeout(1)
    try:
        socket.gethostbyname("server")
        return True
    except OSError:
        return False
    finally:
        socket.setdefaulttimeout(socket_timeout)


class TestEnvironment:
    __test__ = False

    def __init__(self):
        self.initialized = False
        self.host = None
        self.port = None
        self.mongod_started_with_ssl = False
        self.mongod_validates_client_cert = False
        self.server_is_resolvable = is_server_resolvable()
        self.sync_cx = None
        self.is_standalone = False
        self.is_mongos = False
        self.is_replica_set = False
        self.rs_name = None
        self.w = 1
        self.hosts = None
        self.arbiters = None
        self.primary = None
        self.secondaries = None
        self.v8 = False
        self.auth = False
        self.uri = None
        self.rs_uri = None
        self.version = None
        self.sessions_enabled = False
        self.fake_hostname_uri = None
        self.server_status = None

    def setup(self):
        assert not self.initialized
        self.setup_sync_cx()
        self.setup_auth_and_uri()
        self.setup_version()
        self.setup_v8()
        self.server_status = self.sync_cx.admin.command("serverStatus")
        self.initialized = True

    def setup_sync_cx(self):
        """Get a synchronous PyMongo MongoClient and determine SSL config."""
        host = os.environ.get("DB_IP", "localhost")
        port = int(os.environ.get("DB_PORT", 27017))
        connectTimeoutMS = 100
        serverSelectionTimeoutMS = 100
        socketTimeoutMS = 10000
        try:
            client = connected(
                pymongo.MongoClient(
                    host,
                    port,
                    username=db_user,
                    password=db_password,
                    directConnection=True,
                    connectTimeoutMS=connectTimeoutMS,
                    socketTimeoutMS=socketTimeoutMS,
                    serverSelectionTimeoutMS=serverSelectionTimeoutMS,
                    tlsCAFile=CA_PEM,
                    ssl=True,
                )
            )

            self.mongod_started_with_ssl = True
        except pymongo.errors.ServerSelectionTimeoutError:
            try:
                client = connected(
                    pymongo.MongoClient(
                        host,
                        port,
                        username=db_user,
                        password=db_password,
                        directConnection=True,
                        connectTimeoutMS=connectTimeoutMS,
                        socketTimeoutMS=socketTimeoutMS,
                        serverSelectionTimeoutMS=serverSelectionTimeoutMS,
                        tlsCAFile=CA_PEM,
                        tlsCertificateKeyFile=CLIENT_PEM,
                    )
                )

                self.mongod_started_with_ssl = True
                self.mongod_validates_client_cert = True
            except pymongo.errors.ServerSelectionTimeoutError:
                client = connected(
                    pymongo.MongoClient(
                        host,
                        port,
                        username=db_user,
                        password=db_password,
                        directConnection=True,
                        connectTimeoutMS=connectTimeoutMS,
                        socketTimeoutMS=socketTimeoutMS,
                        serverSelectionTimeoutMS=serverSelectionTimeoutMS,
                    )
                )

        response = client.admin.command("ismaster")
        self.sessions_enabled = "logicalSessionTimeoutMinutes" in response
        self.is_mongos = response.get("msg") == "isdbgrid"
        if "setName" in response:
            self.is_replica_set = True
            self.rs_name = str(response["setName"])
            self.w = len(response["hosts"])
            self.hosts = set([partition_node(h) for h in response["hosts"]])
            host, port = self.primary = partition_node(response["primary"])
            self.arbiters = set([partition_node(h) for h in response.get("arbiters", [])])

            self.secondaries = [
                partition_node(m)
                for m in response["hosts"]
                if m != self.primary and m not in self.arbiters
            ]
        elif not self.is_mongos:
            self.is_standalone = True

        # Reconnect to found primary, without short timeouts.
        if self.mongod_started_with_ssl:
            client = connected(
                pymongo.MongoClient(
                    host,
                    port,
                    username=db_user,
                    password=db_password,
                    directConnection=True,
                    tlsCAFile=CA_PEM,
                    tlsCertificateKeyFile=CLIENT_PEM,
                )
            )
        else:
            client = connected(
                pymongo.MongoClient(
                    host,
                    port,
                    username=db_user,
                    password=db_password,
                    directConnection=True,
                    ssl=False,
                )
            )

        self.sync_cx = client
        self.host = host
        self.port = port

    def setup_auth_and_uri(self):
        """Set self.auth and self.uri."""
        if db_user or db_password:
            if not (db_user and db_password):
                sys.stderr.write("You must set both DB_USER and DB_PASSWORD, or neither\n")
                sys.exit(1)

            self.auth = True
            uri_template = "mongodb://%s:%s@%s:%s/admin"
            self.uri = uri_template % (db_user, db_password, self.host, self.port)

            # If the hostname 'server' is resolvable, this URI lets us use it
            # to test SSL hostname validation with auth.
            self.fake_hostname_uri = uri_template % (db_user, db_password, "server", self.port)

        else:
            self.uri = "mongodb://%s:%s/admin" % (self.host, self.port)

            self.fake_hostname_uri = "mongodb://%s:%s/admin" % ("server", self.port)

        if self.rs_name:
            self.rs_uri = self.uri + "?replicaSet=" + self.rs_name

    def setup_version(self):
        """Set self.version to the server's version."""
        self.version = Version.from_client(self.sync_cx)

    def setup_v8(self):
        """Determine if server is running SpiderMonkey or V8."""
        if self.sync_cx.server_info().get("javascriptEngine") == "V8":
            self.v8 = True

    @property
    def storage_engine(self):
        try:
            return self.server_status.get("storageEngine", {}).get("name")
        except AttributeError:
            # Raised if self.server_status is None.
            return None

    def supports_transactions(self):
        if self.storage_engine == "mmapv1":
            return False

        if self.version.at_least(4, 1, 8):
            return self.is_mongos or self.is_replica_set

        if self.version.at_least(4, 0):
            return self.is_replica_set

        return False

    def require(self, condition, msg, func=None):
        def make_wrapper(f):
            @wraps(f)
            def wrap(*args, **kwargs):
                assert self.initialized
                if condition():
                    return f(*args, **kwargs)
                raise SkipTest(msg)

            return wrap

        if func is None:

            def decorate(f):
                return make_wrapper(f)

            return decorate
        return make_wrapper(func)

    def require_auth(self, func):
        """Run a test only if the server is started with auth."""
        return self.require(lambda: self.auth, "Server must be start with auth", func=func)

    def require_version_min(self, *ver):
        """Run a test only if the server version is at least ``version``."""
        other_version = Version(*ver)
        return self.require(
            lambda: self.version >= other_version,
            "Server version must be at least %s" % str(other_version),
        )

    def require_version_max(self, *ver):
        """Run a test only if the server version is at most ``version``."""
        other_version = Version(*ver)
        return self.require(
            lambda: self.version <= other_version,
            "Server version must be at most %s" % str(other_version),
        )

    def require_no_standalone(self, func):
        """Run a test only if the client is not connected to a standalone."""
        return self.require(lambda: not self.is_standalone, "Connected to a standalone", func=func)

    def require_replica_set(self, func):
        """Run a test only if the client is connected to a replica set."""
        return self.require(
            lambda: self.is_replica_set, "Not connected to a replica set", func=func
        )

    def require_transactions(self, func):
        """Run a test only if the deployment might support transactions.

        *Might* because this does not test the FCV.
        """
        return self.require(self.supports_transactions, "Transactions are not supported", func=func)

    def require_csfle(self, func):
        """Run a test only if the deployment supports CSFLE."""
        return self.require(
            lambda: HAVE_PYMONGOCRYPT and self.version >= Version(4, 2),
            "CSFLE requires pymongocrypt and MongoDB >=4.2",
            func=func,
        )

    def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
        kwargs["writeConcern"] = {"w": self.w}
        return create_user(self.sync_cx[dbname], user, pwd, roles, **kwargs)

    def drop_user(self, dbname, user):
        self.sync_cx[dbname].command("dropUser", user, writeConcern={"w": self.w})


env = TestEnvironment()
if "SKIP_ENV_SETUP" not in os.environ:
    env.setup()
