File: ucx_wrapper.h

package info (click to toggle)
mpich 4.3.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 101,184 kB
  • sloc: ansic: 1,040,629; cpp: 82,270; javascript: 40,763; perl: 27,933; python: 16,041; sh: 14,676; xml: 14,418; f90: 12,916; makefile: 9,270; fortran: 8,046; java: 4,635; asm: 324; ruby: 103; awk: 27; lisp: 19; php: 8; sed: 4
file content (417 lines) | stat: -rw-r--r-- 11,746 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
/*
 * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED.
 *
 * See file LICENSE for terms.
 */

#ifndef IODEMO_UCX_WRAPPER_H_
#define IODEMO_UCX_WRAPPER_H_

#include <ucp/api/ucp.h>
#include <ucs/algorithm/crc.h>
#include <ucs/datastruct/list.h>
#include <ucs/sys/math.h>
#include <ucs/sys/sock.h>
#include <deque>
#include <exception>
#include <iostream>
#include <list>
#include <map>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include <queue>
#include <sys/epoll.h>

#define MAX_LOG_PREFIX_SIZE   64

/* Forward declarations */
class UcxConnection;
struct ucx_request;

// Holds details of arrived AM message
struct UcxAmDesc {
    UcxAmDesc(void *data, const ucp_am_recv_param_t *param) :
        _data(data), _param(param) {
    }

    void                         *_data;
    const ucp_am_recv_param_t    *_param;
};

/*
 * UCX callback for send/receive completion
 */
class UcxCallback {
public:
    virtual ~UcxCallback();
    virtual void operator()(ucs_status_t status) = 0;
};


/*
 * Empty callback singleton
 */
class EmptyCallback : public UcxCallback {
public:
    /// @override
    virtual void operator()(ucs_status_t status);

    static EmptyCallback* get();
};


/*
 * Logger which can be enabled/disabled
 */
class UcxLog {
public:
    static bool use_human_time;

    UcxLog(const char* prefix, bool enable = true,
           std::ostream *os = &std::cout, bool abort = false);
    ~UcxLog();

    template<typename T>
    UcxLog& operator<<(const T &t) {
        if (_ss != NULL) {
            (*_ss) << t;
        }
        return *this;
    }

private:
    std::stringstream        *_ss;
    std::ostream             *_os;
    bool                     _abort;
};


/**
 * Holds UCX global context and worker
 */
class UcxContext {
    class UcxAcceptCallback : public UcxCallback {
    public:
        UcxAcceptCallback(UcxContext &context, UcxConnection &connection);

        virtual void operator()(ucs_status_t status);

    private:
        UcxContext    &_context;
        UcxConnection &_connection;
    };

protected:
    class UcxDisconnectCallback : public UcxCallback {
    public:
        virtual void operator()(ucs_status_t status);
    };

public:
    typedef std::vector<uint8_t> iomsg_buffer_t;

    static const uint64_t CLIENT_ID_UNDEFINED = 0;
    static ucp_request_param_t recv_param;

    UcxContext(size_t iomsg_size, double connect_timeout, bool use_am,
               bool use_epoll = false,
               uint64_t client_id = CLIENT_ID_UNDEFINED);

    virtual ~UcxContext();

    bool init(const char *name);

    bool listen(const struct sockaddr* saddr, size_t addrlen);

    void progress(unsigned count = 1);

    static const std::string sockaddr_str(const struct sockaddr* saddr,
                                          size_t addrlen);

    static double get_time();

    static void *malloc(size_t size, const char *name);

    static void *memalign(size_t alignment, size_t size, const char *name);

    static void free(void *ptr);

    bool map_buffer(size_t length, void *address, ucp_mem_h *memh);

    bool unmap_buffer(ucp_mem_h memh);

protected:

    // Called when new IO message is received
    virtual void dispatch_io_message(UcxConnection* conn, const void *buffer,
                                     size_t length) = 0;

    // Called when new AM message is received
    // (note IO message can be bundled with data)
    virtual void dispatch_am_message(UcxConnection* conn, const void *hdr,
                                     size_t hdr_length,
                                     const UcxAmDesc &data_desc) = 0;

    // Called when there is a fatal failure on the connection
    virtual void dispatch_connection_error(UcxConnection* conn) = 0;

    // Called when new server connection is accepted
    virtual void dispatch_connection_accepted(UcxConnection* conn);

    void destroy_connections();

    void wait_disconnected_connections();

    void destroy_listener();

    static inline void *ucx_am_get_data(const UcxAmDesc &desc)
    {
        return desc._data;
    }

    static inline bool ucx_am_is_rndv(const UcxAmDesc &desc)
    {
        return desc._param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV;
    }

private:
    typedef enum {
        WAIT_STATUS_OK,
        WAIT_STATUS_FAILED,
        WAIT_STATUS_TIMED_OUT
    } wait_status_t;

    typedef struct {
        ucp_conn_request_h conn_request;
        struct timeval     arrival_time;
    } conn_req_t;

    typedef std::map<uint64_t, UcxConnection*> conn_map_t;

    typedef std::vector<std::pair<double, UcxConnection*> > timeout_conn_t;

    friend class UcxConnection;

    static const ucp_tag_t IOMSG_TAG = 1ull << 63;

    static uint32_t get_next_conn_id();

    static void request_init(void *request);

    static void request_reset(ucx_request *r);

    static void request_release(void *request);

    static void connect_callback(ucp_conn_request_h conn_req, void *arg);

    static void iomsg_recv_callback(void *request, ucs_status_t status,
                                    ucp_tag_recv_info *info);

    static ucs_status_t am_recv_callback(void *arg, const void *header,
                                         size_t header_length,
                                         void *data, size_t length,
                                         const ucp_am_recv_param_t *param);

    ucp_worker_h worker() const;

    double connect_timeout() const;

    int is_timeout_elapsed(struct timeval const *tv_prior, double timeout);

    ucs_status_t epoll_init();

    bool progress_worker_event();

    void progress_timed_out_conns();

    void progress_conn_requests();

    void progress_io_message();

    void progress_failed_connections();

    void progress_disconnected_connections();

    wait_status_t wait_completion(ucs_status_ptr_t status_ptr, const char *title,
                                  double timeout = 1e6);

    void recv_io_message();

    void add_connection(UcxConnection *conn);

    void remove_connection(UcxConnection *conn);

    timeout_conn_t::iterator find_connection_inprogress(UcxConnection *conn);

    void remove_connection_inprogress(UcxConnection *conn);

    void move_connection_to_disconnecting(UcxConnection *conn);

    bool is_in_disconnecting_list(UcxConnection *conn)
    {
        return std::find(_disconnecting_conns.begin(),
                         _disconnecting_conns.end(), conn) !=
                _disconnecting_conns.end();
    }

    void handle_connection_error(UcxConnection *conn);

    void destroy_worker();

    void set_am_handler(ucp_am_recv_callback_t cb, void *arg);

    ucp_context_h               _context;
    ucp_worker_h                _worker;
    ucp_listener_h              _listener;
    conn_map_t                  _conns;
    std::deque<conn_req_t>      _conn_requests;
    timeout_conn_t              _conns_in_progress; // ordered in time
    std::deque<UcxConnection *> _failed_conns;
    std::list<UcxConnection *>  _disconnecting_conns;
    ucx_request                 *_iomsg_recv_request;
    iomsg_buffer_t              _iomsg_buffer;
    double                      _connect_timeout;
    bool                        _use_am;
    int                         _worker_fd;
    int                         _epoll_fd;
    uint64_t                    _client_id;
};


class UcxConnection {
public:
    UcxConnection(UcxContext &context, bool use_am);

    ~UcxConnection();

    void connect(const struct sockaddr *src_saddr,
                 const struct sockaddr *dst_saddr,
                 socklen_t addrlen,
                 UcxCallback *callback);

    void accept(ucp_conn_request_h conn_req, UcxCallback *callback);

    /**
     * The connection will be destroyed automatically after callback is called.
     */
    void disconnect(UcxCallback *callback);

    bool disconnect_progress();

    bool send_io_message(const void *buffer, size_t length,
                         UcxCallback* callback = EmptyCallback::get());

    bool send_data(const void *buffer, size_t length, ucp_mem_h memh,
                   uint32_t sn, UcxCallback *callback = EmptyCallback::get());

    bool recv_data(void *buffer, size_t length, ucp_mem_h memh, uint32_t sn,
                   UcxCallback *callback = EmptyCallback::get());

    bool send_am(const void *meta, size_t meta_length, const void *buffer,
                 size_t length, ucp_mem_h memh,
                 UcxCallback *callback = EmptyCallback::get());

    bool recv_am_data(void *buffer, size_t length, ucp_mem_h memh,
                      const UcxAmDesc &data_desc,
                      UcxCallback *callback = EmptyCallback::get());

    void iomsg_recv_defer(const UcxContext::iomsg_buffer_t &iomsg,
                          size_t iomsg_length);

    void cancel_all();

    uint64_t id() const {
        return _conn_id;
    }

    ucs_status_t ucx_status() const {
        return _ucx_status;
    }

    const char* get_log_prefix() const {
        return _log_prefix;
    }

    bool is_established() const {
        return _establish_cb == NULL;
    }

    const std::string& get_peer_name() const {
        return _remote_address;
    }

    bool is_disconnecting() const {
        return _disconnect_cb != NULL;
    }

    void handle_connection_error(ucs_status_t status);

    static size_t get_num_instances() {
        return _num_instances;
    }

private:
    static ucp_tag_t make_data_tag(uint32_t conn_id, uint32_t sn);

    static ucp_tag_t make_iomsg_tag(uint32_t conn_id, uint32_t sn);

    static void stream_recv_callback(void *request, ucs_status_t status,
                                     size_t recv_len);

    static void common_request_callback(void *request, ucs_status_t status);

    static void am_data_recv_callback(void *request, ucs_status_t status,
                                      size_t length, void *user_data);

    static void data_recv_callback(void *request, ucs_status_t status,
                                   ucp_tag_recv_info *info);

    static void error_callback(void *arg, ucp_ep_h ep, ucs_status_t status);

    void set_log_prefix(const struct sockaddr* saddr, socklen_t addrlen);

    void print_addresses();

    void connect_common(ucp_ep_params_t &ep_params, UcxCallback *callback);

    void connect_tag(UcxCallback *callback);

    void connect_am(UcxCallback *callback);

    void established(ucs_status_t status);

    bool send_common(const void *buffer, size_t length, ucp_mem_h memh,
                     ucp_tag_t tag, UcxCallback *callback);

    void request_started(ucx_request *r);

    void request_completed(ucx_request *r);

    void ep_close(enum ucp_ep_close_mode mode);

    bool process_request(const char *what, ucs_status_ptr_t ptr_status,
                         UcxCallback* callback);

    static void invoke_callback(UcxCallback *&cb, ucs_status_t status);

    static unsigned _num_instances;

    UcxContext                             &_context;
    UcxCallback                            *_establish_cb;
    UcxCallback                            *_disconnect_cb;
    uint64_t                               _conn_id;
    uint64_t                               _remote_conn_id;
    char                                   _log_prefix[MAX_LOG_PREFIX_SIZE];
    ucp_ep_h                               _ep;
    std::string                            _remote_address;
    void                                   *_close_request;
    ucs_list_link_t                        _all_requests;
    ucs_status_t                           _ucx_status;
    bool                                   _use_am;
    std::queue<UcxContext::iomsg_buffer_t> _iomsg_recv_backlog;
};

#endif