File: coll_ucc_gather.c

package info (click to toggle)
openmpi 5.0.8-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 201,672 kB
  • sloc: ansic: 613,078; makefile: 42,354; sh: 11,194; javascript: 9,244; f90: 7,052; java: 6,404; perl: 5,179; python: 1,859; lex: 740; fortran: 61; cpp: 20; tcl: 12
file content (131 lines) | stat: -rw-r--r-- 4,859 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

/**
 * Copyright (c) 2021 Mellanox Technologies. All rights reserved.
 * Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 */

#include "coll_ucc_common.h"

static inline
ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
                                      void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
                                      int root, mca_coll_ucc_module_t *ucc_module,
                                      ucc_coll_req_h *req,
                                      mca_coll_ucc_req_t *coll_req)
{
    ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
    bool is_inplace = (MPI_IN_PLACE == sbuf);
    int comm_rank = ompi_comm_rank(ucc_module->comm);
    int comm_size = ompi_comm_size(ucc_module->comm);

    if (comm_rank == root) {
        if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
            !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
            goto fallback;
        }

        ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
        if (!is_inplace) {
            ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
        }

        if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
            (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
                        sdtype->super.name : rdtype->super.name);
            goto fallback;
        }
    } else {
        if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) {
            goto fallback;
        }

        ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
        if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        sdtype->super.name);
            goto fallback;
        }
    }

    ucc_coll_args_t coll = {
        .mask      = 0,
        .flags     = 0,
        .coll_type = UCC_COLL_TYPE_GATHER,
        .root      = root,
        .src.info = {
            .buffer   = (void*)sbuf,
            .count    = scount,
            .datatype = ucc_sdt,
            .mem_type = UCC_MEMORY_TYPE_UNKNOWN
        },
        .dst.info = {
            .buffer   = (void*)rbuf,
            .count    = rcount * comm_size,
            .datatype = ucc_rdt,
            .mem_type = UCC_MEMORY_TYPE_UNKNOWN
        },
    };

    if (is_inplace) {
        coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
        coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
    }
    COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
    return UCC_OK;
fallback:
    return UCC_ERR_NOT_SUPPORTED;
}

int mca_coll_ucc_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
                        void *rbuf, int rcount, struct ompi_datatype_t *rdtype,
                        int root, struct ompi_communicator_t *comm,
                        mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;

    UCC_VERBOSE(3, "running ucc gather");
    COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount,
                                            rdtype, root, ucc_module,
                                            &req, NULL));
    COLL_UCC_POST_AND_CHECK(req);
    COLL_UCC_CHECK(coll_ucc_req_wait(req));
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback gather");
    return mca_coll_ucc_call_previous(gather, ucc_module,
        sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
                         void *rbuf, int rcount, struct ompi_datatype_t *rdtype,
                         int root, struct ompi_communicator_t *comm,
                         ompi_request_t** request,
                         mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;
    mca_coll_ucc_req_t    *coll_req = NULL;

    UCC_VERBOSE(3, "running ucc igather");
    COLL_UCC_GET_REQ(coll_req);
    COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount,
                                            rdtype, root, ucc_module,
                                            &req, coll_req));
    COLL_UCC_POST_AND_CHECK(req);
    *request = &coll_req->super;
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback igather");
    if (coll_req) {
        mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
    }
    return mca_coll_ucc_call_previous(igather, ucc_module,
        sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}