File: osc_ucx.h

package info (click to toggle)
openmpi 5.0.7-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 202,312 kB
  • sloc: ansic: 612,441; makefile: 42,495; sh: 11,230; javascript: 9,244; f90: 7,052; java: 6,404; perl: 5,154; python: 1,856; lex: 740; fortran: 61; cpp: 20; tcl: 12
file content (278 lines) | stat: -rw-r--r-- 12,696 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
/*
 * Copyright (C) Mellanox Technologies Ltd. 2001-2017. ALL RIGHTS RESERVED.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 * $HEADER$
 */

#ifndef OMPI_OSC_UCX_H
#define OMPI_OSC_UCX_H

#include <ucp/api/ucp.h>

#include "ompi/group/group.h"
#include "ompi/communicator/communicator.h"
#include "opal/mca/common/ucx/common_ucx.h"
#include "opal/mca/common/ucx/common_ucx_wpool.h"
#include "opal/mca/shmem/shmem.h"
#include "opal/mca/shmem/base/base.h"

#define OSC_UCX_ASSERT  MCA_COMMON_UCX_ASSERT
#define OSC_UCX_ERROR   MCA_COMMON_UCX_ERROR
#define OSC_UCX_VERBOSE MCA_COMMON_UCX_VERBOSE

#define OMPI_OSC_UCX_POST_PEER_MAX 32
#define OMPI_OSC_UCX_ATTACH_MAX    48
#define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN  1024


typedef struct ompi_osc_ucx_component {
    ompi_osc_base_component_t super;
    opal_common_ucx_wpool_t *wpool;
    bool enable_mpi_threads;
    opal_free_list_t requests; /* request free list for the r* communication variants */
    opal_free_list_t accumulate_requests; /* request free list for the r* communication variants */
    bool env_initialized; /* UCX environment is initialized or not */
    bool priority_is_set; /* Is ucp_ctx created and component priority has been set */
    int comm_world_size;
    ucp_ep_h *endpoints;
    int num_modules;
    bool no_locks; /* Default value of the no_locks info key for new windows */
    bool acc_single_intrinsic;
    unsigned int priority;
    /* directory where to place backing files */
    char *backing_directory;
} ompi_osc_ucx_component_t;

OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;

#define OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(_module)                               \
    do {                                                                            \
        opal_atomic_add_fetch_64(&_module->ctx->num_incomplete_req_ops, 1);         \
    } while(0);

#define OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(_module)                               \
    do {                                                                            \
        opal_atomic_add_fetch_64(&_module->ctx->num_incomplete_req_ops, -1);        \
    } while(0);

typedef enum ompi_osc_ucx_epoch {
    NONE_EPOCH,
    FENCE_EPOCH,
    POST_WAIT_EPOCH,
    START_COMPLETE_EPOCH,
    PASSIVE_EPOCH,
    PASSIVE_ALL_EPOCH
} ompi_osc_ucx_epoch_t;

typedef struct ompi_osc_ucx_epoch_type {
    ompi_osc_ucx_epoch_t access;
    ompi_osc_ucx_epoch_t exposure;
} ompi_osc_ucx_epoch_type_t;

#define TARGET_LOCK_UNLOCKED  ((uint64_t)(0x0000000000000000ULL))
#define TARGET_LOCK_EXCLUSIVE ((uint64_t)(0x0000000100000000ULL))

#define OSC_UCX_IOVEC_MAX 128

#define OSC_UCX_STATE_LOCK_OFFSET 0
#define OSC_UCX_STATE_REQ_FLAG_OFFSET sizeof(uint64_t)
#define OSC_UCX_STATE_ACC_LOCK_OFFSET (sizeof(uint64_t) * 2)
#define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
#define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
#define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
#define OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (6 + OMPI_OSC_UCX_POST_PEER_MAX))

typedef struct ompi_osc_dynamic_win_info {
    uint64_t base;
    size_t size;
    char mem_addr[OMPI_OSC_UCX_MEM_ADDR_MAX_LEN];
} ompi_osc_dynamic_win_info_t;

typedef struct ompi_osc_local_dynamic_win_info {
    opal_common_ucx_wpmem_t *mem;
    char *my_mem_addr;
    int my_mem_addr_size;
    int refcnt;
} ompi_osc_local_dynamic_win_info_t;

typedef struct ompi_osc_ucx_state {
    volatile uint64_t lock;
    volatile uint64_t req_flag;
    volatile uint64_t acc_lock;
    volatile uint64_t complete_count; /* # msgs received from complete processes */
    volatile uint64_t post_index;
    volatile uint64_t post_state[OMPI_OSC_UCX_POST_PEER_MAX];
    volatile uint64_t dynamic_lock;
    volatile uint64_t dynamic_win_count;
    volatile ompi_osc_dynamic_win_info_t dynamic_wins[OMPI_OSC_UCX_ATTACH_MAX];
} ompi_osc_ucx_state_t;

typedef struct ompi_osc_ucx_mem_ranges {
    uint64_t base;
    uint64_t tail;
} ompi_osc_ucx_mem_ranges_t;

typedef struct ompi_osc_ucx_module {
    ompi_osc_base_module_t super;
    struct ompi_communicator_t *comm;
    int flavor;
    size_t size;
    uint64_t *addrs;
    uint64_t *state_addrs;
    uint64_t *comm_world_ranks;
    int disp_unit; /* if disp_unit >= 0, then everyone has the same
                    * disp unit size; if disp_unit == -1, then we
                    * need to look at disp_units */
    int *disp_units;

    ompi_osc_ucx_state_t state; /* remote accessible flags */
    ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
    ompi_osc_ucx_epoch_type_t epoch_type;
    ompi_group_t *start_group;
    ompi_group_t *post_group;
    opal_hash_table_t outstanding_locks;
    opal_list_t pending_posts;
    int lock_count;
    int post_count;
    uint64_t req_result;
    int *start_grp_ranks;
    bool lock_all_is_nocheck;
    bool no_locks;
    bool acc_single_intrinsic;
    opal_common_ucx_ctx_t *ctx;
    opal_common_ucx_wpmem_t *mem;
    opal_common_ucx_wpmem_t *state_mem;
    ompi_osc_ucx_mem_ranges_t *epoc_outstanding_ops_mems;
    bool skip_sync_check;
    bool noncontig_shared_win;
    size_t *sizes;
    /* in shared windows, shmem_addrs can be used for direct load store to
     * remote windows */
    uint64_t *shmem_addrs;
    void *segment_base;
    /** opal shared memory structure for the shared memory segment */
    opal_shmem_ds_t seg_ds;
} ompi_osc_ucx_module_t;

typedef enum locktype {
    LOCK_EXCLUSIVE,
    LOCK_SHARED
} lock_type_t;

typedef struct ompi_osc_ucx_lock {
    opal_object_t super;
    int target_rank;
    lock_type_t type;
    bool is_nocheck;
} ompi_osc_ucx_lock_t;

#define OSC_UCX_GET_EP(_module, rank_) (mca_osc_ucx_component.endpoints[_module->comm_world_ranks[rank_]])
#define OSC_UCX_GET_DISP(module_, rank_) ((module_->disp_unit < 0) ? module_->disp_units[rank_] : module_->disp_unit)

#define OSC_UCX_GET_DEFAULT_EP(_ep_ptr, _module, _target)                   \
    if (opal_common_ucx_thread_enabled) {                  \
        _ep_ptr = NULL;                                                     \
    } else {                                                                \
        _ep_ptr = (ucp_ep_h *)&(OSC_UCX_GET_EP(_module, _target));          \
    }

extern size_t ompi_osc_ucx_outstanding_ops_flush_threshold;

int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
        int *disp_unit, void * baseptr);
int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len);
int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base);
int ompi_osc_ucx_free(struct ompi_win_t *win);

int ompi_osc_ucx_put(const void *origin_addr, int origin_count,
                     struct ompi_datatype_t *origin_dt,
                     int target, ptrdiff_t target_disp, int target_count,
                     struct ompi_datatype_t *target_dt, struct ompi_win_t *win);
int ompi_osc_ucx_get(void *origin_addr, int origin_count,
                     struct ompi_datatype_t *origin_dt,
                     int target, ptrdiff_t target_disp, int target_count,
                     struct ompi_datatype_t *target_dt, struct ompi_win_t *win);
int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
                            struct ompi_datatype_t *origin_dt,
                            int target, ptrdiff_t target_disp, int target_count,
                            struct ompi_datatype_t *target_dt,
                            struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
                            struct ompi_datatype_t *origin_dt,
                            int target, ptrdiff_t target_disp, int target_count,
                            struct ompi_datatype_t *target_dt,
                            struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
                                  void *result_addr, struct ompi_datatype_t *dt,
                                  int target, ptrdiff_t target_disp,
                                  struct ompi_win_t *win);
int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
                              struct ompi_datatype_t *dt, int target,
                              ptrdiff_t target_disp, struct ompi_op_t *op,
                              struct ompi_win_t *win);
int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
                                struct ompi_datatype_t *origin_datatype,
                                void *result_addr, int result_count,
                                struct ompi_datatype_t *result_datatype,
                                int target_rank, ptrdiff_t target_disp,
                                int target_count, struct ompi_datatype_t *target_datatype,
                                struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
                                struct ompi_datatype_t *origin_datatype,
                                void *result_addr, int result_count,
                                struct ompi_datatype_t *result_datatype,
                                int target_rank, ptrdiff_t target_disp,
                                int target_count, struct ompi_datatype_t *target_datatype,
                                struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
                      struct ompi_datatype_t *origin_dt,
                      int target, ptrdiff_t target_disp, int target_count,
                      struct ompi_datatype_t *target_dt,
                      struct ompi_win_t *win, struct ompi_request_t **request);
int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
                      struct ompi_datatype_t *origin_dt,
                      int target, ptrdiff_t target_disp, int target_count,
                      struct ompi_datatype_t *target_dt, struct ompi_win_t *win,
                      struct ompi_request_t **request);
int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
                             struct ompi_datatype_t *origin_dt,
                             int target, ptrdiff_t target_disp, int target_count,
                             struct ompi_datatype_t *target_dt, struct ompi_op_t *op,
                             struct ompi_win_t *win, struct ompi_request_t **request);
int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
                                 struct ompi_datatype_t *origin_datatype,
                                 void *result_addr, int result_count,
                                 struct ompi_datatype_t *result_datatype,
                                 int target_rank, ptrdiff_t target_disp, int target_count,
                                 struct ompi_datatype_t *target_datatype,
                                 struct ompi_op_t *op, struct ompi_win_t *win,
                                 struct ompi_request_t **request);

int ompi_osc_ucx_fence(int mpi_assert, struct ompi_win_t *win);
int ompi_osc_ucx_start(struct ompi_group_t *group, int mpi_assert, struct ompi_win_t *win);
int ompi_osc_ucx_complete(struct ompi_win_t *win);
int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_win_t *win);
int ompi_osc_ucx_wait(struct ompi_win_t *win);
int ompi_osc_ucx_test(struct ompi_win_t *win, int *flag);

int ompi_osc_ucx_lock(int lock_type, int target, int mpi_assert, struct ompi_win_t *win);
int ompi_osc_ucx_unlock(int target, struct ompi_win_t *win);
int ompi_osc_ucx_lock_all(int mpi_assert, struct ompi_win_t *win);
int ompi_osc_ucx_unlock_all(struct ompi_win_t *win);
int ompi_osc_ucx_sync(struct ompi_win_t *win);
int ompi_osc_ucx_flush(int target, struct ompi_win_t *win);
int ompi_osc_ucx_flush_all(struct ompi_win_t *win);
int ompi_osc_ucx_flush_local(int target, struct ompi_win_t *win);
int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);

int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
                                           int min_index, int max_index,
                                           uint64_t base, size_t len, int *insert);
int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target);
int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target);

#endif /* OMPI_OSC_UCX_H */