1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
*/
#include "coll_ucc_common.h"
static inline
ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount,
struct ompi_datatype_t *rdtype, int root,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
bool is_inplace = (MPI_IN_PLACE == rbuf);
int comm_rank = ompi_comm_rank(ucc_module->comm);
int comm_size = ompi_comm_size(ucc_module->comm);
if (comm_rank == root) {
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) ||
!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) {
goto fallback;
}
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
if (!is_inplace) {
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
}
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
sdtype->super.name : rdtype->super.name);
goto fallback;
}
} else {
if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) {
goto fallback;
}
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
rdtype->super.name);
goto fallback;
}
}
ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_SCATTER,
.root = root,
.src.info = {
.buffer = (void*)sbuf,
.count = ((size_t)scount) * comm_size,
.datatype = ucc_sdt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.dst.info = {
.buffer = (void*)rbuf,
.count = rcount,
.datatype = ucc_rdt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
};
if (is_inplace) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
return UCC_ERR_NOT_SUPPORTED;
}
int mca_coll_ucc_scatter(const void *sbuf, int scount,
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;
UCC_VERBOSE(3, "running ucc scatter");
COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, ucc_module, &req,
NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback scatter");
return mca_coll_ucc_call_previous(scatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}
int mca_coll_ucc_iscatter(const void *sbuf, int scount,
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm,
ompi_request_t** request,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;
mca_coll_ucc_req_t *coll_req = NULL;
UCC_VERBOSE(3, "running ucc iscatter");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, ucc_module, &req,
coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback iscatter");
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return mca_coll_ucc_call_previous(iscatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}
|