1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
|
"""
Experimental support for distributed training with external memory
==================================================================
.. versionadded:: 3.0.0
See :doc:`the tutorial </tutorials/external_memory>` for more details. To run the
example, following packages in addition to XGBoost native dependencies are required:
- scikit-learn
- loky
If `device` is `cuda`, following are also needed:
- cupy
- python-cuda
- rmm
"""
import argparse
import multiprocessing as mp
import os
import sys
import tempfile
import traceback
from functools import partial, update_wrapper, wraps
from typing import Callable, List, ParamSpec, Tuple, TypeVar
import numpy as np
from loky import get_reusable_executor
from sklearn.datasets import make_regression
import xgboost
from xgboost import collective as coll
from xgboost.tracker import RabitTracker
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str, rank: int
) -> List[Tuple[str, str]]:
files: List[Tuple[str, str]] = []
rng = np.random.RandomState(rank)
for i in range(n_batches):
X, y = make_regression(n_samples_per_batch, n_features, random_state=rng)
X_path = os.path.join(tmpdir, f"X-r{rank}-{i}.npy")
y_path = os.path.join(tmpdir, f"y-r{rank}-{i}.npy")
np.save(X_path, X)
np.save(y_path, y)
files.append((X_path, y_path))
return files
class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches."""
def __init__(self, device: str, file_paths: List[Tuple[str, str]]) -> None:
self.device = device
self._file_paths = file_paths
self._it = 0
# XGBoost will generate some cache files under the current directory with the
# prefix "cache"
super().__init__(cache_prefix=os.path.join(".", "cache"))
def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
"""Load a single batch of data."""
X_path, y_path = self._file_paths[self._it]
# When the `ExtMemQuantileDMatrix` is used, the device must match. GPU cannot
# consume CPU input data and vice-versa.
if self.device == "cpu":
X = np.load(X_path)
y = np.load(y_path)
else:
X = cp.load(X_path)
y = cp.load(y_path)
assert X.shape[0] == y.shape[0]
return X, y
def next(self, input_data: Callable) -> bool:
"""Advance the iterator by 1 step and pass the data to XGBoost. This function
is called by XGBoost during the construction of ``DMatrix``
"""
if self._it == len(self._file_paths):
# return False to let XGBoost know this is the end of iteration
return False
# input_data is a keyword-only function passed in by XGBoost and has the similar
# signature to the ``DMatrix`` constructor.
X, y = self.load_file()
input_data(data=X, label=y)
self._it += 1
return True
def reset(self) -> None:
"""Reset the iterator to its beginning"""
self._it = 0
def setup_rmm() -> None:
"""Setup RMM for GPU-based external memory training.
It's important to use RMM with `CudaAsyncMemoryResource` or `ArenaMemoryResource`
for GPU-based external memory to improve performance. If XGBoost is not built with
RMM support, a warning is raised when constructing the `DMatrix`.
"""
import rmm
from cuda import cudart
from rmm.allocators.cupy import rmm_cupy_allocator
from rmm.mr import ArenaMemoryResource
if not xgboost.build_info()["USE_RMM"]:
return
status, free, total = cudart.cudaMemGetInfo()
if status != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(cudart.cudaGetErrorString(status))
mr = rmm.mr.CudaMemoryResource()
mr = ArenaMemoryResource(mr, arena_size=int(total * 0.9))
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)
R = TypeVar("R")
P = ParamSpec("P")
def try_run(fn: Callable[P, R]) -> Callable[P, R]:
"""Loky aborts the process without printing out any error message if there's an
exception.
"""
@wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return fn(*args, **kwargs)
except Exception as e:
print(traceback.format_exc(), file=sys.stderr)
raise RuntimeError("Running into exception in worker.") from e
return inner
@try_run
def hist_train(worker_idx: int, tmpdir: str, device: str, rabit_args: dict) -> None:
"""The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for
faster initialization and lower memory usage.
"""
# Make sure XGBoost is using RMM for all allocations.
with coll.CommunicatorContext(**rabit_args), xgboost.config_context(use_rmm=True):
# Generate the data for demonstration. The sythetic data is sharded by workers.
files = make_batches(
n_samples_per_batch=4096,
n_features=16,
n_batches=17,
tmpdir=tmpdir,
rank=coll.get_rank(),
)
# Since we are running two workers on a single node, we should divide the number
# of threads between workers.
n_threads = os.cpu_count()
assert n_threads is not None
n_threads = max(n_threads // coll.get_world_size(), 1)
it = Iterator(device, files)
Xy = xgboost.ExtMemQuantileDMatrix(
it, missing=np.nan, enable_categorical=False, nthread=n_threads
)
# Check the device is correctly set.
if device == "cuda":
# Check the first device
assert (
int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0])
< coll.get_world_size()
)
booster = xgboost.train(
{
"tree_method": "hist",
"max_depth": 4,
"device": it.device,
"nthread": n_threads,
},
Xy,
evals=[(Xy, "Train")],
num_boost_round=10,
)
booster.predict(Xy)
def main(tmpdir: str, args: argparse.Namespace) -> None:
n_workers = 2
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
tracker.start()
rabit_args = tracker.worker_args()
def initializer(device: str) -> None:
# Set CUDA device before launching child processes.
if device == "cuda":
# name: LokyProcess-1
lop, sidx = mp.current_process().name.split("-")
idx = int(sidx) - 1 # 1-based indexing from loky
# Assuming two workers for demo.
devices = ",".join([str(idx), str((idx + 1) % n_workers)])
# P0: CUDA_VISIBLE_DEVICES=0,1
# P1: CUDA_VISIBLE_DEVICES=1,0
os.environ["CUDA_VISIBLE_DEVICES"] = devices
setup_rmm()
with get_reusable_executor(
max_workers=n_workers, initargs=(args.device,), initializer=initializer
) as pool:
# Poor man's currying
fn = update_wrapper(
partial(
hist_train, tmpdir=tmpdir, device=args.device, rabit_args=rabit_args
),
hist_train,
)
pool.map(fn, range(n_workers))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
args = parser.parse_args()
if args.device == "cuda":
import cupy as cp
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args)
else:
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args)
|