File: thread_pool.py

package info (click to toggle)
python-b2sdk 2.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,020 kB
  • sloc: python: 30,902; sh: 13; makefile: 8
file content (114 lines) | stat: -rw-r--r-- 3,550 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
######################################################################
#
# File: b2sdk/_internal/utils/thread_pool.py
#
# Copyright 2022 Backblaze Inc. All Rights Reserved.
#
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################
from __future__ import annotations

import os
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Callable

try:
    from typing_extensions import Protocol
except ImportError:
    from typing import Protocol

from b2sdk._internal.utils import B2TraceMetaAbstract


class DynamicThreadPoolExecutorProtocol(Protocol):
    def submit(self, fn: Callable, *args, **kwargs) -> Future: ...

    def set_size(self, max_workers: int) -> None:
        """Set the size of the thread pool."""

    def get_size(self) -> int:
        """Return the current size of the thread pool."""


class LazyThreadPool:
    """
    Lazily initialized thread pool.
    """

    _THREAD_POOL_FACTORY = ThreadPoolExecutor

    def __init__(self, max_workers: int | None = None, **kwargs):
        if max_workers is None:
            max_workers = min(
                32, (os.cpu_count() or 1) + 4
            )  # same default as in ThreadPoolExecutor
        self._max_workers = max_workers
        self._thread_pool: ThreadPoolExecutor | None = None
        super().__init__(**kwargs)

    def submit(self, fn: Callable, *args, **kwargs) -> Future:
        if self._thread_pool is None:
            self._thread_pool = self._THREAD_POOL_FACTORY(self._max_workers)
        return self._thread_pool.submit(fn, *args, **kwargs)

    def set_size(self, max_workers: int) -> None:
        """
        Set the size of the thread pool.

        This operation will block until all tasks in the current thread pool are completed.

        :param max_workers: New size of the thread pool
        :return: None
        """
        if self._max_workers == max_workers:
            return
        old_thread_pool = self._thread_pool
        self._thread_pool = self._THREAD_POOL_FACTORY(max_workers=max_workers)
        if old_thread_pool is not None:
            old_thread_pool.shutdown(wait=True)
        self._max_workers = max_workers

    def get_size(self) -> int:
        """Return the current size of the thread pool."""
        return self._max_workers


class ThreadPoolMixin(metaclass=B2TraceMetaAbstract):
    """
    Mixin class with ThreadPoolExecutor.
    """

    DEFAULT_THREAD_POOL_CLASS = LazyThreadPool

    def __init__(
        self,
        thread_pool: DynamicThreadPoolExecutorProtocol | None = None,
        max_workers: int | None = None,
        **kwargs,
    ):
        """
        :param thread_pool: thread pool to be used
        :param max_workers: maximum number of worker threads (ignored if thread_pool is not None)
        """
        self._thread_pool = (
            thread_pool
            if thread_pool is not None
            else self.DEFAULT_THREAD_POOL_CLASS(max_workers=max_workers)
        )
        self._max_workers = max_workers
        super().__init__(**kwargs)

    def set_thread_pool_size(self, max_workers: int) -> None:
        """
        Set the size of the thread pool.

        This operation will block until all tasks in the current thread pool are completed.

        :param max_workers: New size of the thread pool
        :return: None
        """
        return self._thread_pool.set_size(max_workers)

    def get_thread_pool_size(self) -> int:
        return self._thread_pool.get_size()