1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
#ifndef CAFFE2_MPI_MPI_COMMON_H_
#define CAFFE2_MPI_MPI_COMMON_H_
#include <mpi.h>
#include <mutex>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
namespace caffe2 {
inline void CheckInitializedMPI() {
int flag;
MPI_Initialized(&flag);
CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
}
template <typename T>
class MPIDataTypeWrapper;
#define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
template <> \
class MPIDataTypeWrapper<c_type> { \
public: \
inline static MPI_Datatype type() { \
return mpi_type; \
} \
};
MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
// Note(Yangqing): as necessary, add more specializations.
#undef MPI_DATATYPE_WRAPPER
// For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
TORCH_API std::mutex& MPIMutex();
#define MPI_CHECK(condition) \
do { \
std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
int error = (condition); \
CAFFE_ENFORCE( \
error == MPI_SUCCESS, \
"Caffe2 MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
/**
* @brief Gets the global MPI communicator used by Caffe2. In default, this
* is MPI_COMM_WORLD unless you call SetGlobalMPIComm().
*/
TORCH_API MPI_Comm GlobalMPIComm();
/**
* @brief Sets the global MPI communicator. Caffe2 takes over the ownership
* of the passed in communicator.
*/
TORCH_API void SetGlobalMPIComm(MPI_Comm new_comm);
/**
* @brief A helper function to return the size of the given communicator.
*/
TORCH_API int MPICommSize(MPI_Comm comm);
/**
* @brief A helper function to return the rank of the given communicator.
*/
TORCH_API int MPICommRank(MPI_Comm comm);
/**
* @brief A simple wrapper over an MPI common world.
*/
class MPICommonWorldWrapper {
public:
/**
* @brief Creates a common world wrapper.
*
* The new common world is created by taking the existing communicator
* passed in as src_comm, and splitting it using the color and the rank
* specified. In default, we will split from Caffe2's global communicator,
* and use color 0 as well as rank implicitly given by src_comm. As a result,
* the default constructor basically creates a comm identical to the source
* comm world.
*/
explicit MPICommonWorldWrapper(
MPI_Comm src_comm = MPI_COMM_NULL,
int color = 0,
int rank = -1) {
if (src_comm == MPI_COMM_NULL) {
src_comm = GlobalMPIComm();
}
if (rank == -1) {
MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
}
MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
MPI_CHECK(MPI_Comm_size(comm_, &size_));
MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
}
~MPICommonWorldWrapper() {
int ret;
MPI_CHECK(MPI_Finalized(&ret));
if (!ret) {
MPI_Comm_free(&comm_);
}
}
/**
* @brief Returns the common world held by the wrapper.
*/
inline MPI_Comm comm() const {
return comm_;
}
/**
* @brief Returns the size of the world.
*/
inline int size() const {
return size_;
}
/**
* @brief Returns the rank of this process in the world.
*/
inline int rank() const {
return rank_;
}
private:
MPI_Comm comm_;
int size_;
int rank_;
};
/**
* A function used to perform peer setup so one does not need to use
* mpirun / mpiexec to run the binary. Note that if you use mpirun or mpiexec
* to set up the common world, do not use this function - MPI_Init would have
* already set that up.
*
* This also assumes that you have a common path (like NFS) that multiple
* instances can read from.
*
* Inputs:
* replicas (int): the number of replicas that mpi will run with.
* role (string): the role of this process, "server" or "client".
* job_path (string): a file name that the server will write its port into
* and the clients will read the server's port from.
*/
void MPISetupPeers(
const int replicas,
const string& role,
const string& job_path);
} // namespace caffe2
#endif // CAFFE2_MPI_MPI_COMMON_H_
|