File: ng_mpi_wrapper.cpp

package info (click to toggle)
netgen 6.2.2601%2Bdfsg1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 13,076 kB
  • sloc: cpp: 166,627; tcl: 6,310; python: 2,868; sh: 528; makefile: 90
file content (206 lines) | stat: -rw-r--r-- 6,194 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#ifdef PARALLEL

#include <filesystem>
#include <iostream>
#include <stdexcept>

#include "ng_mpi.hpp"
#include "ngstream.hpp"
#ifdef NG_PYTHON
#include "python_ngcore.hpp"
#endif // NG_PYTHON
#include "utils.hpp"

using std::cerr;
using std::cout;
using std::endl;

#ifndef NG_MPI_WRAPPER
#ifdef NG_PYTHON
#define MPI4PY_LIMITED_API 1
#define MPI4PY_LIMITED_API_SKIP_MESSAGE 1
#define MPI4PY_LIMITED_API_SKIP_SESSION 1
#include "mpi4py_pycapi.h"  // mpi4py < 4.0.0
#endif // NG_PYTHON
#endif // NG_MPI_WRAPPER

namespace ngcore {

#ifdef NG_MPI_WRAPPER
static std::unique_ptr<SharedLibrary> mpi_lib, ng_mpi_lib;
static bool need_mpi_finalize = false;

struct MPIFinalizer {
  ~MPIFinalizer() {
    if (need_mpi_finalize) {
      cout << IM(5) << "Calling MPI_Finalize" << endl;
      NG_MPI_Finalize();
    }
  }
} mpi_finalizer;

bool MPI_Loaded() { return ng_mpi_lib != nullptr; }

void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
  if (ng_mpi_lib) return;

  cout << IM(3) << "InitMPI" << endl;

  std::string vendor = "";
  std::string mpi4py_lib_file = "";

  if (mpi_lib_path) {
    // Dynamic load of given shared MPI library
    // Then call MPI_Init, read the library version and set the vender name
    try {
      typedef int (*init_handle)(int *, char ***);
      typedef int (*mpi_initialized_handle)(int *);
      mpi_lib =
          std::make_unique<SharedLibrary>(*mpi_lib_path, std::nullopt, true);
      auto mpi_init = mpi_lib->GetSymbol<init_handle>("MPI_Init");
      auto mpi_initialized =
          mpi_lib->GetSymbol<mpi_initialized_handle>("MPI_Initialized");

      int flag = 0;
      mpi_initialized(&flag);
      if (!flag) {
        typedef const char *pchar;
        int argc = 1;
        pchar args[] = {"netgen", nullptr};
        pchar *argv = &args[0];
        cout << IM(5) << "Calling MPI_Init" << endl;
        mpi_init(&argc, (char ***)argv);
        need_mpi_finalize = true;
      }

      char c_version_string[65536];
      c_version_string[0] = '\0';
      int result_len = 0;
      typedef void (*get_version_handle)(char *, int *);
      auto get_version =
          mpi_lib->GetSymbol<get_version_handle>("MPI_Get_library_version");
      get_version(c_version_string, &result_len);
      std::string version = c_version_string;

      if (version.substr(0, 8) == "Open MPI")
        vendor = "Open MPI";
      else if (version.substr(0, 5) == "MPICH")
        vendor = "MPICH";
      else if (version.substr(0, 13) == "Microsoft MPI")
        vendor = "Microsoft MPI";
      else if (version.substr(0, 12) == "Intel(R) MPI")
        vendor = "Intel MPI";
      else
        throw std::runtime_error(
            std::string("Unknown MPI version: " + version));
    } catch (std::runtime_error &e) {
      cerr << "Could not load MPI: " << e.what() << endl;
      throw e;
    }
  } else {
#ifdef NG_PYTHON
    // Use mpi4py to init MPI library and get the vendor name
    auto mpi4py = py::module::import("mpi4py.MPI");
    vendor = mpi4py.attr("get_vendor")()[py::int_(0)].cast<std::string>();

#ifndef WIN32
    // Load mpi4py library (it exports all MPI symbols) to have all MPI symbols
    // available before the ng_mpi wrapper is loaded This is not necessary on
    // windows as the matching mpi dll is linked to the ng_mpi wrapper directly
    mpi4py_lib_file = mpi4py.attr("__file__").cast<std::string>();
    mpi_lib =
        std::make_unique<SharedLibrary>(mpi4py_lib_file, std::nullopt, true);
#endif  // WIN32
#endif // NG_PYTHON
  }

  std::string ng_lib_name = "";
  if (vendor == "Open MPI")
    ng_lib_name = "ng_openmpi";
  else if (vendor == "MPICH")
    ng_lib_name = "ng_mpich";
  else if (vendor == "Microsoft MPI")
    ng_lib_name = "ng_microsoft_mpi";
  else if (vendor == "Intel MPI")
    ng_lib_name = "ng_intel_mpi";
  else
    throw std::runtime_error("Unknown MPI vendor: " + vendor);

  ng_lib_name += NETGEN_SHARED_LIBRARY_SUFFIX;

  // Load the ng_mpi wrapper and call ng_init_mpi to set all function pointers
  typedef void (*ng_init_handle)();
  ng_mpi_lib = std::make_unique<SharedLibrary>(ng_lib_name);
  ng_mpi_lib->GetSymbol<ng_init_handle>("ng_init_mpi")();
  std::cout << IM(3) << "MPI wrapper loaded, vendor: " << vendor << endl;
}

static std::runtime_error no_mpi() {
  return std::runtime_error("MPI not enabled");
}

#ifdef NG_PYTHON
decltype(NG_MPI_CommFromMPI4Py) NG_MPI_CommFromMPI4Py =
    [](py::handle py_obj, NG_MPI_Comm &ng_comm) -> bool {
  // If this gets called, it means that we want to convert an mpi4py
  // communicator to a Netgen MPI communicator, but the Netgen MPI wrapper
  // runtime was not yet initialized.

  // store the current address of this function
  auto old_converter_address = NG_MPI_CommFromMPI4Py;

  // initialize the MPI wrapper runtime, this sets all the function pointers
  InitMPI();

  // if the initialization was successful, the function pointer should have
  // changed
  // -> call the actual conversion function
  if (NG_MPI_CommFromMPI4Py != old_converter_address)
    return NG_MPI_CommFromMPI4Py(py_obj, ng_comm);

  // otherwise, something strange happened
  throw no_mpi();
};
decltype(NG_MPI_CommToMPI4Py) NG_MPI_CommToMPI4Py =
    [](NG_MPI_Comm) -> py::handle { throw no_mpi(); };
#endif  // NG_PYTHON

#include "ng_mpi_generated_dummy_init.hpp"
#else  // NG_MPI_WRAPPER

static bool imported_mpi4py = false;
#ifdef NG_PYTHON
decltype(NG_MPI_CommFromMPI4Py) NG_MPI_CommFromMPI4Py =
    [](py::handle src, NG_MPI_Comm &dst) -> bool {
  if (!imported_mpi4py) {
    import_mpi4py__MPI();
    imported_mpi4py = true;
  }
  PyObject *py_src = src.ptr();
  // auto type = Py_TYPE(py_src);
  if (PyObject_TypeCheck(py_src, &PyMPIComm_Type)) {
    dst = *PyMPIComm_Get(py_src);
    return !PyErr_Occurred();
  }
  return false;
};

decltype(NG_MPI_CommToMPI4Py) NG_MPI_CommToMPI4Py =
    [](NG_MPI_Comm src) -> py::handle {
  if (!imported_mpi4py) {
    import_mpi4py__MPI();
    imported_mpi4py = true;
  }
  return py::handle(PyMPIComm_New(src));
};

#endif  // NG_PYTHON

bool MPI_Loaded() { return true; }
void InitMPI(std::optional<std::filesystem::path>) {}

#endif  // NG_MPI_WRAPPER

}  // namespace ngcore

#endif  // PARALLEL