1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
/**
Copyright (c) 2021 Mellanox Technologies. All rights reserved.
$COPYRIGHT$
Additional copyrights may follow
$HEADER$
*/
#include "scoll_ucc.h"
#include "scoll_ucc_dtypes.h"
#include "scoll_ucc_common.h"
#include <ucc/api/ucc.h>
static inline ucc_status_t mca_scoll_ucc_reduce_init(const void *sbuf, void *rbuf,
int count, struct oshmem_op_t * op,
mca_scoll_ucc_module_t * ucc_module,
ucc_coll_req_h * req)
{
ucc_datatype_t ucc_dt;
ucc_reduction_op_t ucc_op;
ucc_dt = shmem_op_to_ucc_dtype(op);
ucc_op = shmem_op_to_ucc_op(op->op);
if (OPAL_UNLIKELY((ucc_datatype_t) SCOLL_UCC_DT_UNSUPPORTED == ucc_dt)) {
UCC_VERBOSE(5, "shmem datatype is not supported: dtype # = %d",
op->dt);
}
if (OPAL_UNLIKELY((ucc_reduction_op_t) SCOLL_UCC_OP_UNSUPPORTED == ucc_op)) {
UCC_VERBOSE(5, "shmem reduction op is not supported: op # = %d",
op->op);
}
ucc_coll_args_t coll = {
.mask = 0,
.coll_type = UCC_COLL_TYPE_ALLREDUCE,
.src.info = {
.buffer = (void *)sbuf,
.count = count,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.dst.info = {
.buffer = rbuf,
.count = count,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.op = ucc_op,
};
if (sbuf == rbuf) {
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}
if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
}
}
SCOLL_UCC_REQ_INIT(req, coll, ucc_module);
return UCC_OK;
fallback:
return UCC_ERR_NOT_SUPPORTED;
}
int mca_scoll_ucc_reduce(struct oshmem_group_t *group,
struct oshmem_op_t *op,
void *target,
const void *source,
size_t nlong,
long *pSync,
void *pWrk,
int alg)
{
mca_scoll_ucc_module_t *ucc_module;
size_t count;
ucc_coll_req_h req;
int rc;
UCC_VERBOSE(3, "running ucc reduce");
ucc_module = (mca_scoll_ucc_module_t *) group->g_scoll.scoll_reduce_module;
count = nlong / op->dt_size;
/* Do nothing on zero-length request */
if (OPAL_UNLIKELY(!nlong)) {
return OSHMEM_SUCCESS;
}
SCOLL_UCC_CHECK(mca_scoll_ucc_reduce_init(source, target, count, op, ucc_module, &req));
SCOLL_UCC_CHECK(ucc_collective_post(req));
SCOLL_UCC_CHECK(scoll_ucc_req_wait(req));
return OSHMEM_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduction");
PREVIOUS_SCOLL_FN(ucc_module, reduce, group, op, target,
source, nlong, pSync, pWrk, alg);
return rc;
}
|