File: scoll_ucc.h

package info (click to toggle)
openmpi 5.0.8-10
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 201,692 kB
  • sloc: ansic: 613,078; makefile: 42,351; sh: 11,194; javascript: 9,244; f90: 7,052; java: 6,404; perl: 5,179; python: 1,859; lex: 740; fortran: 61; cpp: 20; tcl: 12
file content (131 lines) | stat: -rw-r--r-- 4,040 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*
 * Copyright (c) 2021      Mellanox Technologies. All rights reserved.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 * $HEADER$
 */

#ifndef MCA_SCOLL_UCC_H
#define MCA_SCOLL_UCC_H

#include "oshmem_config.h"

#include "shmem.h"
#include "oshmem/mca/mca.h"
#include "oshmem/mca/scoll/scoll.h"
#include "oshmem/proc/proc.h"

#include "scoll_ucc_debug.h"

#include <ucc/api/ucc.h>

BEGIN_C_DECLS

#define SCOLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
                       UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLGATHER | \
                       UCC_COLL_TYPE_ALLTOALL)

#define SCOLL_UCC_CTS_STR "barrier,broadcast,reduce,collect,alltoall"

int mca_scoll_ucc_progress(void);

/**
 * Globally exported structure
 */
struct mca_scoll_ucc_component_t {
    mca_scoll_base_component_1_0_0_t super;
    int ucc_priority;
    int ucc_verbose;
    int ucc_enable;
    int ucc_np;
    char * cls;
    char * cts;
    int nr_modules;
    bool libucc_initialized;
    ucc_context_h ucc_context;
    ucc_lib_h ucc_lib;
    ucc_lib_attr_t ucc_lib_attr;
    ucc_coll_type_t cts_requested;
};
typedef struct mca_scoll_ucc_component_t mca_scoll_ucc_component_t;

OMPI_DECLSPEC extern mca_scoll_ucc_component_t mca_scoll_ucc_component;

/**
 * UCC enabled team
 */
struct mca_scoll_ucc_module_t {
    mca_scoll_base_module_t super;

    oshmem_group_t             *group;
    ucc_team_h                  ucc_team;
    long                       *pSync;
    
    /* Saved handlers - for fallback */
    mca_scoll_base_module_reduce_fn_t previous_reduce;
    mca_scoll_base_module_t *previous_reduce_module;
    mca_scoll_base_module_broadcast_fn_t previous_broadcast;
    mca_scoll_base_module_t *previous_broadcast_module;
    mca_scoll_base_module_barrier_fn_t previous_barrier;
    mca_scoll_base_module_t *previous_barrier_module;
    mca_scoll_base_module_collect_fn_t previous_collect;
    mca_scoll_base_module_t *previous_collect_module;
    mca_scoll_base_module_alltoall_fn_t previous_alltoall;
    mca_scoll_base_module_t *previous_alltoall_module;
};
typedef struct mca_scoll_ucc_module_t mca_scoll_ucc_module_t;

OBJ_CLASS_DECLARATION(mca_scoll_ucc_module_t);

/* API functions */
int mca_scoll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threads);

int mca_scoll_ucc_team_create(mca_scoll_ucc_module_t *ucc_module, 
                              oshmem_group_t *osh_group);

int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group);

mca_scoll_base_module_t* mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority);

int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg);

int mca_scoll_ucc_broadcast(struct oshmem_group_t *group,
                            int PE_root,
                            void *target,
                            const void *source,
                            size_t nlong,
                            long *pSync,
                            bool nlong_type,
                            int alg);

int mca_scoll_ucc_collect(struct oshmem_group_t *group,
                          void *target,
                          const void *source,
                          size_t nlong,
                          long *pSync,
                          bool nlong_type,
                          int alg);

int mca_scoll_ucc_reduce(struct oshmem_group_t *group,
                         struct oshmem_op_t *op,
                         void *target,
                         const void *source,
                         size_t nlong,
                         long *pSync,
                         void *pWrk,
                         int alg);

int mca_scoll_ucc_alltoall(struct oshmem_group_t *group,
                           void *target,
                           const void *source,
                           ptrdiff_t dst, ptrdiff_t sst,
                           size_t nelems,
                           size_t element_size,
                           long *pSync,
                           int alg);

END_C_DECLS

#endif