﻿# coding: utf-8

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest

import azure.mgmt.sql

from devtools_testutils import AzureMgmtTestCase, ResourceGroupPreparer, AzureMgmtPreparer, FakeResource
from devtools_testutils.fake_credentials import FAKE_LOGIN_PASSWORD


def get_server_params(location):
    return {
        "location": "westus2",  # location, # "self.region" is 'west-us' by default
        "version": "12.0",
        "administrator_login": "mysecretname",
        "administrator_login_password": FAKE_LOGIN_PASSWORD,  # this password may not work -- check when tests are active
    }


class SqlServerPreparer(AzureMgmtPreparer):
    def __init__(self, name_prefix="mypysqlserverx"):
        super(SqlServerPreparer, self).__init__(name_prefix, 24)

    def create_resource(self, name, **kwargs):
        if self.is_live:
            async_server_create = self.test_class_instance.client.servers.create_or_update(
                kwargs["resource_group"].name, name, get_server_params(kwargs["location"])
            )
            server = async_server_create.result()
        else:
            server = FakeResource(name=name, id="")

        return {"server": server}


class MgmtSqlTest(AzureMgmtTestCase):

    def setUp(self):
        super(MgmtSqlTest, self).setUp()
        self.client = self.create_mgmt_client(azure.mgmt.sql.SqlManagementClient)

    @ResourceGroupPreparer()
    def test_server(self, resource_group, location):
        server_name = self.get_resource_name("tstpysqlserverx")

        async_server_create = self.client.servers.create_or_update(
            resource_group.name,  # Created by the framework
            server_name,
            get_server_params(location),
        )
        server = async_server_create.result()
        self.assertEqual(server.name, server_name)

        server = self.client.servers.get(resource_group.name, server_name)
        self.assertEqual(server.name, server_name)

        my_servers = list(self.client.servers.list_by_resource_group(resource_group.name))
        self.assertEqual(len(my_servers), 1)
        self.assertEqual(my_servers[0].name, server_name)

        my_servers = list(self.client.servers.list())
        self.assertTrue(len(my_servers) >= 1)
        self.assertTrue(any(server.name == server_name for server in my_servers))

        usages = list(self.client.server_usages.list_by_server(resource_group.name, server_name))
        self.assertTrue(any(usage.name == "server_dtu_quota" for usage in usages))

        firewall_rule_name = self.get_resource_name("firewallrule")
        firewall_rule = self.client.firewall_rules.create_or_update(
            resource_group.name, server_name, firewall_rule_name, "123.123.123.123", "123.123.123.124"
        )
        self.assertEqual(firewall_rule.name, firewall_rule_name)
        self.assertEqual(firewall_rule.start_ip_address, "123.123.123.123")
        self.assertEqual(firewall_rule.end_ip_address, "123.123.123.124")

        self.client.servers.delete(resource_group.name, server_name, polling=False)

    @ResourceGroupPreparer()
    @SqlServerPreparer()
    def test_database(self, resource_group, location, server):
        db_name = self.get_resource_name("pyarmdb")

        async_db_create = self.client.databases.create_or_update(
            resource_group.name, server.name, db_name, {"location": "westus2"}  # location
        )
        database = async_db_create.result()  # Wait for completion and return created object
        self.assertEqual(database.name, db_name)

        db = self.client.databases.get(resource_group.name, server.name, db_name)
        self.assertEqual(db.name, db_name)

        my_dbs = list(self.client.databases.list_by_server(resource_group.name, server.name))
        print([db.name for db in my_dbs])
        self.assertEqual(len(my_dbs), 2)
        self.assertTrue(any(db.name == "master" for db in my_dbs))
        self.assertTrue(any(db.name == db_name for db in my_dbs))

        usages = list(self.client.database_usages.list_by_database(resource_group.name, server.name, db_name))
        self.assertTrue(any(usage.name == "database_size" for usage in usages))

        self.client.databases.delete(resource_group.name, server.name, db_name).wait()


# ------------------------------------------------------------------------------
if __name__ == "__main__":
    unittest.main()
