# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from datetime import datetime, timedelta
from io import StringIO
import logging
import math
import os
import random
import zlib

import pytest

from devtools_testutils import AzureRecordedTestCase

from .. import FakeTokenCredential

try:
    from azure.storage.blob import (
        generate_account_sas,
        AccountSasPermissions,
        ResourceTypes,
    )
except:
    try:
        from azure.storage.queue import (
            generate_account_sas,
            AccountSasPermissions,
            ResourceTypes,
        )
    except:
        from azure.storage.fileshare import (
            generate_account_sas,
            AccountSasPermissions,
            ResourceTypes,
        )

LOGGING_FORMAT = "%(asctime)s %(name)-20s %(levelname)-5s %(message)s"
ENABLE_LOGGING = True


def generate_sas_token():
    fake_key = "a" * 30 + "b" * 30

    return "?" + generate_account_sas(
        account_name="test",  # name of the storage account
        account_key=fake_key,  # key for the storage account
        resource_types=ResourceTypes(object=True),
        permission=AccountSasPermissions(read=True, list=True),
        start=datetime.now() - timedelta(hours=24),
        expiry=datetime.now() + timedelta(days=8),
    )


class StorageRecordedTestCase(AzureRecordedTestCase):
    def setup_class(cls):
        cls.logger = logging.getLogger("azure.storage")
        cls.sas_token = generate_sas_token()

    def setup_method(self, _):
        self.configure_logging()

    def connection_string(self, account_name, key):
        return (
            "DefaultEndpointsProtocol=https;AcCounTName="
            + account_name
            + ";AccOuntKey="
            + str(key)
            + ";EndpoIntSuffix=core.windows.net"
        )

    def account_url(self, storage_account, storage_type):
        """Return an url of storage account.

        :param str storage_account: Storage account name
        :param str storage_type: The Storage type part of the URL. Should be "blob", or "queue", etc.
        """
        protocol = os.environ.get("PROTOCOL", "https")
        suffix = os.environ.get("ACCOUNT_URL_SUFFIX", "core.windows.net")
        return f"{protocol}://{storage_account}.{storage_type}.{suffix}"

    def configure_logging(self):
        enable_logging = ENABLE_LOGGING

        self.enable_logging() if enable_logging else self.disable_logging()

    def enable_logging(self):
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
        self.logger.handlers = [handler]
        self.logger.setLevel(logging.DEBUG)
        self.logger.propagate = True
        self.logger.disabled = False

    def disable_logging(self):
        self.logger.propagate = False
        self.logger.disabled = True
        self.logger.handlers = []

    def get_random_bytes(self, size):
        # recordings don't like random stuff. making this more
        # deterministic.
        return b"a" * size

    def get_random_text_data(self, size):
        """Returns random unicode text data exceeding the size threshold for
        chunking blob upload."""
        checksum = zlib.adler32(self.qualified_test_name.encode()) & 0xFFFFFFFF
        rand = random.Random(checksum)
        text = ""
        words = ["hello", "world", "python", "啊齄丂狛狜"]
        while len(text) < size:
            index = int(rand.random() * (len(words) - 1))
            text = text + " " + words[index]

        return text

    @staticmethod
    def _set_test_proxy(service, settings):
        if settings.USE_PROXY:
            service.set_proxy(
                settings.PROXY_HOST,
                settings.PROXY_PORT,
                settings.PROXY_USER,
                settings.PROXY_PASSWORD,
            )

    def assertNamedItemInContainer(self, container, item_name, msg=None):
        def _is_string(obj):
            return isinstance(obj, str)

        for item in container:
            if _is_string(item):
                if item == item_name:
                    return
            elif isinstance(item, dict):
                if item_name == item["name"]:
                    return
            elif item.name == item_name:
                return
            elif hasattr(item, "snapshot") and item.snapshot == item_name:
                return

        error_message = f"{repr(item_name)} not found in {[str(c) for c in container]}"
        pytest.fail(error_message)

    def assertNamedItemNotInContainer(self, container, item_name, msg=None):
        for item in container:
            if item.name == item_name:
                error_message = f"{repr(item_name)} unexpectedly found in {repr(container)}"
                pytest.fail(error_message)

    def assert_upload_progress(self, size, max_chunk_size, progress, unknown_size=False):
        """Validates that the progress chunks align with our chunking procedure."""
        total = None if unknown_size else size
        small_chunk_size = size % max_chunk_size
        assert len(progress) == math.ceil(size / max_chunk_size)
        for i in progress:
            assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
            assert i[1] == total

    def assert_download_progress(self, size, max_chunk_size, max_get_size, progress):
        """Validates that the progress chunks align with our chunking procedure."""
        if size <= max_get_size:
            assert len(progress) == 1
            assert progress[0][0], size
            assert progress[0][1], size
        else:
            small_chunk_size = (size - max_get_size) % max_chunk_size
            assert len(progress) == 1 + math.ceil((size - max_get_size) / max_chunk_size)

            assert progress[0][0], max_get_size
            assert progress[0][1], size
            for i in progress[1:]:
                assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
                assert i[1] == size

    def get_datetime_variable(self, variables, name, dt):
        dt_string = variables.setdefault(name, dt.isoformat())
        return datetime.strptime(dt_string, "%Y-%m-%dT%H:%M:%S.%f")


class LogCaptured(object):
    def __init__(self, test_case=None):
        # accept the test case so that we may reset logging after capturing logs
        self.test_case = test_case

    def __enter__(self):
        # enable logging
        # it is possible that the global logging flag is turned off
        self.test_case.enable_logging()

        # create a string stream to send the logs to
        self.log_stream = StringIO()

        # the handler needs to be stored so that we can remove it later
        self.handler = logging.StreamHandler(self.log_stream)
        self.handler.setFormatter(logging.Formatter(LOGGING_FORMAT))

        # get and enable the logger to send the outputs to the string stream
        self.logger = logging.getLogger("azure.storage")
        self.logger.level = logging.DEBUG
        self.logger.addHandler(self.handler)

        # the stream is returned to the user so that the capture logs can be retrieved
        return self.log_stream

    def __exit__(self, exc_type, exc_val, exc_tb):
        # stop the handler, and close the stream to exit
        self.logger.removeHandler(self.handler)
        self.log_stream.close()

        # reset logging since we messed with the setting
        self.test_case.configure_logging()
