File: process_worker_pool.py

package info (click to toggle)
python-parsl 2025.01.13%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,072 kB
  • sloc: python: 23,817; makefile: 349; sh: 276; ansic: 45
file content (930 lines) | stat: -rwxr-xr-x 41,865 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
#!/usr/bin/env python3

import argparse
import json
import logging
import math
import multiprocessing
import os
import pickle
import platform
import queue
import subprocess
import sys
import threading
import time
import uuid
from multiprocessing.managers import DictProxy
from multiprocessing.sharedctypes import Synchronized
from typing import Dict, List, Optional, Sequence

import psutil
import zmq

from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.execute_task import execute_task
from parsl.executors.high_throughput.errors import WorkerLost
from parsl.executors.high_throughput.mpi_prefix_composer import (
    VALID_LAUNCHERS,
    compose_all,
)
from parsl.executors.high_throughput.mpi_resource_management import (
    MPITaskScheduler,
    TaskScheduler,
)
from parsl.executors.high_throughput.probe import probe_addresses
from parsl.multiprocessing import SpawnContext
from parsl.process_loggers import wrap_with_logs
from parsl.serialize import serialize
from parsl.version import VERSION as PARSL_VERSION

HEARTBEAT_CODE = (2 ** 32) - 1
DRAINED_CODE = (2 ** 32) - 2


class Manager:
    """ Manager manages task execution by the workers

                |         zmq              |    Manager         |   Worker Processes
                |                          |                    |
                | <-----Register with -----+                    |      Request task<--+
                |       N task capacity    |                    |          |          |
    Interchange | -------------------------+->Receive task batch|          |          |
                |                          |  Distribute tasks--+----> Get(block) &   |
                |                          |                    |      Execute task   |
                |                          |                    |          |          |
                | <------------------------+--Return results----+----  Post result    |
                |                          |                    |          |          |
                |                          |                    |          +----------+
                |                          |                IPC-Qeueues

    """
    def __init__(self, *,
                 addresses,
                 address_probe_timeout,
                 task_port,
                 result_port,
                 cores_per_worker,
                 mem_per_worker,
                 max_workers_per_node,
                 prefetch_capacity,
                 uid,
                 block_id,
                 heartbeat_threshold,
                 heartbeat_period,
                 poll_period,
                 cpu_affinity,
                 enable_mpi_mode: bool = False,
                 mpi_launcher: str = "mpiexec",
                 available_accelerators: Sequence[str],
                 cert_dir: Optional[str],
                 drain_period: Optional[int]):
        """
        Parameters
        ----------
        addresses : str
             comma separated list of addresses for the interchange

        address_probe_timeout : int
             Timeout in seconds for the address probe to detect viable addresses
             to the interchange.

        uid : str
             string unique identifier

        block_id : str
             Block identifier that maps managers to the provider blocks they belong to.

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0.

        mem_per_worker : float
             GB of memory required per worker. If this option is specified, the node manager
             will check the available memory at startup and limit the number of workers such that
             the there's sufficient memory for each worker. If set to None, memory on node is not
             considered in the determination of workers to be launched on node by the manager.

        max_workers_per_node : int | float
             Caps the maximum number of workers that can be launched.

        prefetch_capacity : int
             Number of tasks that could be prefetched over available worker capacity.
             When there are a few tasks (<100) or when tasks are long running, this option should
             be set to 0 for better load balancing.

        heartbeat_threshold : int
             Seconds since the last message from the interchange after which the
             interchange is assumed to be un-available, and the manager initiates shutdown.

             Number of seconds since the last message from the interchange after which the worker
             assumes that the interchange is lost and the manager shuts down.

        heartbeat_period : int
             Number of seconds after which a heartbeat message is sent to the interchange, and workers
             are checked for liveness.

        poll_period : int
             Timeout period used by the manager in milliseconds.

        cpu_affinity : str
             Whether or how each worker should force its affinity to different CPUs

        available_accelerators: list of str
            List of accelerators available to the workers.

        enable_mpi_mode: bool
            When set to true, the manager assumes ownership of the batch job and each worker
            claims a subset of nodes from a shared pool to execute multi-node mpi tasks. Node
            info is made available to workers via env vars.

        mpi_launcher: str
            Set to one of the supported MPI launchers: ("srun", "aprun", "mpiexec")

        cert_dir : str | None
            Path to the certificate directory.

        drain_period: int | None
            Number of seconds to drain after  TODO: could be a nicer timespec involving m,s,h qualifiers for user friendliness?
        """

        logger.info("Manager initializing")

        self._start_time = time.time()

        try:
            ix_address = probe_addresses(addresses.split(','), task_port, timeout=address_probe_timeout)
            if not ix_address:
                raise Exception("No viable address found")
            else:
                logger.info("Connection to Interchange successful on {}".format(ix_address))
                task_q_url = tcp_url(ix_address, task_port)
                result_q_url = tcp_url(ix_address, result_port)
                logger.info("Task url : {}".format(task_q_url))
                logger.info("Result url : {}".format(result_q_url))
        except Exception:
            logger.exception("Caught exception while trying to determine viable address to interchange")
            print("Failed to find a viable address to connect to interchange. Exiting")
            exit(5)

        self.cert_dir = cert_dir
        self.zmq_context = curvezmq.ClientContext(self.cert_dir)
        self.task_incoming = self.zmq_context.socket(zmq.DEALER)
        self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        # Linger is set to 0, so that the manager can exit even when there might be
        # messages in the pipe
        self.task_incoming.setsockopt(zmq.LINGER, 0)
        self.task_incoming.connect(task_q_url)

        self.result_outgoing = self.zmq_context.socket(zmq.DEALER)
        self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        self.result_outgoing.setsockopt(zmq.LINGER, 0)
        self.result_outgoing.connect(result_q_url)
        logger.info("Manager connected to interchange")

        self.uid = uid
        self.block_id = block_id
        self.start_time = time.time()

        self.enable_mpi_mode = enable_mpi_mode
        self.mpi_launcher = mpi_launcher

        if os.environ.get('PARSL_CORES'):
            cores_on_node = int(os.environ['PARSL_CORES'])
        else:
            cores_on_node = SpawnContext.cpu_count()

        if os.environ.get('PARSL_MEMORY_GB'):
            available_mem_on_node = float(os.environ['PARSL_MEMORY_GB'])
        else:
            available_mem_on_node = round(psutil.virtual_memory().available / (2**30), 1)

        self.max_workers_per_node = max_workers_per_node
        self.prefetch_capacity = prefetch_capacity

        mem_slots = max_workers_per_node
        # Avoid a divide by 0 error.
        if mem_per_worker and mem_per_worker > 0:
            mem_slots = math.floor(available_mem_on_node / mem_per_worker)

        self.worker_count: int = min(max_workers_per_node,
                                     mem_slots,
                                     math.floor(cores_on_node / cores_per_worker))

        self._mp_manager = SpawnContext.Manager()  # Starts a server process

        self.monitoring_queue = self._mp_manager.Queue()
        self.pending_task_queue = SpawnContext.Queue()
        self.pending_result_queue = SpawnContext.Queue()
        self.task_scheduler: TaskScheduler
        if self.enable_mpi_mode:
            self.task_scheduler = MPITaskScheduler(
                self.pending_task_queue,
                self.pending_result_queue,
            )
        else:
            self.task_scheduler = TaskScheduler(
                self.pending_task_queue,
                self.pending_result_queue
            )
        self.ready_worker_count = SpawnContext.Value("i", 0)

        self.max_queue_size = self.prefetch_capacity + self.worker_count

        self.tasks_per_round = 1

        self.heartbeat_period = heartbeat_period
        self.heartbeat_threshold = heartbeat_threshold
        self.poll_period = poll_period

        self.drain_time: float
        if drain_period:
            self.drain_time = self._start_time + drain_period
            logger.info(f"Will request drain at {self.drain_time}")
        else:
            self.drain_time = float('inf')

        self.cpu_affinity = cpu_affinity

        # Define accelerator available, adjust worker count accordingly
        self.available_accelerators = available_accelerators
        self.accelerators_available = len(available_accelerators) > 0
        if self.accelerators_available:
            self.worker_count = min(len(self.available_accelerators), self.worker_count)
        logger.info("Manager will spawn {} workers".format(self.worker_count))

    def create_reg_message(self):
        """ Creates a registration message to identify the worker to the interchange
        """
        msg = {'type': 'registration',
               'parsl_v': PARSL_VERSION,
               'python_v': "{}.{}.{}".format(sys.version_info.major,
                                             sys.version_info.minor,
                                             sys.version_info.micro),
               'worker_count': self.worker_count,
               'uid': self.uid,
               'block_id': self.block_id,
               'start_time': self.start_time,
               'prefetch_capacity': self.prefetch_capacity,
               'max_capacity': self.worker_count + self.prefetch_capacity,
               'os': platform.system(),
               'hostname': platform.node(),
               'dir': os.getcwd(),
               'cpu_count': psutil.cpu_count(logical=False),
               'total_memory': psutil.virtual_memory().total,
               }
        b_msg = json.dumps(msg).encode('utf-8')
        return b_msg

    def heartbeat_to_incoming(self):
        """ Send heartbeat to the incoming task queue
        """
        msg = {'type': 'heartbeat'}
        # don't need to dumps and encode this every time - could do as a global on import?
        b_msg = json.dumps(msg).encode('utf-8')
        self.task_incoming.send(b_msg)
        logger.debug("Sent heartbeat")

    def drain_to_incoming(self):
        """ Send heartbeat to the incoming task queue
        """
        msg = {'type': 'drain'}
        b_msg = json.dumps(msg).encode('utf-8')
        self.task_incoming.send(b_msg)
        logger.debug("Sent drain")

    @wrap_with_logs
    def pull_tasks(self, kill_event):
        """ Pull tasks from the incoming tasks zmq pipe onto the internal
        pending task queue

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """
        logger.info("starting")
        poller = zmq.Poller()
        poller.register(self.task_incoming, zmq.POLLIN)

        # Send a registration message
        msg = self.create_reg_message()
        logger.debug("Sending registration message: {}".format(msg))
        self.task_incoming.send(msg)
        last_beat = time.time()
        last_interchange_contact = time.time()
        task_recv_counter = 0

        while not kill_event.is_set():

            # This loop will sit inside poller.poll until either a message
            # arrives or one of these event times is reached. This code
            # assumes that the event times won't change except on iteration
            # of this loop - so will break if a different thread does
            # anything to bring one of the event times earlier - and that the
            # time here are correctly copy-pasted from the relevant if
            # statements.
            next_interesting_event_time = min(last_beat + self.heartbeat_period,
                                              self.drain_time,
                                              last_interchange_contact + self.heartbeat_threshold)
            try:
                pending_task_count = self.pending_task_queue.qsize()
            except NotImplementedError:
                # Ref: https://github.com/python/cpython/blob/6d5e0dc0e330f4009e8dc3d1642e46b129788877/Lib/multiprocessing/queues.py#L125
                pending_task_count = f"pending task count is not available on {platform.system()}"

            logger.debug("ready workers: {}, pending tasks: {}".format(self.ready_worker_count.value,
                                                                       pending_task_count))

            if time.time() >= last_beat + self.heartbeat_period:
                self.heartbeat_to_incoming()
                last_beat = time.time()

            if time.time() > self.drain_time:
                logger.info("Requesting drain")
                self.drain_to_incoming()
                # This will start the pool draining...
                # Drained exit behaviour does not happen here. It will be
                # driven by the interchange sending a DRAINED_CODE message.

                # now set drain time to the far future so we don't send a drain
                # message every iteration.
                self.drain_time = float('inf')

            poll_duration_s = max(0, next_interesting_event_time - time.time())
            socks = dict(poller.poll(timeout=poll_duration_s * 1000))

            if self.task_incoming in socks and socks[self.task_incoming] == zmq.POLLIN:
                _, pkl_msg = self.task_incoming.recv_multipart()
                tasks = pickle.loads(pkl_msg)
                last_interchange_contact = time.time()

                if tasks == HEARTBEAT_CODE:
                    logger.debug("Got heartbeat from interchange")
                elif tasks == DRAINED_CODE:
                    logger.info("Got fully drained message from interchange - setting kill flag")
                    kill_event.set()
                else:
                    task_recv_counter += len(tasks)
                    logger.debug("Got executor tasks: {}, cumulative count of tasks: {}".format(
                        [t['task_id'] for t in tasks], task_recv_counter
                    ))

                    for task in tasks:
                        self.task_scheduler.put_task(task)

            else:
                logger.debug("No incoming tasks")

                # Only check if no messages were received.
                if time.time() >= last_interchange_contact + self.heartbeat_threshold:
                    logger.critical("Missing contact with interchange beyond heartbeat_threshold")
                    kill_event.set()
                    logger.critical("Exiting")
                    break

    @wrap_with_logs
    def push_results(self, kill_event):
        """ Listens on the pending_result_queue and sends out results via zmq

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """

        logger.debug("Starting result push thread")

        push_poll_period = max(10, self.poll_period) / 1000    # push_poll_period must be atleast 10 ms
        logger.debug("push poll period: {}".format(push_poll_period))

        last_beat = time.time()
        last_result_beat = time.time()
        items = []

        while not kill_event.is_set():
            try:
                logger.debug("Starting pending_result_queue get")
                r = self.task_scheduler.get_result(block=True, timeout=push_poll_period)
                logger.debug("Got a result item")
                items.append(r)
            except queue.Empty:
                logger.debug("pending_result_queue get timeout without result item")
            except Exception as e:
                logger.exception("Got an exception: {}".format(e))

            if time.time() > last_result_beat + self.heartbeat_period:
                heartbeat_message = f"last_result_beat={last_result_beat} heartbeat_period={self.heartbeat_period} seconds"
                logger.info(f"Sending heartbeat via results connection: {heartbeat_message}")
                last_result_beat = time.time()
                items.append(pickle.dumps({'type': 'heartbeat'}))

            if len(items) >= self.max_queue_size or time.time() > last_beat + push_poll_period:
                last_beat = time.time()
                if items:
                    logger.debug(f"Result send: Pushing {len(items)} items")
                    self.result_outgoing.send_multipart(items)
                    logger.debug("Result send: Pushed")
                    items = []
                else:
                    logger.debug("Result send: No items to push")
            else:
                logger.debug(f"Result send: check condition not met - deferring {len(items)} result items")

        logger.critical("Exiting")

    @wrap_with_logs
    def worker_watchdog(self, kill_event: threading.Event):
        """Keeps workers alive.

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """

        logger.debug("Starting worker watchdog")

        while not kill_event.wait(self.heartbeat_period):
            for worker_id, p in self.procs.items():
                if not p.is_alive():
                    logger.error("Worker {} has died".format(worker_id))
                    try:
                        task = self._tasks_in_progress.pop(worker_id)
                        logger.info("Worker {} was busy when it died".format(worker_id))
                        try:
                            raise WorkerLost(worker_id, platform.node())
                        except Exception:
                            logger.info("Putting exception for executor task {} in the pending result queue".format(task['task_id']))
                            result_package = {'type': 'result',
                                              'task_id': task['task_id'],
                                              'exception': serialize(RemoteExceptionWrapper(*sys.exc_info()))}
                            pkl_package = pickle.dumps(result_package)
                            self.pending_result_queue.put(pkl_package)
                    except KeyError:
                        logger.info("Worker {} was not busy when it died".format(worker_id))

                    p = self._start_worker(worker_id)
                    self.procs[worker_id] = p
                    logger.info("Worker {} has been restarted".format(worker_id))

        logger.critical("Exiting")

    @wrap_with_logs
    def handle_monitoring_messages(self, kill_event: threading.Event):
        """Transfer messages from the managed monitoring queue to the result queue.

        We separate the queues so that the result queue does not rely on a manager
        process, which adds overhead that causes slower queue operations but enables
        use across processes started in fork and spawn contexts.

        We transfer the messages to the result queue to reuse the ZMQ connection between
        the manager and the interchange.
        """
        logger.debug("Starting monitoring handler thread")

        poll_period_s = max(10, self.poll_period) / 1000    # Must be at least 10 ms

        while not kill_event.is_set():
            try:
                logger.debug("Starting monitor_queue.get()")
                msg = self.monitoring_queue.get(block=True, timeout=poll_period_s)
            except queue.Empty:
                logger.debug("monitoring_queue.get() has timed out")
            except Exception as e:
                logger.exception(f"Got an exception: {e}")
            else:
                logger.debug("Got a monitoring message")
                self.pending_result_queue.put(msg)
                logger.debug("Put monitoring message on pending_result_queue")

        logger.critical("Exiting")

    def start(self):
        """ Start the worker processes.

        TODO: Move task receiving to a thread
        """
        self._kill_event = threading.Event()
        self._tasks_in_progress = self._mp_manager.dict()

        self.procs = {}
        for worker_id in range(self.worker_count):
            p = self._start_worker(worker_id)
            self.procs[worker_id] = p

        logger.debug("Workers started")

        self._task_puller_thread = threading.Thread(target=self.pull_tasks,
                                                    args=(self._kill_event,),
                                                    name="Task-Puller")
        self._result_pusher_thread = threading.Thread(target=self.push_results,
                                                      args=(self._kill_event,),
                                                      name="Result-Pusher")
        self._worker_watchdog_thread = threading.Thread(target=self.worker_watchdog,
                                                        args=(self._kill_event,),
                                                        name="worker-watchdog")
        self._monitoring_handler_thread = threading.Thread(target=self.handle_monitoring_messages,
                                                           args=(self._kill_event,),
                                                           name="Monitoring-Handler")

        self._task_puller_thread.start()
        self._result_pusher_thread.start()
        self._worker_watchdog_thread.start()
        self._monitoring_handler_thread.start()

        logger.info("Manager threads started")

        # This might need a multiprocessing event to signal back.
        self._kill_event.wait()
        logger.critical("Received kill event, terminating worker processes")

        self._task_puller_thread.join()
        self._result_pusher_thread.join()
        self._worker_watchdog_thread.join()
        self._monitoring_handler_thread.join()
        for proc_id in self.procs:
            self.procs[proc_id].terminate()
            logger.critical("Terminating worker {}: is_alive()={}".format(self.procs[proc_id],
                                                                          self.procs[proc_id].is_alive()))
            self.procs[proc_id].join()
            logger.debug("Worker {} joined successfully".format(self.procs[proc_id]))

        self.task_incoming.close()
        self.result_outgoing.close()
        self.zmq_context.term()
        delta = time.time() - self._start_time
        logger.info("process_worker_pool ran for {} seconds".format(delta))
        return

    def _start_worker(self, worker_id: int):
        p = SpawnContext.Process(
            target=worker,
            args=(
                worker_id,
                self.uid,
                self.worker_count,
                self.pending_task_queue,
                self.pending_result_queue,
                self.monitoring_queue,
                self.ready_worker_count,
                self._tasks_in_progress,
                self.cpu_affinity,
                self.available_accelerators[worker_id] if self.accelerators_available else None,
                self.block_id,
                self.heartbeat_period,
                os.getpid(),
                args.logdir,
                args.debug,
                self.mpi_launcher,
            ),
            name="HTEX-Worker-{}".format(worker_id),
        )
        p.start()
        return p


def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_info: List[str]) -> None:
    prefix_table = compose_all(mpi_launcher, resource_spec=resource_spec, node_hostnames=node_info)
    for key in prefix_table:
        os.environ[key] = prefix_table[key]


def _init_mpi_env(mpi_launcher: str, resource_spec: Dict):
    node_list = resource_spec.get("MPI_NODELIST")
    if node_list is None:
        return
    nodes_for_task = node_list.split(',')
    logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
    update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task)


@wrap_with_logs(target="worker_log")
def worker(
    worker_id: int,
    pool_id: str,
    pool_size: int,
    task_queue: multiprocessing.Queue,
    result_queue: multiprocessing.Queue,
    monitoring_queue: queue.Queue,
    ready_worker_count: Synchronized,
    tasks_in_progress: DictProxy,
    cpu_affinity: str,
    accelerator: Optional[str],
    block_id: str,
    task_queue_timeout: int,
    manager_pid: int,
    logdir: str,
    debug: bool,
    mpi_launcher: str,
):
    # override the global logger inherited from the __main__ process (which
    # usually logs to manager.log) with one specific to this worker.
    global logger
    logger = start_file_logger('{}/block-{}/{}/worker_{}.log'.format(logdir, block_id, pool_id, worker_id),
                               worker_id,
                               name="worker_log",
                               level=logging.DEBUG if debug else logging.INFO)

    # Store worker ID as an environment variable
    os.environ['PARSL_WORKER_RANK'] = str(worker_id)
    os.environ['PARSL_WORKER_COUNT'] = str(pool_size)
    os.environ['PARSL_WORKER_POOL_ID'] = str(pool_id)
    os.environ['PARSL_WORKER_BLOCK_ID'] = str(block_id)

    import parsl.executors.high_throughput.monitoring_info as mi
    mi.result_queue = monitoring_queue

    logger.info('Worker {} started'.format(worker_id))
    if debug:
        logger.debug("Debug logging enabled")

    # If desired, set process affinity
    if cpu_affinity != "none":
        # Count the number of cores per worker
        # OSX does not implement os.sched_getaffinity
        avail_cores = sorted(os.sched_getaffinity(0))  # type: ignore[attr-defined, unused-ignore]
        cores_per_worker = len(avail_cores) // pool_size
        assert cores_per_worker > 0, "Affinity does not work if there are more workers than cores"

        # Determine this worker's cores
        if cpu_affinity == "block":
            my_cores = avail_cores[cores_per_worker * worker_id:cores_per_worker * (worker_id + 1)]
        elif cpu_affinity == "block-reverse":
            cpu_worker_id = pool_size - worker_id - 1  # To assign in reverse order
            my_cores = avail_cores[cores_per_worker * cpu_worker_id:cores_per_worker * (cpu_worker_id + 1)]
        elif cpu_affinity == "alternating":
            my_cores = avail_cores[worker_id::pool_size]
        elif cpu_affinity[0:4] == "list":
            thread_ranks = cpu_affinity.split(":")[1:]
            if len(thread_ranks) != pool_size:
                raise ValueError("Affinity list {} has wrong number of thread ranks".format(cpu_affinity))
            threads = thread_ranks[worker_id]
            thread_list = threads.split(",")
            my_cores = []
            for tl in thread_list:
                thread_range = tl.split("-")
                if len(thread_range) == 1:
                    my_cores.append(int(thread_range[0]))
                elif len(thread_range) == 2:
                    start_thread = int(thread_range[0])
                    end_thread = int(thread_range[1]) + 1
                    my_cores += list(range(start_thread, end_thread))
                else:
                    raise ValueError("Affinity list formatting is not expected {}".format(cpu_affinity))
        else:
            raise ValueError("Affinity strategy {} is not supported".format(cpu_affinity))

        # Set the affinity for OpenMP
        #  See: https://hpc-tutorials.llnl.gov/openmp/ProcessThreadAffinity.pdf
        proc_list = ",".join(map(str, my_cores))
        os.environ["OMP_NUM_THREADS"] = str(len(my_cores))
        os.environ["GOMP_CPU_AFFINITY"] = proc_list  # Compatible with GCC OpenMP
        os.environ["KMP_AFFINITY"] = f"explicit,proclist=[{proc_list}]"  # For Intel OpenMP

        # Set the affinity for this worker
        # OSX does not implement os.sched_setaffinity so type checking
        # is ignored here in two ways:
        # On a platform without sched_setaffinity, that attribute will not
        # be defined, so ignore[attr-defined] will tell mypy to ignore this
        # incorrect-for-OS X attribute access.
        # On a platform with sched_setaffinity, that type: ignore message
        # will be redundant, and ignore[unused-ignore] tells mypy to ignore
        # that this ignore is unneeded.
        os.sched_setaffinity(0, my_cores)  # type: ignore[attr-defined, unused-ignore]
        logger.info("Set worker CPU affinity to {}".format(my_cores))

    # If desired, pin to accelerator
    if accelerator is not None:

        # If CUDA devices, find total number of devices to allow for MPS
        # See: https://developer.nvidia.com/system-management-interface
        nvidia_smi_cmd = "nvidia-smi -L > /dev/null && nvidia-smi -L | wc -l"
        nvidia_smi_ret = subprocess.run(nvidia_smi_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if nvidia_smi_ret.returncode == 0:
            num_cuda_devices = int(nvidia_smi_ret.stdout.split()[0])
        else:
            num_cuda_devices = None

        try:
            if num_cuda_devices is not None:
                procs_per_cuda_device = pool_size // num_cuda_devices
                partitioned_accelerator = str(int(accelerator) // procs_per_cuda_device)  # multiple workers will share a GPU
                os.environ["CUDA_VISIBLE_DEVICES"] = partitioned_accelerator
                logger.info(f'Pinned worker to partitioned cuda device: {partitioned_accelerator}')
            else:
                os.environ["CUDA_VISIBLE_DEVICES"] = accelerator
        except (TypeError, ValueError, ZeroDivisionError):
            os.environ["CUDA_VISIBLE_DEVICES"] = accelerator
        os.environ["ROCR_VISIBLE_DEVICES"] = accelerator
        os.environ["ZE_AFFINITY_MASK"] = accelerator
        os.environ["ZE_ENABLE_PCI_ID_DEVICE_ORDER"] = '1'

        logger.info(f'Pinned worker to accelerator: {accelerator}')

    def manager_is_alive():
        try:
            # This does not kill the process, but instead raises
            # an exception if the process doesn't exist
            os.kill(manager_pid, 0)
        except OSError:
            logger.critical(f"Manager ({manager_pid}) died; worker {worker_id} shutting down")
            return False
        else:
            return True

    worker_enqueued = False
    while manager_is_alive():
        if not worker_enqueued:
            with ready_worker_count.get_lock():
                ready_worker_count.value += 1
            worker_enqueued = True

        try:
            # The worker will receive {'task_id':<tid>, 'buffer':<buf>}
            req = task_queue.get(timeout=task_queue_timeout)
        except queue.Empty:
            continue

        tasks_in_progress[worker_id] = req
        tid = req['task_id']
        logger.info("Received executor task {}".format(tid))

        with ready_worker_count.get_lock():
            ready_worker_count.value -= 1
        worker_enqueued = False

        _init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"])

        try:
            result = execute_task(req['buffer'])
            serialized_result = serialize(result, buffer_threshold=1000000)
        except Exception as e:
            logger.info('Caught an exception: {}'.format(e))
            result_package = {'type': 'result', 'task_id': tid, 'exception': serialize(RemoteExceptionWrapper(*sys.exc_info()))}
        else:
            result_package = {'type': 'result', 'task_id': tid, 'result': serialized_result}
            # logger.debug("Result: {}".format(result))

        logger.info("Completed executor task {}".format(tid))
        try:
            pkl_package = pickle.dumps(result_package)
        except Exception:
            logger.exception("Caught exception while trying to pickle the result package")
            pkl_package = pickle.dumps({'type': 'result', 'task_id': tid,
                                        'exception': serialize(RemoteExceptionWrapper(*sys.exc_info()))
                                        })

        result_queue.put(pkl_package)
        tasks_in_progress.pop(worker_id)
        logger.info("All processing finished for executor task {}".format(tid))


def start_file_logger(filename, rank, name='parsl', level=logging.DEBUG, format_string=None):
    """Add a stream log handler.

    Args:
        - filename (string): Name of the file to write logs to
        - name (string): Logger name
        - level (logging.LEVEL): Set the logging level.
        - format_string (string): Set the format string

    Returns:
       -  None
    """
    if format_string is None:
        format_string = "%(asctime)s.%(msecs)03d %(name)s:%(lineno)d " \
                        "%(process)d %(threadName)s " \
                        "[%(levelname)s]  %(message)s"

    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(filename)
    handler.setLevel(level)
    formatter = logging.Formatter(format_string, datefmt='%Y-%m-%d %H:%M:%S')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", action='store_true',
                        help="Enable logging at DEBUG level")
    parser.add_argument("-a", "--addresses", default='',
                        help="Comma separated list of addresses at which the interchange could be reached")
    parser.add_argument("--cert_dir", required=True,
                        help="Path to certificate directory.")
    parser.add_argument("-l", "--logdir", default="process_worker_pool_logs",
                        help="Process worker pool log directory")
    parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1],
                        help="Unique identifier string for Manager")
    parser.add_argument("-b", "--block_id", default=None,
                        help="Block identifier for Manager")
    parser.add_argument("-c", "--cores_per_worker", default="1.0",
                        help="Number of cores assigned to each worker process. Default=1.0")
    parser.add_argument("-m", "--mem_per_worker", default=0,
                        help="GB of memory assigned to each worker process. Default=0, no assignment")
    parser.add_argument("-t", "--task_port", required=True,
                        help="REQUIRED: Task port for receiving tasks from the interchange")
    parser.add_argument("--max_workers_per_node", default=float('inf'),
                        help="Caps the maximum workers that can be launched, default:infinity")
    parser.add_argument("-p", "--prefetch_capacity", default=0,
                        help="Number of tasks that can be prefetched to the manager. Default is 0.")
    parser.add_argument("--hb_period", default=30,
                        help="Heartbeat period in seconds. Uses manager default unless set")
    parser.add_argument("--hb_threshold", default=120,
                        help="Heartbeat threshold in seconds. Uses manager default unless set")
    parser.add_argument("--drain_period", default=None,
                        help="Drain this pool after specified number of seconds. By default, does not drain.")
    parser.add_argument("--address_probe_timeout", default=30,
                        help="Timeout to probe for viable address to interchange. Default: 30s")
    parser.add_argument("--poll", default=10,
                        help="Poll period used in milliseconds")
    parser.add_argument("-r", "--result_port", required=True,
                        help="REQUIRED: Result port for posting results to the interchange")

    def strategyorlist(s: str):
        allowed_strategies = ["none", "block", "alternating", "block-reverse"]
        if s in allowed_strategies:
            return s
        elif s[0:4] == "list":
            return s
        else:
            raise argparse.ArgumentTypeError("cpu-affinity must be one of {} or a list format".format(allowed_strategies))

    parser.add_argument("--cpu-affinity", type=strategyorlist,
                        required=True,
                        help="Whether/how workers should control CPU affinity.")
    parser.add_argument("--available-accelerators", type=str, nargs="*",
                        help="Names of available accelerators, if not given assumed to be zero accelerators available", default=[])
    parser.add_argument("--enable_mpi_mode", action='store_true',
                        help="Enable MPI mode")
    parser.add_argument("--mpi-launcher", type=str, choices=VALID_LAUNCHERS,
                        help="MPI launcher to use iff enable_mpi_mode=true")

    args = parser.parse_args()

    os.makedirs(os.path.join(args.logdir, "block-{}".format(args.block_id), args.uid), exist_ok=True)

    try:
        logger = start_file_logger('{}/block-{}/{}/manager.log'.format(args.logdir, args.block_id, args.uid),
                                   0,
                                   level=logging.DEBUG if args.debug is True else logging.INFO)

        logger.info("Python version: {}".format(sys.version))
        logger.info("Debug logging: {}".format(args.debug))
        logger.info("Certificates dir: {}".format(args.cert_dir))
        logger.info("Log dir: {}".format(args.logdir))
        logger.info("Manager ID: {}".format(args.uid))
        logger.info("Block ID: {}".format(args.block_id))
        logger.info("cores_per_worker: {}".format(args.cores_per_worker))
        logger.info("mem_per_worker: {}".format(args.mem_per_worker))
        logger.info("task_port: {}".format(args.task_port))
        logger.info("result_port: {}".format(args.result_port))
        logger.info("addresses: {}".format(args.addresses))
        logger.info("max_workers_per_node: {}".format(args.max_workers_per_node))
        logger.info("poll_period: {}".format(args.poll))
        logger.info("address_probe_timeout: {}".format(args.address_probe_timeout))
        logger.info("Prefetch capacity: {}".format(args.prefetch_capacity))
        logger.info("Heartbeat threshold: {}".format(args.hb_threshold))
        logger.info("Heartbeat period: {}".format(args.hb_period))
        logger.info("Drain period: {}".format(args.drain_period))
        logger.info("CPU affinity: {}".format(args.cpu_affinity))
        logger.info("Accelerators: {}".format(" ".join(args.available_accelerators)))
        logger.info("enable_mpi_mode: {}".format(args.enable_mpi_mode))
        logger.info("mpi_launcher: {}".format(args.mpi_launcher))

        manager = Manager(task_port=args.task_port,
                          result_port=args.result_port,
                          addresses=args.addresses,
                          address_probe_timeout=int(args.address_probe_timeout),
                          uid=args.uid,
                          block_id=args.block_id,
                          cores_per_worker=float(args.cores_per_worker),
                          mem_per_worker=None if args.mem_per_worker == 'None' else float(args.mem_per_worker),
                          max_workers_per_node=(
                              args.max_workers_per_node if args.max_workers_per_node == float('inf')
                              else int(args.max_workers_per_node)
                          ),
                          prefetch_capacity=int(args.prefetch_capacity),
                          heartbeat_threshold=int(args.hb_threshold),
                          heartbeat_period=int(args.hb_period),
                          drain_period=None if args.drain_period == "None" else int(args.drain_period),
                          poll_period=int(args.poll),
                          cpu_affinity=args.cpu_affinity,
                          enable_mpi_mode=args.enable_mpi_mode,
                          mpi_launcher=args.mpi_launcher,
                          available_accelerators=args.available_accelerators,
                          cert_dir=None if args.cert_dir == "None" else args.cert_dir)
        manager.start()

    except Exception:
        logger.critical("Process worker pool exiting with an exception", exc_info=True)
        raise
    else:
        logger.info("Process worker pool exiting normally")
        print("Process worker pool exiting normally")