File: rabit.py

package info (click to toggle)
rabit 0.0~git20200628.74bf00a-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 720 kB
  • sloc: cpp: 5,015; ansic: 710; python: 360; makefile: 306; sh: 136
file content (364 lines) | stat: -rw-r--r-- 10,642 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
Reliable Allreduce and Broadcast Library.

Author: Tianqi Chen
"""
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
import pickle
import ctypes
import os
import platform
import sys
import warnings
import numpy as np

# version information about the doc
__version__ = '1.0'

_LIB = None

def _find_lib_path(dll_name):
    """Find the rabit dynamic library files.

    Returns
    -------
    lib_path: list(string)
       List of all found library path to rabit
    """
    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
    # make pythonpack hack: copy this directory one level upper for setup.py
    dll_path = [curr_path,
                os.path.join(curr_path, '../lib/'),
                os.path.join(curr_path, './lib/')]
    if os.name == 'nt':
        dll_path = [os.path.join(p, dll_name) for p in dll_path]
    else:
        dll_path = [os.path.join(p, dll_name) for p in dll_path]
    lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
    #From github issues, most of installation errors come from machines w/o compilers
    if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
        raise RuntimeError(
            'Cannot find Rabit Libarary in the candicate path, ' +
            'did you install compilers and run build.sh in root path?\n'
            'List of candidates:\n' + ('\n'.join(dll_path)))
    return lib_path

# load in xgboost library
def _loadlib(lib='standard', lib_dll=None):
    """Load rabit library."""
    global _LIB
    if _LIB is not None:
        warnings.warn('rabit.int call was ignored because it has'\
                          ' already been initialized', level=2)
        return

    if lib_dll is not None:
        _LIB = lib_dll
        return

    if lib == 'standard':
        dll_name = 'librabit'
    else:
        dll_name = 'librabit_' + lib

    if os.name == 'nt':
        dll_name += '.dll'
    elif platform.system() == 'Darwin':
        dll_name += '.dylib'
    else:
        dll_name += '.so'

    _LIB = ctypes.cdll.LoadLibrary(_find_lib_path(dll_name)[0])
    _LIB.RabitGetRank.restype = ctypes.c_int
    _LIB.RabitGetWorldSize.restype = ctypes.c_int
    _LIB.RabitVersionNumber.restype = ctypes.c_int

def _unloadlib():
    """Unload rabit library."""
    global _LIB
    del _LIB
    _LIB = None

# reduction operators
MAX = 0
MIN = 1
SUM = 2
BITOR = 3

def init(args=None, lib='standard', lib_dll=None):
    """Intialize the rabit module, call this once before using anything.

    Parameters
    ----------
    args: list of str, optional
        The list of arguments used to initialized the rabit
        usually you need to pass in sys.argv.
        Defaults to sys.argv when it is None.
    lib: {'standard', 'mock', 'mpi'}, optional
        Type of library we want to load
        When cdll is specified
    lib_dll: ctypes.DLL, optional
        The DLL object used as lib.
        When this is presented argument lib will be ignored.
    """
    if args is None:
        args = []
    _loadlib(lib, lib_dll)
    arr = (ctypes.c_char_p * len(args))()

    arr[:] = args
    _LIB.RabitInit(len(args), arr)

def finalize():
    """Finalize the rabit engine.

    Call this function after you finished all jobs.
    """
    _LIB.RabitFinalize()
    _unloadlib()

def get_rank():
    """Get rank of current process.

    Returns
    -------
    rank : int
        Rank of current process.
    """
    ret = _LIB.RabitGetRank()
    return ret

def get_world_size():
    """Get total number workers.

    Returns
    -------
    n : int
        Total number of process.
    """
    ret = _LIB.RabitGetWorldSize()
    return ret

def tracker_print(msg):
    """Print message to the tracker.

    This function can be used to communicate the information of
    the progress to the tracker

    Parameters
    ----------
    msg : str
        The message to be printed to tracker.
    """
    if not isinstance(msg, str):
        msg = str(msg)
    _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))

def get_processor_name():
    """Get the processor name.

    Returns
    -------
    name : str
        the name of processor(host)
    """
    mxlen = 256
    length = ctypes.c_ulong()
    buf = ctypes.create_string_buffer(mxlen)
    _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
    return buf.value

def broadcast(data, root):
    """Broadcast object from one node to all other nodes.

    Parameters
    ----------
    data : any type that can be pickled
        Input data, if current rank does not equal root, this can be None
    root : int
        Rank of the node to broadcast data from.

    Returns
    -------
    object : int
        the result of broadcast.
    """
    rank = get_rank()
    length = ctypes.c_ulong()
    if root == rank:
        assert data is not None, 'need to pass in data when broadcasting'
        s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
        length.value = len(s)
    # run first broadcast
    _LIB.RabitBroadcast(ctypes.byref(length),
                        ctypes.sizeof(ctypes.c_ulong), root)
    if root != rank:
        dptr = (ctypes.c_char * length.value)()
        # run second
        _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
                            length.value, root)
        data = pickle.loads(dptr.raw)
        del dptr
    else:
        _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
                            length.value, root)
        del s
    return data

# enumeration of dtypes
DTYPE_ENUM__ = {
    np.dtype('int8') : 0,
    np.dtype('uint8') : 1,
    np.dtype('int32') : 2,
    np.dtype('uint32') : 3,
    np.dtype('int64') : 4,
    np.dtype('uint64') : 5,
    np.dtype('float32') : 6,
    np.dtype('float64') : 7
}

def allreduce(data, op, prepare_fun=None):
    """Perform allreduce, return the result.

    Parameters
    ----------
    data: numpy array
        Input data.
    op: int
        Reduction operators, can be MIN, MAX, SUM, BITOR
    prepare_fun: function
        Lazy preprocessing function, if it is not None, prepare_fun(data)
        will be called by the function before performing allreduce, to intialize the data
        If the result of Allreduce can be recovered directly,
        then prepare_fun will NOT be called

    Returns
    -------
    result : array_like
        The result of allreduce, have same shape as data

    Notes
    -----
    This function is not thread-safe.
    """
    if not isinstance(data, np.ndarray):
        raise Exception('allreduce only takes in numpy.ndarray')
    buf = data.ravel()
    if buf.base is data.base:
        buf = buf.copy()
    if buf.dtype not in DTYPE_ENUM__:
        raise Exception('data type %s not supported' % str(buf.dtype))
    if prepare_fun is None:
        _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
                            buf.size, DTYPE_ENUM__[buf.dtype],
                            op, None, None)
    else:
        func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
        def pfunc(args):
            """prepare function."""
            prepare_fun(data)
        _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
                            buf.size, DTYPE_ENUM__[buf.dtype],
                            op, func_ptr(pfunc), None)
    return buf


def _load_model(ptr, length):
    """
    Internal function used by the module,
    unpickle a model from a buffer specified by ptr, length
    Arguments:
        ptr: ctypes.POINTER(ctypes._char)
            pointer to the memory region of buffer
        length: int
            the length of buffer
    """
    data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
    return pickle.loads(data.raw)

def load_checkpoint(with_local=False):
    """Load latest check point.

    Parameters
    ----------
    with_local: bool, optional
        whether the checkpoint contains local model

    Returns
    -------
    tuple : tuple
        if with_local: return (version, gobal_model, local_model)
        else return (version, gobal_model)
        if returned version == 0, this means no model has been CheckPointed
        and global_model, local_model returned will be None
    """
    gptr = ctypes.POINTER(ctypes.c_char)()
    global_len = ctypes.c_ulong()
    if with_local:
        lptr = ctypes.POINTER(ctypes.c_char)()
        local_len = ctypes.c_ulong()
        version = _LIB.RabitLoadCheckPoint(
            ctypes.byref(gptr),
            ctypes.byref(global_len),
            ctypes.byref(lptr),
            ctypes.byref(local_len))
        if version == 0:
            return (version, None, None)
        return (version,
                _load_model(gptr, global_len.value),
                _load_model(lptr, local_len.value))
    else:
        version = _LIB.RabitLoadCheckPoint(
            ctypes.byref(gptr),
            ctypes.byref(global_len),
            None, None)
        if version == 0:
            return (version, None)
        return (version,
                _load_model(gptr, global_len.value))

def checkpoint(global_model, local_model=None):
    """Checkpoint the model.

    This means we finished a stage of execution.
    Every time we call check point, there is a version number which will increase by one.

    Parameters
    ----------
    global_model: anytype that can be pickled
        globally shared model/state when calling this function,
        the caller need to gauranttees that global_model is the same in all nodes

    local_model: anytype that can be pickled
       Local model, that is specific to current node/rank.
       This can be None when no local state is needed.

    Notes
    -----
    local_model requires explicit replication of the model for fault-tolerance.
    This will bring replication cost in checkpoint function.
    while global_model do not need explicit replication.
    It is recommended to use global_model if possible.
    """
    sglobal = pickle.dumps(global_model)
    if local_model is None:
        _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
        del sglobal
    else:
        slocal = pickle.dumps(local_model)
        _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
        del slocal
        del sglobal

def version_number():
    """Returns version number of current stored model.

    This means how many calls to CheckPoint we made so far.

    Returns
    -------
    version : int
        Version number of currently stored model
    """
    ret = _LIB.RabitVersionNumber()
    return ret