File: worker.py

package info (click to toggle)
tryton-server 7.0.40-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,748 kB
  • sloc: python: 53,502; xml: 5,194; sh: 803; sql: 217; makefile: 28
file content (189 lines) | stat: -rw-r--r-- 6,875 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# This file is part of Tryton.  The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime as dt
import logging
import random
import selectors
import signal
import time
from multiprocessing import Pool as MPool
from multiprocessing import cpu_count

from sql import Flavor

from trytond import backend
from trytond.config import config
from trytond.exceptions import UserError, UserWarning
from trytond.pool import Pool
from trytond.status import processing
from trytond.transaction import Transaction, TransactionError

__all__ = ['work']
logger = logging.getLogger(__name__)


class Queue(object):
    def __init__(self, database_name, mpool):
        self.database = backend.Database(database_name).connect()
        self.connection = self.database.get_connection(autocommit=True)
        self.mpool = mpool

    def pull(self, name=None):
        database_list = Pool.database_list()
        pool = Pool(self.database.name)
        if self.database.name not in database_list:
            with Transaction().start(self.database.name, 0, readonly=True):
                pool.init()
        Queue = pool.get('ir.queue')
        return Queue.pull(self.database, self.connection, name=name)

    def run(self, task_id):
        return self.mpool.apply_async(run_task, (self.database.name, task_id))


class TaskList(list):
    def filter(self):
        for t in list(self):
            if t.ready():
                self.remove(t)
        return self


def work(options):
    Flavor.set(backend.Database.flavor)
    if not config.getboolean('queue', 'worker', default=False):
        return
    try:
        processes = options.processes or cpu_count()
    except NotImplementedError:
        processes = 1
    logger.info("start %d workers", processes)
    mpool = MPool(
        processes, initializer, (options.database_names,),
        options.maxtasksperchild)
    queues = [Queue(name, mpool) for name in options.database_names]

    tasks = TaskList()
    selector = selectors.DefaultSelector()
    for queue in queues:
        selector.register(queue.connection, selectors.EVENT_READ)
    try:
        while True:
            timeout = options.timeout
            # Add some randomness to avoid concurrent pulling
            time.sleep(0.1 * random.random())
            while len(tasks.filter()) >= processes:
                time.sleep(0.1)
            for queue in queues:
                try:
                    task_id, next_ = queue.pull(options.name)
                except backend.DatabaseOperationalError:
                    break
                if next_ is not None:
                    timeout = min(next_, timeout)
                if task_id:
                    tasks.append(queue.run(task_id))
                    break
            else:
                for key, _ in selector.select(timeout=timeout):
                    connection = key.fileobj
                    connection.poll()
                    while connection.notifies:
                        connection.notifies.pop(0)
    except KeyboardInterrupt:
        mpool.close()
    finally:
        selector.close()


def initializer(database_names, worker=True):
    if worker:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
    pools = []
    database_list = Pool.database_list()
    for database_name in database_names:
        pool = Pool(database_name)
        if database_name not in database_list:
            with Transaction().start(database_name, 0, readonly=True):
                pool.init()
        pools.append(pool)
    return pools


def run_task(pool, task_id):
    if not isinstance(pool, Pool):
        database_list = Pool.database_list()
        pool = Pool(pool)
        if pool.database_name not in database_list:
            with Transaction().start(pool.database_name, 0, readonly=True):
                pool.init()
    Queue = pool.get('ir.queue')
    Error = pool.get('ir.error')

    def duration():
        return (time.monotonic() - started) * 1000
    started = time.monotonic()
    name = '<Task %s@%s>' % (task_id, pool.database_name)
    retry = config.getint('database', 'retry')
    try:
        count = 0
        transaction_extras = {}
        while True:
            if count:
                time.sleep(0.02 * count)
            with Transaction().start(
                    pool.database_name, 0,
                    **transaction_extras) as transaction:
                try:
                    try:
                        task, = Queue.search([('id', '=', task_id)])
                    except ValueError:
                        # the task was rollbacked, nothing to do
                        break
                    with processing(name):
                        task.run()
                    break
                except TransactionError as e:
                    transaction.rollback()
                    e.fix(transaction_extras)
                    continue
                except backend.DatabaseOperationalError:
                    if count < retry:
                        transaction.rollback()
                        count += 1
                        logger.debug("Retry: %i", count)
                        continue
                    raise
                except (UserError, UserWarning) as e:
                    Error.report(task, e)
                    raise
        logger.info("%s in %i ms", name, duration())
    except backend.DatabaseOperationalError:
        logger.info(
            "%s failed after %i ms, retrying", name, duration(),
            exc_info=logger.isEnabledFor(logging.DEBUG))
        if not config.getboolean('queue', 'worker', default=False):
            time.sleep(0.02 * retry)
        try:
            with Transaction().start(pool.database_name, 0) as transaction:
                if not transaction.database.has_channel():
                    logger.critical('%s failed', name, exc_info=True)
                    return
                task = Queue(task_id)
                if task.scheduled_at and task.enqueued_at < task.scheduled_at:
                    duration = (task.scheduled_at - task.enqueued_at) * 2
                else:
                    duration = dt.timedelta(seconds=2 * retry)
                duration = min(duration, dt.timedelta(hours=1))
                scheduled_at = dt.datetime.now() + duration * random.random()
                Queue.push(task.name, task.data, scheduled_at=scheduled_at)
        except Exception:
            logger.critical(
                "rescheduling %s failed", name, exc_info=True)
    except (UserError, UserWarning):
        logger.info(
            "%s failed after %i ms", name, duration(),
            exc_info=logger.isEnabledFor(logging.DEBUG))
    except Exception:
        logger.critical(
            "%s failed after %i ms", name, duration(), exc_info=True)