File: coll_ucc_scatter.c

package info (click to toggle)
openmpi 5.0.8-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 201,672 kB
  • sloc: ansic: 613,078; makefile: 42,354; sh: 11,194; javascript: 9,244; f90: 7,052; java: 6,404; perl: 5,179; python: 1,859; lex: 740; fortran: 61; cpp: 20; tcl: 12
file content (134 lines) | stat: -rw-r--r-- 5,020 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/**
 * Copyright (c) 2021 Mellanox Technologies. All rights reserved.
 * Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 */

#include "coll_ucc_common.h"

static inline
ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
                                       struct ompi_datatype_t *sdtype,
                                       void *rbuf, size_t rcount,
                                       struct ompi_datatype_t *rdtype, int root,
                                       mca_coll_ucc_module_t *ucc_module,
                                       ucc_coll_req_h *req,
                                       mca_coll_ucc_req_t *coll_req)
{
    ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
    bool is_inplace = (MPI_IN_PLACE == rbuf);
    int comm_rank = ompi_comm_rank(ucc_module->comm);
    int comm_size = ompi_comm_size(ucc_module->comm);

    if (comm_rank == root) {
        if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) ||
            !ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) {
            goto fallback;
        }

        ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
        if (!is_inplace) {
            ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
        }

        if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
            (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
                        sdtype->super.name : rdtype->super.name);
            goto fallback;
        }
    } else {
        if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) {
            goto fallback;
        }

        ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
        if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
            UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
                        rdtype->super.name);
            goto fallback;
        }
    }

    ucc_coll_args_t coll = {
        .mask      = 0,
        .flags     = 0,
        .coll_type = UCC_COLL_TYPE_SCATTER,
        .root      = root,
        .src.info  = {
            .buffer   = (void*)sbuf,
            .count    = ((size_t)scount) * comm_size,
            .datatype = ucc_sdt,
            .mem_type = UCC_MEMORY_TYPE_UNKNOWN
        },
        .dst.info = {
            .buffer   = (void*)rbuf,
            .count    = rcount,
            .datatype = ucc_rdt,
            .mem_type = UCC_MEMORY_TYPE_UNKNOWN
        },
    };

    if (is_inplace) {
        coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
        coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
    }
    COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
    return UCC_OK;
fallback:
    return UCC_ERR_NOT_SUPPORTED;
}

int mca_coll_ucc_scatter(const void *sbuf, int scount,
                         struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
                         struct ompi_datatype_t *rdtype, int root,
                         struct ompi_communicator_t *comm,
                         mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;

    UCC_VERBOSE(3, "running ucc scatter");
    COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
                                             rdtype, root, ucc_module, &req,
                                             NULL));
    COLL_UCC_POST_AND_CHECK(req);
    COLL_UCC_CHECK(coll_ucc_req_wait(req));
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback scatter");
    return mca_coll_ucc_call_previous(scatter, ucc_module,
        sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_iscatter(const void *sbuf, int scount,
                         struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
                         struct ompi_datatype_t *rdtype, int root,
                         struct ompi_communicator_t *comm,
                         ompi_request_t** request,
                         mca_coll_base_module_t *module)
{
    mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
    ucc_coll_req_h         req;
    mca_coll_ucc_req_t    *coll_req = NULL;

    UCC_VERBOSE(3, "running ucc iscatter");
    COLL_UCC_GET_REQ(coll_req);
    COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
                                             rdtype, root, ucc_module, &req,
                                             coll_req));
    COLL_UCC_POST_AND_CHECK(req);
    *request = &coll_req->super;
    return OMPI_SUCCESS;
fallback:
    UCC_VERBOSE(3, "running fallback iscatter");
    if (coll_req) {
        mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
    }
    return mca_coll_ucc_call_previous(iscatter, ucc_module,
        sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}