# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import asyncio
from devtools_testutils.perfstress_tests import PerfStressTest

from azure.eventgrid import (
    EventGridPublisherClient as SyncPublisherClient,
    EventGridEvent,
)
from azure.eventgrid.aio import EventGridPublisherClient as AsyncPublisherClient

from azure.core.credentials import AzureKeyCredential


class EventGridPerfTest(PerfStressTest):
    def __init__(self, arguments):
        super().__init__(arguments)

        # auth configuration
        topic_key = self.get_from_env("EG_ACCESS_KEY")
        endpoint = self.get_from_env("EG_TOPIC_HOSTNAME")

        # Create clients
        self.publisher_client = SyncPublisherClient(endpoint=endpoint, credential=AzureKeyCredential(topic_key))
        self.async_publisher_client = AsyncPublisherClient(endpoint=endpoint, credential=AzureKeyCredential(topic_key))

        self.event_list = []
        for _ in range(self.args.num_events):
            self.event_list.append(
                EventGridEvent(
                    event_type="Contoso.Items.ItemReceived",
                    data={"services": ["EventGrid", "ServiceBus", "EventHubs", "Storage"]},
                    subject="Door1",
                    data_version="2.0",
                )
            )

    async def close(self):
        """This is run after cleanup.

        Use this to close any open handles or clients.
        """
        await self.async_publisher_client.close()
        await super().close()

    def run_sync(self):
        """The synchronous perf test.

        Try to keep this minimal and focused. Using only a single client API.
        Avoid putting any ancillary logic (e.g. generating UUIDs), and put this in the setup/init instead
        so that we're only measuring the client API call.
        """
        self.publisher_client.send(self.event_list)

    async def run_async(self):
        """The asynchronous perf test.

        Try to keep this minimal and focused. Using only a single client API.
        Avoid putting any ancillary logic (e.g. generating UUIDs), and put this in the setup/init instead
        so that we're only measuring the client API call.
        """
        await self.async_publisher_client.send(self.event_list)

    @staticmethod
    def add_arguments(parser):
        super(EventGridPerfTest, EventGridPerfTest).add_arguments(parser)
        parser.add_argument(
            "-n",
            "--num-events",
            nargs="?",
            type=int,
            help="Number of events to be sent. Defaults to 100",
            default=100,
        )
