File: coll_ucc_gatherv.c

package info (click to toggle)
openmpi 5.0.7-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 202,312 kB
  • sloc: ansic: 612,441; makefile: 42,495; sh: 11,230; javascript: 9,244; f90: 7,052; java: 6,404; perl: 5,154; python: 1,856; lex: 740; fortran: 61; cpp: 20; tcl: 12
file content (119 lines) | stat: -rw-r--r-- 4,710 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

/**
 * Copyright (c) 2021 Mellanox Technologies. All rights reserved.
 * Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 */

#include "coll_ucc_common.h"

static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
                                                     void *rbuf, const int *rcounts, const int *disps,
                                                     struct ompi_datatype_t *rdtype, int root,
                                                     mca_coll_ucc_module_t *ucc_module,
                                                     ucc_coll_req_h *req,
                                                     mca_coll_ucc_req_t *coll_req)
{
    ucc_datatype_t ucc_sdt, ucc_rdt;
    int comm_rank = ompi_comm_rank(ucc_module->comm);

    ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
    if (comm_rank == root) {
        ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
        if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ||
            (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ?
                        rdtype->super.name : sdtype->super.name);
            goto fallback;
        }
    } else {
        if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        sdtype->super.name);
            goto fallback;
        }
    }

    ucc_coll_args_t coll = {
        .mask      = 0,
        .flags     = 0,
        .coll_type = UCC_COLL_TYPE_GATHERV,
        .root      = root,
        .src.info = {
            .buffer   = (void*)sbuf,
            .count    = scount,
            .datatype = ucc_sdt,
            .mem_type = UCC_MEMORY_TYPE_UNKNOWN
        },
        .dst.info_v = {
            .buffer        = (void*)rbuf,
            .counts        = (ucc_count_t*)rcounts,
            .displacements = (ucc_aint_t*)disps,
            .datatype      = ucc_rdt,
            .mem_type      = UCC_MEMORY_TYPE_UNKNOWN
        },
    };

    if (MPI_IN_PLACE == sbuf) {
        coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
        coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
    }
    COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
    return UCC_OK;
fallback:
    return UCC_ERR_NOT_SUPPORTED;
}

int mca_coll_ucc_gatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
                         void *rbuf, const int *rcounts, const int *disps,
                         struct ompi_datatype_t *rdtype, int root,
                         struct ompi_communicator_t *comm,
                         mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;

    UCC_VERBOSE(3, "running ucc gatherv");
    COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts,
                                             disps, rdtype, root, ucc_module,
                                             &req, NULL));
    COLL_UCC_POST_AND_CHECK(req);
    COLL_UCC_CHECK(coll_ucc_req_wait(req));
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback gatherv");
    return mca_coll_ucc_call_previous(gatherv, ucc_module,
        sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm);
}

int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
                          void *rbuf, const int *rcounts, const int *disps,
                          struct ompi_datatype_t *rdtype, int root,
                          struct ompi_communicator_t *comm,
                          ompi_request_t** request,
                          mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;
    mca_coll_ucc_req_t    *coll_req = NULL;

    UCC_VERBOSE(3, "running ucc igatherv");
    COLL_UCC_GET_REQ(coll_req);
    COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts,
                                             disps, rdtype, root, ucc_module,
                                             &req, coll_req));
    COLL_UCC_POST_AND_CHECK(req);
    *request = &coll_req->super;
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback igatherv");
    if (coll_req) {
        mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
    }
    return mca_coll_ucc_call_previous(igatherv, ucc_module,
        sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm, request);
}