#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2021
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program.  If not, see [http://www.gnu.org/licenses/].
import logging
from queue import Queue
from threading import current_thread
from time import sleep

import pytest

from telegram import TelegramError, Message, User, Chat, Update, Bot, MessageEntity
from telegram.ext import (
    MessageHandler,
    Filters,
    Defaults,
    CommandHandler,
    CallbackContext,
    JobQueue,
    BasePersistence,
)
from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DEFAULT_FALSE
from tests.conftest import create_dp
from collections import defaultdict


@pytest.fixture(scope='function')
def dp2(bot):
    yield from create_dp(bot)


class TestDispatcher:
    message_update = Update(
        1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
    )
    received = None
    count = 0

    @pytest.fixture(autouse=True, name='reset')
    def reset_fixture(self):
        self.reset()

    def reset(self):
        self.received = None
        self.count = 0

    def error_handler(self, bot, update, error):
        self.received = error.message

    def error_handler_context(self, update, context):
        self.received = context.error.message

    def error_handler_raise_error(self, bot, update, error):
        raise Exception('Failing bigly')

    def callback_increase_count(self, bot, update):
        self.count += 1

    def callback_set_count(self, count):
        def callback(bot, update):
            self.count = count

        return callback

    def callback_raise_error(self, bot, update):
        if isinstance(bot, Bot):
            raise TelegramError(update.message.text)
        raise TelegramError(bot.message.text)

    def callback_if_not_update_queue(self, bot, update, update_queue=None):
        if update_queue is not None:
            self.received = update.message

    def callback_context(self, update, context):
        if (
            isinstance(context, CallbackContext)
            and isinstance(context.bot, Bot)
            and isinstance(context.update_queue, Queue)
            and isinstance(context.job_queue, JobQueue)
            and isinstance(context.error, TelegramError)
        ):
            self.received = context.error.message

    def test_one_context_per_update(self, cdp):
        def one(update, context):
            if update.message.text == 'test':
                context.my_flag = True

        def two(update, context):
            if update.message.text == 'test':
                if not hasattr(context, 'my_flag'):
                    pytest.fail()
            else:
                if hasattr(context, 'my_flag'):
                    pytest.fail()

        cdp.add_handler(MessageHandler(Filters.regex('test'), one), group=1)
        cdp.add_handler(MessageHandler(None, two), group=2)
        u = Update(1, Message(1, None, None, None, text='test'))
        cdp.process_update(u)
        u.message.text = 'something'
        cdp.process_update(u)

    def test_error_handler(self, dp):
        dp.add_error_handler(self.error_handler)
        error = TelegramError('Unauthorized.')
        dp.update_queue.put(error)
        sleep(0.1)
        assert self.received == 'Unauthorized.'

        # Remove handler
        dp.remove_error_handler(self.error_handler)
        self.reset()

        dp.update_queue.put(error)
        sleep(0.1)
        assert self.received is None

    def test_double_add_error_handler(self, dp, caplog):
        dp.add_error_handler(self.error_handler)
        with caplog.at_level(logging.DEBUG):
            dp.add_error_handler(self.error_handler)
            assert len(caplog.records) == 1
            assert caplog.records[-1].getMessage().startswith('The callback is already registered')

    def test_construction_with_bad_persistence(self, caplog, bot):
        class my_per:
            def __init__(self):
                self.store_user_data = False
                self.store_chat_data = False
                self.store_bot_data = False

        with pytest.raises(
            TypeError, match='persistence must be based on telegram.ext.BasePersistence'
        ):
            Dispatcher(bot, None, persistence=my_per())

    def test_error_handler_that_raises_errors(self, dp):
        """
        Make sure that errors raised in error handlers don't break the main loop of the dispatcher
        """
        handler_raise_error = MessageHandler(Filters.all, self.callback_raise_error)
        handler_increase_count = MessageHandler(Filters.all, self.callback_increase_count)
        error = TelegramError('Unauthorized.')

        dp.add_error_handler(self.error_handler_raise_error)

        # From errors caused by handlers
        dp.add_handler(handler_raise_error)
        dp.update_queue.put(self.message_update)
        sleep(0.1)

        # From errors in the update_queue
        dp.remove_handler(handler_raise_error)
        dp.add_handler(handler_increase_count)
        dp.update_queue.put(error)
        dp.update_queue.put(self.message_update)
        sleep(0.1)

        assert self.count == 1

    @pytest.mark.parametrize(['run_async', 'expected_output'], [(True, 5), (False, 0)])
    def test_default_run_async_error_handler(self, dp, monkeypatch, run_async, expected_output):
        def mock_async_err_handler(*args, **kwargs):
            self.count = 5

        # set defaults value to dp.bot
        dp.bot.defaults = Defaults(run_async=run_async)
        try:
            dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
            dp.add_error_handler(self.error_handler)

            monkeypatch.setattr(dp, 'run_async', mock_async_err_handler)
            dp.process_update(self.message_update)

            assert self.count == expected_output

        finally:
            # reset dp.bot.defaults values
            dp.bot.defaults = None

    @pytest.mark.parametrize(
        ['run_async', 'expected_output'], [(True, 'running async'), (False, None)]
    )
    def test_default_run_async(self, monkeypatch, dp, run_async, expected_output):
        def mock_run_async(*args, **kwargs):
            self.received = 'running async'

        # set defaults value to dp.bot
        dp.bot.defaults = Defaults(run_async=run_async)
        try:
            dp.add_handler(MessageHandler(Filters.all, lambda u, c: None))
            monkeypatch.setattr(dp, 'run_async', mock_run_async)
            dp.process_update(self.message_update)
            assert self.received == expected_output

        finally:
            # reset defaults value
            dp.bot.defaults = None

    def test_run_async_multiple(self, bot, dp, dp2):
        def get_dispatcher_name(q):
            q.put(current_thread().name)

        q1 = Queue()
        q2 = Queue()

        dp.run_async(get_dispatcher_name, q1)
        dp2.run_async(get_dispatcher_name, q2)

        sleep(0.1)

        name1 = q1.get()
        name2 = q2.get()

        assert name1 != name2

    def test_multiple_run_async_decorator(self, dp, dp2):
        # Make sure we got two dispatchers and that they are not the same
        assert isinstance(dp, Dispatcher)
        assert isinstance(dp2, Dispatcher)
        assert dp is not dp2

        @run_async
        def must_raise_runtime_error():
            pass

        with pytest.raises(RuntimeError):
            must_raise_runtime_error()

    def test_run_async_with_args(self, dp):
        dp.add_handler(
            MessageHandler(
                Filters.all, run_async(self.callback_if_not_update_queue), pass_update_queue=True
            )
        )

        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.received == self.message_update.message

    def test_multiple_run_async_deprecation(self, dp):
        assert isinstance(dp, Dispatcher)

        @run_async
        def callback(update, context):
            pass

        dp.add_handler(MessageHandler(Filters.all, callback))

        with pytest.warns(TelegramDeprecationWarning, match='@run_async decorator'):
            dp.process_update(self.message_update)

    def test_async_raises_dispatcher_handler_stop(self, dp, caplog):
        @run_async
        def callback(update, context):
            raise DispatcherHandlerStop()

        dp.add_handler(MessageHandler(Filters.all, callback))

        with caplog.at_level(logging.WARNING):
            dp.update_queue.put(self.message_update)
            sleep(0.1)
            assert len(caplog.records) == 1
            assert (
                caplog.records[-1]
                .getMessage()
                .startswith('DispatcherHandlerStop is not supported ' 'with async functions')
            )

    def test_async_raises_exception(self, dp, caplog):
        @run_async
        def callback(update, context):
            raise RuntimeError('async raising exception')

        dp.add_handler(MessageHandler(Filters.all, callback))

        with caplog.at_level(logging.WARNING):
            dp.update_queue.put(self.message_update)
            sleep(0.1)
            assert len(caplog.records) == 1
            assert (
                caplog.records[-1]
                .getMessage()
                .startswith('A promise with deactivated error handling')
            )

    def test_add_async_handler(self, dp):
        dp.add_handler(
            MessageHandler(
                Filters.all,
                self.callback_if_not_update_queue,
                pass_update_queue=True,
                run_async=True,
            )
        )

        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.received == self.message_update.message

    def test_run_async_no_error_handler(self, dp, caplog):
        def func():
            raise RuntimeError('Async Error')

        with caplog.at_level(logging.ERROR):
            dp.run_async(func)
            sleep(0.1)
            assert len(caplog.records) == 1
            assert caplog.records[-1].getMessage().startswith('No error handlers are registered')

    def test_async_handler_error_handler(self, dp):
        dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error, run_async=True))
        dp.add_error_handler(self.error_handler)

        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.received == self.message_update.message.text

    def test_async_handler_async_error_handler_context(self, cdp):
        cdp.add_handler(MessageHandler(Filters.all, self.callback_raise_error, run_async=True))
        cdp.add_error_handler(self.error_handler_context, run_async=True)

        cdp.update_queue.put(self.message_update)
        sleep(2)
        assert self.received == self.message_update.message.text

    def test_async_handler_error_handler_that_raises_error(self, dp, caplog):
        handler = MessageHandler(Filters.all, self.callback_raise_error, run_async=True)
        dp.add_handler(handler)
        dp.add_error_handler(self.error_handler_raise_error, run_async=False)

        with caplog.at_level(logging.ERROR):
            dp.update_queue.put(self.message_update)
            sleep(0.1)
            assert len(caplog.records) == 1
            assert caplog.records[-1].getMessage().startswith('An uncaught error was raised')

        # Make sure that the main loop still runs
        dp.remove_handler(handler)
        dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count, run_async=True))
        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.count == 1

    def test_async_handler_async_error_handler_that_raises_error(self, dp, caplog):
        handler = MessageHandler(Filters.all, self.callback_raise_error, run_async=True)
        dp.add_handler(handler)
        dp.add_error_handler(self.error_handler_raise_error, run_async=True)

        with caplog.at_level(logging.ERROR):
            dp.update_queue.put(self.message_update)
            sleep(0.1)
            assert len(caplog.records) == 1
            assert caplog.records[-1].getMessage().startswith('An uncaught error was raised')

        # Make sure that the main loop still runs
        dp.remove_handler(handler)
        dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count, run_async=True))
        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.count == 1

    def test_error_in_handler(self, dp):
        dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
        dp.add_error_handler(self.error_handler)

        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.received == self.message_update.message.text

    def test_add_remove_handler(self, dp):
        handler = MessageHandler(Filters.all, self.callback_increase_count)
        dp.add_handler(handler)
        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.count == 1
        dp.remove_handler(handler)
        dp.update_queue.put(self.message_update)
        assert self.count == 1

    def test_add_remove_handler_non_default_group(self, dp):
        handler = MessageHandler(Filters.all, self.callback_increase_count)
        dp.add_handler(handler, group=2)
        with pytest.raises(KeyError):
            dp.remove_handler(handler)
        dp.remove_handler(handler, group=2)

    def test_error_start_twice(self, dp):
        assert dp.running
        dp.start()

    def test_handler_order_in_group(self, dp):
        dp.add_handler(MessageHandler(Filters.photo, self.callback_set_count(1)))
        dp.add_handler(MessageHandler(Filters.all, self.callback_set_count(2)))
        dp.add_handler(MessageHandler(Filters.text, self.callback_set_count(3)))
        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.count == 2

    def test_groups(self, dp):
        dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count))
        dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=2)
        dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=-1)

        dp.update_queue.put(self.message_update)
        sleep(0.1)
        assert self.count == 3

    def test_add_handler_errors(self, dp):
        handler = 'not a handler'
        with pytest.raises(TypeError, match='handler is not an instance of'):
            dp.add_handler(handler)

        handler = MessageHandler(Filters.photo, self.callback_set_count(1))
        with pytest.raises(TypeError, match='group is not int'):
            dp.add_handler(handler, 'one')

    def test_flow_stop(self, dp, bot):
        passed = []

        def start1(b, u):
            passed.append('start1')
            raise DispatcherHandlerStop

        def start2(b, u):
            passed.append('start2')

        def start3(b, u):
            passed.append('start3')

        def error(b, u, e):
            passed.append('error')
            passed.append(e)

        update = Update(
            1,
            message=Message(
                1,
                None,
                None,
                None,
                text='/start',
                entities=[
                    MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
                ],
                bot=bot,
            ),
        )

        # If Stop raised handlers in other groups should not be called.
        passed = []
        dp.add_handler(CommandHandler('start', start1), 1)
        dp.add_handler(CommandHandler('start', start3), 1)
        dp.add_handler(CommandHandler('start', start2), 2)
        dp.process_update(update)
        assert passed == ['start1']

    def test_exception_in_handler(self, dp, bot):
        passed = []
        err = Exception('General exception')

        def start1(b, u):
            passed.append('start1')
            raise err

        def start2(b, u):
            passed.append('start2')

        def start3(b, u):
            passed.append('start3')

        def error(b, u, e):
            passed.append('error')
            passed.append(e)

        update = Update(
            1,
            message=Message(
                1,
                None,
                None,
                None,
                text='/start',
                entities=[
                    MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
                ],
                bot=bot,
            ),
        )

        # If an unhandled exception was caught, no further handlers from the same group should be
        # called. Also, the error handler should be called and receive the exception
        passed = []
        dp.add_handler(CommandHandler('start', start1), 1)
        dp.add_handler(CommandHandler('start', start2), 1)
        dp.add_handler(CommandHandler('start', start3), 2)
        dp.add_error_handler(error)
        dp.process_update(update)
        assert passed == ['start1', 'error', err, 'start3']

    def test_telegram_error_in_handler(self, dp, bot):
        passed = []
        err = TelegramError('Telegram error')

        def start1(b, u):
            passed.append('start1')
            raise err

        def start2(b, u):
            passed.append('start2')

        def start3(b, u):
            passed.append('start3')

        def error(b, u, e):
            passed.append('error')
            passed.append(e)

        update = Update(
            1,
            message=Message(
                1,
                None,
                None,
                None,
                text='/start',
                entities=[
                    MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
                ],
                bot=bot,
            ),
        )

        # If a TelegramException was caught, an error handler should be called and no further
        # handlers from the same group should be called.
        dp.add_handler(CommandHandler('start', start1), 1)
        dp.add_handler(CommandHandler('start', start2), 1)
        dp.add_handler(CommandHandler('start', start3), 2)
        dp.add_error_handler(error)
        dp.process_update(update)
        assert passed == ['start1', 'error', err, 'start3']
        assert passed[2] is err

    def test_error_while_saving_chat_data(self, bot):
        increment = []

        class OwnPersistence(BasePersistence):
            def __init__(self):
                super().__init__()
                self.store_user_data = True
                self.store_chat_data = True
                self.store_bot_data = True

            def get_bot_data(self):
                return dict()

            def update_bot_data(self, data):
                raise Exception

            def get_chat_data(self):
                return defaultdict(dict)

            def update_chat_data(self, chat_id, data):
                raise Exception

            def get_user_data(self):
                return defaultdict(dict)

            def update_user_data(self, user_id, data):
                raise Exception

            def get_conversations(self, name):
                pass

            def update_conversation(self, name, key, new_state):
                pass

        def start1(b, u):
            pass

        def error(b, u, e):
            increment.append("error")

        # If updating a user_data or chat_data from a persistence object throws an error,
        # the error handler should catch it

        update = Update(
            1,
            message=Message(
                1,
                None,
                Chat(1, "lala"),
                from_user=User(1, "Test", False),
                text='/start',
                entities=[
                    MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
                ],
                bot=bot,
            ),
        )
        my_persistence = OwnPersistence()
        dp = Dispatcher(bot, None, persistence=my_persistence, use_context=False)
        dp.add_handler(CommandHandler('start', start1))
        dp.add_error_handler(error)
        dp.process_update(update)
        assert increment == ["error", "error", "error"]

    def test_flow_stop_in_error_handler(self, dp, bot):
        passed = []
        err = TelegramError('Telegram error')

        def start1(b, u):
            passed.append('start1')
            raise err

        def start2(b, u):
            passed.append('start2')

        def start3(b, u):
            passed.append('start3')

        def error(b, u, e):
            passed.append('error')
            passed.append(e)
            raise DispatcherHandlerStop

        update = Update(
            1,
            message=Message(
                1,
                None,
                None,
                None,
                text='/start',
                entities=[
                    MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
                ],
                bot=bot,
            ),
        )

        # If a TelegramException was caught, an error handler should be called and no further
        # handlers from the same group should be called.
        dp.add_handler(CommandHandler('start', start1), 1)
        dp.add_handler(CommandHandler('start', start2), 1)
        dp.add_handler(CommandHandler('start', start3), 2)
        dp.add_error_handler(error)
        dp.process_update(update)
        assert passed == ['start1', 'error', err]
        assert passed[2] is err

    def test_error_handler_context(self, cdp):
        cdp.add_error_handler(self.callback_context)

        error = TelegramError('Unauthorized.')
        cdp.update_queue.put(error)
        sleep(0.1)
        assert self.received == 'Unauthorized.'

    def test_sensible_worker_thread_names(self, dp2):
        thread_names = [thread.name for thread in getattr(dp2, '_Dispatcher__async_threads')]
        print(thread_names)
        for thread_name in thread_names:
            assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:")

    def test_non_context_deprecation(self, dp):
        with pytest.warns(TelegramDeprecationWarning):
            Dispatcher(
                dp.bot, dp.update_queue, job_queue=dp.job_queue, workers=0, use_context=False
            )

    def test_error_while_persisting(self, cdp, monkeypatch):
        class OwnPersistence(BasePersistence):
            def __init__(self):
                super().__init__()
                self.store_user_data = True
                self.store_chat_data = True
                self.store_bot_data = True

            def update(self, data):
                raise Exception('PersistenceError')

            def update_bot_data(self, data):
                self.update(data)

            def update_chat_data(self, chat_id, data):
                self.update(data)

            def update_user_data(self, user_id, data):
                self.update(data)

            def get_chat_data(self):
                pass

            def get_bot_data(self):
                pass

            def get_user_data(self):
                pass

            def get_conversations(self, name):
                pass

            def update_conversation(self, name, key, new_state):
                pass

        def callback(update, context):
            pass

        test_flag = False

        def error(update, context):
            nonlocal test_flag
            test_flag = str(context.error) == 'PersistenceError'
            raise Exception('ErrorHandlingError')

        def logger(message):
            assert 'uncaught error was raised while handling' in message

        update = Update(
            1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
        )
        handler = MessageHandler(Filters.all, callback)
        cdp.add_handler(handler)
        cdp.add_error_handler(error)
        monkeypatch.setattr(cdp.logger, 'exception', logger)

        cdp.persistence = OwnPersistence()
        cdp.process_update(update)
        assert test_flag

    def test_persisting_no_user_no_chat(self, cdp):
        class OwnPersistence(BasePersistence):
            def __init__(self):
                super().__init__()
                self.store_user_data = True
                self.store_chat_data = True
                self.store_bot_data = True
                self.test_flag_bot_data = False
                self.test_flag_chat_data = False
                self.test_flag_user_data = False

            def update_bot_data(self, data):
                self.test_flag_bot_data = True

            def update_chat_data(self, chat_id, data):
                self.test_flag_chat_data = True

            def update_user_data(self, user_id, data):
                self.test_flag_user_data = True

            def update_conversation(self, name, key, new_state):
                pass

            def get_conversations(self, name):
                pass

            def get_user_data(self):
                pass

            def get_bot_data(self):
                pass

            def get_chat_data(self):
                pass

        def callback(update, context):
            pass

        handler = MessageHandler(Filters.all, callback)
        cdp.add_handler(handler)
        cdp.persistence = OwnPersistence()

        update = Update(
            1, message=Message(1, None, None, from_user=User(1, '', False), text='Text')
        )
        cdp.process_update(update)
        assert cdp.persistence.test_flag_bot_data
        assert cdp.persistence.test_flag_user_data
        assert not cdp.persistence.test_flag_chat_data

        cdp.persistence.test_flag_bot_data = False
        cdp.persistence.test_flag_user_data = False
        cdp.persistence.test_flag_chat_data = False
        update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
        cdp.process_update(update)
        assert cdp.persistence.test_flag_bot_data
        assert not cdp.persistence.test_flag_user_data
        assert cdp.persistence.test_flag_chat_data

    def test_update_persistence_once_per_update(self, monkeypatch, dp):
        def update_persistence(*args, **kwargs):
            self.count += 1

        def dummy_callback(*args):
            pass

        monkeypatch.setattr(dp, 'update_persistence', update_persistence)

        for group in range(5):
            dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)

        update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None))
        dp.process_update(update)
        assert self.count == 0

        update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text'))
        dp.process_update(update)
        assert self.count == 1

    def test_update_persistence_all_async(self, monkeypatch, dp):
        def update_persistence(*args, **kwargs):
            self.count += 1

        def dummy_callback(*args, **kwargs):
            pass

        monkeypatch.setattr(dp, 'update_persistence', update_persistence)
        monkeypatch.setattr(dp, 'run_async', dummy_callback)

        for group in range(5):
            dp.add_handler(
                MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
            )

        update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
        dp.process_update(update)
        assert self.count == 0

        dp.bot.defaults = Defaults(run_async=True)
        try:
            for group in range(5):
                dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)

            update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
            dp.process_update(update)
            assert self.count == 0
        finally:
            dp.bot.defaults = None

    @pytest.mark.parametrize('run_async', [DEFAULT_FALSE, False])
    def test_update_persistence_one_sync(self, monkeypatch, dp, run_async):
        def update_persistence(*args, **kwargs):
            self.count += 1

        def dummy_callback(*args, **kwargs):
            pass

        monkeypatch.setattr(dp, 'update_persistence', update_persistence)
        monkeypatch.setattr(dp, 'run_async', dummy_callback)

        for group in range(5):
            dp.add_handler(
                MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
            )
        dp.add_handler(MessageHandler(Filters.text, dummy_callback, run_async=run_async), group=5)

        update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
        dp.process_update(update)
        assert self.count == 1

    @pytest.mark.parametrize('run_async,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)])
    def test_update_persistence_defaults_async(self, monkeypatch, dp, run_async, expected):
        def update_persistence(*args, **kwargs):
            self.count += 1

        def dummy_callback(*args, **kwargs):
            pass

        monkeypatch.setattr(dp, 'update_persistence', update_persistence)
        monkeypatch.setattr(dp, 'run_async', dummy_callback)
        dp.bot.defaults = Defaults(run_async=run_async)

        try:
            for group in range(5):
                dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)

            update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
            dp.process_update(update)
            assert self.count == expected
        finally:
            dp.bot.defaults = None
