1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
|
/*
* Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED.
* Copyright (c) 2016 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
* $HEADER$
*/
#include "pml_ucx_request.h"
#include "ompi/mca/pml/base/pml_base_bsend.h"
#include "ompi/message/message.h"
#include "ompi/runtime/ompi_spc.h"
#include "ompi/request/request.h"
#include <inttypes.h>
static int mca_pml_ucx_request_free(ompi_request_t **rptr)
{
ompi_request_t *req = *rptr;
PML_UCX_VERBOSE(9, "free request *%p=%p", (void*)rptr, (void*)req);
*rptr = MPI_REQUEST_NULL;
mca_pml_ucx_request_reset(req);
ucp_request_free(req);
return OMPI_SUCCESS;
}
int mca_pml_ucx_request_cancel(ompi_request_t *req, int flag)
{
ucp_request_cancel(ompi_pml_ucx.ucp_worker, req);
return OMPI_SUCCESS;
}
#if MPI_VERSION >= 4
int mca_pml_ucx_request_cancel_send(ompi_request_t *req, int flag)
{
mca_pml_cancel_send_callback(req, flag);
return mca_pml_ucx_request_cancel(req, flag);
}
#endif
__opal_attribute_always_inline__ static inline void
mca_pml_ucx_send_completion_internal(void *request, ucs_status_t status)
{
ompi_request_t *req = request;
PML_UCX_VERBOSE(8, "send request %p completed with status %s", (void*)req,
ucs_status_string(status));
mca_pml_ucx_set_send_status(&req->req_status, status);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
ompi_request_complete(req, true);
}
__opal_attribute_always_inline__ static inline void
mca_pml_ucx_bsend_completion_internal(void *request, ucs_status_t status)
{
ompi_request_t *req = request;
PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req,
req->req_complete_cb_data, ucs_status_string(status));
mca_pml_base_bsend_request_free(req->req_complete_cb_data);
req->req_complete_cb_data = NULL;
mca_pml_ucx_set_send_status(&req->req_status, status);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
mca_pml_ucx_request_free(&req);
}
__opal_attribute_always_inline__ static inline void
mca_pml_ucx_recv_completion_internal(void *request, ucs_status_t status,
const ucp_tag_recv_info_t *info)
{
ompi_request_t *req = request;
PML_UCX_VERBOSE(8, "receive request %p completed with status %s tag %"PRIx64" len %zu",
(void*)req, ucs_status_string(status), info->sender_tag,
info->length);
SPC_USER_OR_MPI(PML_UCX_TAG_GET_MPI_TAG(info->sender_tag), info->length,
OMPI_SPC_BYTES_RECEIVED_USER, OMPI_SPC_BYTES_RECEIVED_MPI);
mca_pml_ucx_set_recv_status(&req->req_status, status, info);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
ompi_request_complete(req, true);
}
void mca_pml_ucx_send_completion(void *request, ucs_status_t status)
{
mca_pml_ucx_send_completion_internal(request, status);
}
void mca_pml_ucx_send_completion_empty(void *request, ucs_status_t status)
{
/* empty */
}
void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status)
{
mca_pml_ucx_bsend_completion_internal(request, status);
}
void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
mca_pml_ucx_recv_completion_internal(request, status, info);
}
void mca_pml_ucx_send_nbx_completion(void *request, ucs_status_t status,
void *user_data)
{
mca_pml_ucx_send_completion_internal(request, status);
}
void mca_pml_ucx_bsend_nbx_completion(void *request, ucs_status_t status,
void *user_data)
{
mca_pml_ucx_bsend_completion_internal(request, status);
}
void mca_pml_ucx_recv_nbx_completion(void *request, ucs_status_t status,
const ucp_tag_recv_info_t *info,
void *user_data)
{
mca_pml_ucx_recv_completion_internal(request, status, info);
}
static void mca_pml_ucx_persistent_request_detach(mca_pml_ucx_persistent_request_t *preq,
ompi_request_t *tmp_req)
{
tmp_req->req_complete_cb_data = NULL;
preq->tmp_req = NULL;
}
inline void
mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *preq,
ompi_request_t *tmp_req)
{
preq->ompi.req_status = tmp_req->req_status;
mca_pml_ucx_request_reset(tmp_req);
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
ucp_request_free(tmp_req);
ompi_request_complete(&preq->ompi, true);
}
static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
{
mca_pml_ucx_persistent_request_t *preq;
ompi_request_complete(tmp_req, false);
preq = (mca_pml_ucx_persistent_request_t*)tmp_req->req_complete_cb_data;
if (preq != NULL) {
PML_UCX_ASSERT(preq->tmp_req != NULL);
mca_pml_ucx_persistent_request_complete(preq, tmp_req);
}
}
void mca_pml_ucx_psend_completion(void *request, ucs_status_t status)
{
ompi_request_t *tmp_req = request;
PML_UCX_VERBOSE(8, "persistent send request %p completed with status %s",
(void*)tmp_req, ucs_status_string(status));
mca_pml_ucx_set_send_status(&tmp_req->req_status, status);
mca_pml_ucx_preq_completion(tmp_req);
}
void mca_pml_ucx_precv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
ompi_request_t *tmp_req = request;
PML_UCX_VERBOSE(8, "persistent receive request %p completed with status %s tag %"PRIx64" len %zu",
(void*)tmp_req, ucs_status_string(status), info->sender_tag,
info->length);
mca_pml_ucx_set_recv_status(&tmp_req->req_status, status, info);
mca_pml_ucx_preq_completion(tmp_req);
}
static void mca_pml_ucx_request_init_common(ompi_request_t* ompi_req,
bool req_persistent,
ompi_request_state_t state,
ompi_request_free_fn_t req_free,
ompi_request_cancel_fn_t req_cancel)
{
OMPI_REQUEST_INIT(ompi_req, req_persistent);
ompi_req->req_type = OMPI_REQUEST_PML;
ompi_req->req_state = state;
ompi_req->req_start = mca_pml_ucx_start;
ompi_req->req_free = req_free;
ompi_req->req_cancel = req_cancel;
/* This field is used to attach persistent request to a temporary req.
* Receive (ucp_tag_recv_nb) may call completion callback
* before the field is set. If the field is not NULL then mca_pml_ucx_preq_completion()
* will try to complete bogus persistent request.
*/
ompi_req->req_complete_cb_data = NULL;
}
void mca_pml_ucx_request_init(void *request)
{
ompi_request_t* ompi_req = request;
OBJ_CONSTRUCT(ompi_req, ompi_request_t);
mca_pml_ucx_request_init_common(ompi_req, false, OMPI_REQUEST_ACTIVE,
mca_pml_ucx_request_free,
mca_pml_ucx_request_cancel);
}
void mca_pml_ucx_request_cleanup(void *request)
{
ompi_request_t* ompi_req = request;
ompi_req->req_state = OMPI_REQUEST_INVALID;
OMPI_REQUEST_FINI(ompi_req);
OBJ_DESTRUCT(ompi_req);
}
static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr)
{
mca_pml_ucx_persistent_request_t* preq = (mca_pml_ucx_persistent_request_t*)*rptr;
ompi_request_t *tmp_req = preq->tmp_req;
preq->ompi.req_state = OMPI_REQUEST_INVALID;
if (tmp_req != NULL) {
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
ucp_request_free(tmp_req);
}
OMPI_DATATYPE_RELEASE(preq->ompi_datatype);
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super);
*rptr = MPI_REQUEST_NULL;
return OMPI_SUCCESS;
}
static int mca_pml_ucx_persistent_request_cancel(ompi_request_t *req, int flag)
{
mca_pml_ucx_persistent_request_t* preq = (mca_pml_ucx_persistent_request_t*)req;
if (preq->tmp_req != NULL) {
ucp_request_cancel(ompi_pml_ucx.ucp_worker, preq->tmp_req);
}
return OMPI_SUCCESS;
}
static void mca_pml_ucx_persisternt_request_construct(mca_pml_ucx_persistent_request_t* req)
{
mca_pml_ucx_request_init_common(&req->ompi, true, OMPI_REQUEST_INACTIVE,
mca_pml_ucx_persistent_request_free,
mca_pml_ucx_persistent_request_cancel);
req->tmp_req = NULL;
}
static void mca_pml_ucx_persisternt_request_destruct(mca_pml_ucx_persistent_request_t* req)
{
req->ompi.req_state = OMPI_REQUEST_INVALID;
OMPI_REQUEST_FINI(&req->ompi);
}
OBJ_CLASS_INSTANCE(mca_pml_ucx_persistent_request_t,
ompi_request_t,
mca_pml_ucx_persisternt_request_construct,
mca_pml_ucx_persisternt_request_destruct);
static int mca_pml_completed_request_free(struct ompi_request_t** rptr)
{
*rptr = MPI_REQUEST_NULL;
return OMPI_SUCCESS;
}
static int mca_pml_completed_request_cancel(struct ompi_request_t* ompi_req, int flag)
{
return OMPI_SUCCESS;
}
void mca_pml_ucx_completed_request_init(ompi_request_t *ompi_req)
{
mca_pml_ucx_request_init_common(ompi_req, false, OMPI_REQUEST_ACTIVE,
mca_pml_completed_request_free,
mca_pml_completed_request_cancel);
ompi_req->req_mpi_object.comm = &ompi_mpi_comm_world.comm;
ompi_request_complete(ompi_req, false);
}
|