File: spark_estimator.rst

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (296 lines) | stat: -rw-r--r-- 11,700 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
################################
Distributed XGBoost with PySpark
################################

Starting from version 1.7.0, xgboost supports pyspark estimator APIs.

.. note::

  The integration is only tested on Linux distributions.

.. contents::
  :backlinks: none
  :local:

*************************
XGBoost PySpark Estimator
*************************

SparkXGBRegressor
=================

SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost classification
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
and PySpark ML meta algorithms like CrossValidator/TrainValidationSplit/OneVsRest.

We can create a ``SparkXGBRegressor`` estimator like:

.. code-block:: python

  from xgboost.spark import SparkXGBRegressor
  xgb_regressor = SparkXGBRegressor(
    features_col="features",
    label_col="label",
    num_workers=2,
  )


The above snippet creates a spark estimator which can fit on a spark dataset, and return a
spark model that can transform a spark dataset and generate dataset with prediction
column. We can set almost all of xgboost sklearn estimator parameters as
``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in
spark estimator, and some parameters are replaced with pyspark specific parameters such as
``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor``
doc.

The following code snippet shows how to train a spark xgboost regressor model,
first we need to prepare a training dataset as a spark dataframe contains
"label" column and "features" column(s), the "features" column(s) must be ``pyspark.ml.linalg.Vector``
type or spark array type or a list of feature column names.


.. code-block:: python

  xgb_regressor_model = xgb_regressor.fit(train_spark_dataframe)


The following code snippet shows how to predict test data using a spark xgboost regressor model,
first we need to prepare a test dataset as a spark dataframe contains
"features" and "label" column, the "features" column must be ``pyspark.ml.linalg.Vector``
type or spark array type.

.. code-block:: python

  transformed_test_spark_dataframe = xgb_regressor_model.transform(test_spark_dataframe)


The above snippet code returns a ``transformed_test_spark_dataframe`` that contains the input
dataset columns and an appended column "prediction" representing the prediction results.

SparkXGBClassifier
==================

``SparkXGBClassifier`` estimator has similar API with ``SparkXGBRegressor``, but it has some
pyspark classifier specific params, e.g. ``raw_prediction_col`` and ``probability_col`` parameters.
Correspondingly, by default, ``SparkXGBClassifierModel`` transforming test dataset will
generate result dataset with 3 new columns:

- "prediction": represents the predicted label.
- "raw_prediction": represents the output margin values.
- "probability": represents the prediction probability on each label.


***************************
XGBoost PySpark GPU support
***************************

XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable
efficient training but also utilize their GPUs for the whole PySpark pipeline including
ETL and inference. In below sections, we will walk through an example of training on a
Spark standalone cluster with GPU support. To get started, first we need to install some
additional packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.

Prepare the necessary packages
==============================

Aside from the PySpark and XGBoost modules, we also need the `cuDF
<https://docs.rapids.ai/api/cudf/stable/>`_ package for handling Spark dataframe. We
recommend using either Conda or Virtualenv to manage python dependencies for PySpark
jobs. Please refer to `How to Manage Python Dependencies in PySpark
<https://www.databricks.com/blog/2020/12/22/how-to-manage-python-dependencies-in-pyspark.html>`_
for more details on PySpark dependency management.

In short, to create a Python environment that can be sent to a remote cluster using
virtualenv and pip:

.. code-block:: bash

  python -m venv xgboost_env
  source xgboost_env/bin/activate
  pip install pyarrow pandas venv-pack xgboost
  # https://docs.rapids.ai/install#pip-install
  pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com
  venv-pack -o xgboost_env.tar.gz

With Conda:

.. code-block:: bash

  conda create -y -n xgboost_env -c conda-forge conda-pack python=3.9
  conda activate xgboost_env
  # use conda when the supported version of xgboost (1.7) is released on conda-forge
  pip install xgboost
  conda install cudf pyarrow pandas -c rapids -c nvidia -c conda-forge
  conda pack -f -o xgboost_env.tar.gz


Write your PySpark application
==============================

Below snippet is a small example for training xgboost model with PySpark. Notice that we are
using a list of feature names instead of vector type as the input. The parameter ``"device=cuda"``
specifically indicates that the training will be performed on a GPU.

.. code-block:: python

  from xgboost.spark import SparkXGBRegressor
  spark = SparkSession.builder.getOrCreate()

  # read data into spark dataframe
  train_data_path = "xxxx/train"
  train_df = spark.read.parquet(data_path)

  test_data_path = "xxxx/test"
  test_df = spark.read.parquet(test_data_path)

  # assume the label column is named "class"
  label_name = "class"

  # get a list with feature column names
  feature_names = [x.name for x in train_df.schema if x.name != label_name]

  # create a xgboost pyspark regressor estimator and set device="cuda"
  regressor = SparkXGBRegressor(
    features_col=feature_names,
    label_col=label_name,
    num_workers=2,
    device="cuda",
  )

  # train and return the model
  model = regressor.fit(train_df)

  # predict on test data
  predict_df = model.transform(test_df)
  predict_df.show()

Like other distributed interfaces, the ``device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).

.. _stage-level-scheduling:

Submit the PySpark application
==============================

Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.

Starting from XGBoost 2.0.1, stage-level scheduling is automatically enabled. Therefore,
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
enable running multiple tasks in parallel during the ETL phase. An example configuration
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
using a XGBoost version earlier than 2.0.1 or a Spark standalone cluster version below 3.4.0,
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.

.. note::

  As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
  However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.

.. code-block:: bash

  export PYSPARK_DRIVER_PYTHON=python
  export PYSPARK_PYTHON=./environment/bin/python

  spark-submit \
    --master spark://<master-ip>:7077 \
    --conf spark.executor.cores=12 \
    --conf spark.task.cpus=1 \
    --conf spark.executor.resource.gpu.amount=1 \
    --conf spark.task.resource.gpu.amount=0.08 \
    --archives xgboost_env.tar.gz#environment \
    xgboost_app.py

The above command submits the xgboost pyspark application with the python environment created by pip or conda,
specifying a request for 1 GPU and 12 CPUs per executor. So you can see, a total of 12 tasks per executor will be
executed concurrently during the ETL phase.

Model Persistence
=================

Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save``
and ``load`` methods:

.. code-block:: python

  regressor = SparkXGBRegressor()
  model = regressor.fit(train_df)
  # save the model
  model.save("/tmp/xgboost-pyspark-model")
  # load the model
  model2 = SparkXGBRankerModel.load("/tmp/xgboost-pyspark-model")

To export the underlying booster model used by XGBoost:

.. code-block:: python

  regressor = SparkXGBRegressor()
  model = regressor.fit(train_df)
  # the same booster object returned by xgboost.train
  booster: xgb.Booster = model.get_booster()
  booster.predict(...)
  booster.save_model("model.json") # or model.ubj, depending on your choice of format.

This booster is not only shared by other Python interfaces but also used by all the
XGBoost bindings including the C, Java, and the R package. Lastly, one can extract the
booster file directly from a saved spark estimator without going through the getter:

.. code-block:: python

  import xgboost as xgb
  bst = xgb.Booster()
  # Loading the model saved in previous snippet
  bst.load_model("/tmp/xgboost-pyspark-model/model/part-00000")


Accelerate the whole pipeline for xgboost pyspark
=================================================

With `RAPIDS Accelerator for Apache Spark <https://nvidia.github.io/spark-rapids/>`_, you
can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost
pyspark without the need for any code modifications. Likewise, you have the option to configure
the ``"spark.task.resource.gpu.amount"`` setting as a fractional value, enabling a higher
number of tasks to be executed in parallel during the ETL phase. please refer to
:ref:`stage-level-scheduling` for more details.


An example submit command is shown below with additional spark configurations and dependencies:

.. code-block:: bash

  export PYSPARK_DRIVER_PYTHON=python
  export PYSPARK_PYTHON=./environment/bin/python

  spark-submit \
    --master spark://<master-ip>:7077 \
    --conf spark.executor.cores=12 \
    --conf spark.task.cpus=1 \
    --conf spark.executor.resource.gpu.amount=1 \
    --conf spark.task.resource.gpu.amount=0.08 \
    --packages com.nvidia:rapids-4-spark_2.12:24.04.1 \
    --conf spark.plugins=com.nvidia.spark.SQLPlugin \
    --conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
    --archives xgboost_env.tar.gz#environment \
    xgboost_app.py

When rapids plugin is enabled, both of the JVM rapids plugin and the cuDF Python package
are required. More configuration options can be found in the RAPIDS link above along with
details on the plugin.

Advanced Usage
==============

XGBoost needs to repartition the input dataset to the num_workers to ensure there will be
num_workers training tasks running at the same time. However, repartition is a costly operation.

If there is a scenario where reading the data from source and directly fitting it to XGBoost
without introducing the shuffle stage, users can avoid the need for repartitioning by setting
the Spark configuration parameters ``spark.sql.files.maxPartitionNum`` and
``spark.sql.files.minPartitionNum`` to num_workers. This tells Spark to automatically partition
the dataset into the desired number of partitions.

However, if the input dataset is skewed (i.e. the data is not evenly distributed), setting
the partition number to num_workers may not be efficient. In this case, users can set
the ``force_repartition=true`` option to explicitly force XGBoost to repartition the dataset,
even if the partition number is already equal to num_workers. This ensures the data is evenly
distributed across the workers.