# Copyright 2012-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for testing Motor with Tornado."""

import concurrent.futures
import functools
from test.test_environment import CA_PEM, CLIENT_PEM, env
from test.version import Version
from unittest import SkipTest

from bson import SON
from mockupdb import MockupDB
from tornado import testing

import motor


async def get_command_line(client):
    command_line = await client.admin.command("getCmdLineOpts")
    assert command_line["ok"] == 1, "getCmdLineOpts() failed"
    return command_line


async def server_is_mongos(client):
    ismaster_response = await client.admin.command("ismaster")
    return ismaster_response.get("msg") == "isdbgrid"


async def skip_if_mongos(client):
    is_mongos = await server_is_mongos(client)
    if is_mongos:
        raise SkipTest("connected to mongos")


async def remove_all_users(db):
    await db.command({"dropAllUsersFromDatabase": 1})


class MotorTest(testing.AsyncTestCase):
    longMessage = True  # Used by unittest.TestCase
    ssl = False  # If True, connect with SSL, skip if mongod isn't SSL

    def setUp(self):
        super().setUp()

        if self.ssl and not env.mongod_started_with_ssl:
            raise SkipTest("mongod doesn't support SSL, or is down")

        self.cx = self.motor_client()
        self.db = self.cx.motor_test
        self.collection = self.db.test_collection

    async def make_test_data(self):
        await self.collection.delete_many({})
        await self.collection.insert_many([{"_id": i} for i in range(200)])

    # Workaround for https://github.com/pytest-dev/pytest/issues/12263.
    def runTest(self):
        pass

    make_test_data.__test__ = False

    async def set_fail_point(self, client, command_args):
        cmd = SON([("configureFailPoint", "failCommand")])
        cmd.update(command_args)
        await client.admin.command(cmd)

    def get_client_kwargs(self, **kwargs):
        if env.mongod_started_with_ssl:
            kwargs.setdefault("tlsCAFile", CA_PEM)
            kwargs.setdefault("tlsCertificateKeyFile", CLIENT_PEM)

        kwargs.setdefault("tls", env.mongod_started_with_ssl)
        kwargs.setdefault("io_loop", self.io_loop)

        return kwargs

    def motor_client(self, uri=None, *args, **kwargs):
        """Get a MotorClient.

        Ignores self.ssl, you must pass 'ssl' argument. You'll probably need to
        close the client to avoid file-descriptor problems after AsyncTestCase
        calls self.io_loop.close(all_fds=True).
        """
        return motor.MotorClient(uri or env.uri, *args, **self.get_client_kwargs(**kwargs))

    def motor_rsc(self, uri=None, *args, **kwargs):
        """Get an open MotorClient for replica set.

        Ignores self.ssl, you must pass 'ssl' argument.
        """
        return motor.MotorClient(uri or env.rs_uri, *args, **self.get_client_kwargs(**kwargs))

    def tearDown(self):
        env.sync_cx.motor_test.test_collection.delete_many({})
        if hasattr(self, "cx"):
            self.cx.close()
        super().tearDown()


class MotorReplicaSetTestBase(MotorTest):
    def setUp(self):
        super().setUp()
        if not env.is_replica_set:
            raise SkipTest("Not connected to a replica set")

        self.rsc = self.motor_rsc()


class MotorMockServerTest(MotorTest):
    executor = concurrent.futures.ThreadPoolExecutor(1)

    def server(self, *args, **kwargs):
        server = MockupDB(*args, **kwargs)
        server.run()
        self.addCleanup(server.stop)
        return server

    def client_server(self, *args, **kwargs):
        server = self.server(*args, **kwargs)
        client = motor.motor_tornado.MotorClient(server.uri, io_loop=self.io_loop)

        self.addCleanup(client.close)
        return client, server

    async def run_thread(self, fn, *args, **kwargs):
        return await self.io_loop.run_in_executor(None, functools.partial(fn, *args, **kwargs))


class AsyncVersion(Version):
    """Version class that can be instantiated with an async client from
    within a coroutine."""

    @classmethod
    async def from_client(cls, client):
        info = await client.server_info()
        if "versionArray" in info:
            return cls.from_version_array(info["versionArray"])
        return cls.from_string(info["version"])
