1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
import logging
import multiprocessing
import os
import pickle
import queue
import subprocess
from enum import Enum
from typing import Dict, List
from parsl.multiprocessing import SpawnContext
from parsl.serialize import pack_res_spec_apply_message, unpack_res_spec_apply_message
logger = logging.getLogger(__name__)
class Scheduler(Enum):
Unknown = 0
Slurm = 1
PBS = 2
def get_slurm_hosts_list() -> List[str]:
"""Get list of slurm hosts from scontrol"""
cmd = "scontrol show hostname $SLURM_NODELIST"
b_output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT, shell=True
) # bytes
output = b_output.decode().strip().split()
return output
def get_pbs_hosts_list() -> List[str]:
"""Get list of PBS hosts from envvar: PBS_NODEFILE"""
nodefile_name = os.environ["PBS_NODEFILE"]
with open(nodefile_name) as f:
return [line.strip() for line in f.readlines()]
def get_nodes_in_batchjob(scheduler: Scheduler) -> List[str]:
"""Get nodelist from all supported schedulers"""
nodelist = []
if scheduler == Scheduler.Slurm:
nodelist = get_slurm_hosts_list()
elif scheduler == Scheduler.PBS:
nodelist = get_pbs_hosts_list()
else:
raise RuntimeError(f"mpi_mode does not support scheduler:{scheduler}")
return nodelist
def identify_scheduler() -> Scheduler:
"""Use envvars to determine batch scheduler"""
if os.environ.get("SLURM_NODELIST"):
return Scheduler.Slurm
elif os.environ.get("PBS_NODEFILE"):
return Scheduler.PBS
else:
return Scheduler.Unknown
class MPINodesUnavailable(Exception):
"""Raised if there are no free nodes available for an MPI request"""
def __init__(self, requested: int, available: int):
self.requested = requested
self.available = available
def __str__(self):
return f"MPINodesUnavailable(requested={self.requested} available={self.available})"
class TaskScheduler:
"""Default TaskScheduler that does no taskscheduling
This class simply acts as an abstraction over the task_q and result_q
that can be extended to implement more complex task scheduling logic
"""
def __init__(
self,
pending_task_q: multiprocessing.Queue,
pending_result_q: multiprocessing.Queue,
):
self.pending_task_q = pending_task_q
self.pending_result_q = pending_result_q
def put_task(self, task) -> None:
return self.pending_task_q.put(task)
def get_result(self, block: bool, timeout: float):
return self.pending_result_q.get(block, timeout=timeout)
class MPITaskScheduler(TaskScheduler):
"""Extends TaskScheduler to schedule MPI functions over provisioned nodes
The MPITaskScheduler runs on a Manager on the lead node of a batch job, as
such it is expected to control task placement over this single batch job.
The MPITaskScheduler adds the following functionality:
1) Determine list of nodes attached to current batch job
2) put_task for execution onto workers:
a) if resources are available attach resource list
b) if unavailable place tasks into backlog
3) get_result will fetch a result and relinquish nodes,
and attempt to schedule tasks in backlog if any.
"""
def __init__(
self,
pending_task_q: multiprocessing.Queue,
pending_result_q: multiprocessing.Queue,
):
super().__init__(pending_task_q, pending_result_q)
self.scheduler = identify_scheduler()
# PriorityQueue is threadsafe
self._backlog_queue: queue.PriorityQueue = queue.PriorityQueue()
self._map_tasks_to_nodes: Dict[str, List[str]] = {}
self.available_nodes = get_nodes_in_batchjob(self.scheduler)
self._free_node_counter = SpawnContext.Value("i", len(self.available_nodes))
# mp.Value has issues with mypy
# issue https://github.com/python/typeshed/issues/8799
# from mypy 0.981 onwards
self.nodes_q: queue.Queue = queue.Queue()
for node in self.available_nodes:
self.nodes_q.put(node)
logger.info(
f"Starting MPITaskScheduler with {len(self.available_nodes)}"
)
def _get_nodes(self, num_nodes: int) -> List[str]:
"""Thread safe method to acquire num_nodes from free resources
Raises: MPINodesUnavailable if there aren't enough resources
Returns: List of nodenames:str
"""
logger.debug(
f"Requesting : {num_nodes=} we have {self._free_node_counter}"
)
acquired_nodes = []
with self._free_node_counter.get_lock():
if num_nodes <= self._free_node_counter.value: # type: ignore[attr-defined]
self._free_node_counter.value -= num_nodes # type: ignore[attr-defined]
else:
raise MPINodesUnavailable(
requested=num_nodes, available=self._free_node_counter.value # type: ignore[attr-defined]
)
for i in range(num_nodes):
node = self.nodes_q.get()
acquired_nodes.append(node)
return acquired_nodes
def _return_nodes(self, nodes: List[str]) -> None:
"""Threadsafe method to return a list of nodes"""
for node in nodes:
self.nodes_q.put(node)
with self._free_node_counter.get_lock():
self._free_node_counter.value += len(nodes) # type: ignore[attr-defined]
def put_task(self, task_package: dict):
"""Schedule task if resources are available otherwise backlog the task"""
user_ns = locals()
user_ns.update({"__builtins__": __builtins__})
_f, _args, _kwargs, resource_spec = unpack_res_spec_apply_message(task_package["buffer"])
nodes_needed = resource_spec.get("num_nodes")
if nodes_needed:
try:
allocated_nodes = self._get_nodes(nodes_needed)
except MPINodesUnavailable:
logger.warning("Not enough resources, placing task into backlog")
self._backlog_queue.put((nodes_needed, task_package))
return
else:
resource_spec["MPI_NODELIST"] = ",".join(allocated_nodes)
self._map_tasks_to_nodes[task_package["task_id"]] = allocated_nodes
buffer = pack_res_spec_apply_message(_f, _args, _kwargs, resource_spec)
task_package["buffer"] = buffer
task_package["resource_spec"] = resource_spec
self.pending_task_q.put(task_package)
def _schedule_backlog_tasks(self):
"""Attempt to schedule backlogged tasks"""
try:
_nodes_requested, task_package = self._backlog_queue.get(block=False)
self.put_task(task_package)
except queue.Empty:
return
else:
# Keep attempting to schedule tasks till we are out of resources
self._schedule_backlog_tasks()
def get_result(self, block: bool, timeout: float):
"""Return result and relinquish provisioned nodes"""
result_pkl = self.pending_result_q.get(block, timeout=timeout)
result_dict = pickle.loads(result_pkl)
# TODO (wardlt): If the task did not request nodes, it won't be in `self._map_tasks_to_nodes`.
# Causes Parsl to hang. See Issue #3427
if result_dict["type"] == "result":
task_id = result_dict["task_id"]
assert task_id in self._map_tasks_to_nodes, "You are about to experience issue #3427"
nodes_to_reallocate = self._map_tasks_to_nodes[task_id]
self._return_nodes(nodes_to_reallocate)
self._schedule_backlog_tasks()
return result_pkl
|