1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
// Copyright 2009 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
#include <algorithm>
#include <chrono>
#include <thread>
#include "Collectives.h"
#include "MPIBcastFabric.h"
#include "maml/maml.h"
namespace mpicommon {
MPIFabric::MPIFabric(const Group &parentGroup, int bcastRoot)
: group(parentGroup.dup()), bcastRoot(bcastRoot)
{
if (!group.valid()) {
throw std::runtime_error(
"#osp:mpi: trying to set up an MPI fabric "
"with an invalid MPI communicator");
}
int isInter = 0;
MPI_CALL(Comm_test_inter(group.comm, &isInter));
if (isInter && bcastRoot != MPI_ROOT) {
throw std::runtime_error(
"Invalid MPIFabric group config "
"on an MPI intercomm group");
}
}
MPIFabric::~MPIFabric()
{
flushBcastSends();
MPI_Comm_free(&group.comm);
}
void MPIFabric::sendBcast(std::shared_ptr<utility::AbstractArray<uint8_t>> buf)
{
auto future = mpicommon::bcast(
buf->data(), buf->size(), MPI_BYTE, bcastRoot, group.comm);
pendingSends.emplace_back(
std::make_shared<PendingSend>(std::move(future), buf));
checkPendingSends();
}
void MPIFabric::flushBcastSends()
{
while (!pendingSends.empty()) {
checkPendingSends();
}
}
void MPIFabric::recvBcast(utility::AbstractArray<uint8_t> &buf)
{
mpicommon::bcast(buf.data(), buf.size(), MPI_BYTE, bcastRoot, group.comm)
.wait();
checkPendingSends();
}
void MPIFabric::send(
std::shared_ptr<utility::AbstractArray<uint8_t>> buf, int rank)
{
auto future =
mpicommon::send(buf->data(), buf->size(), MPI_BYTE, rank, 0, group.comm);
pendingSends.emplace_back(
std::make_shared<PendingSend>(std::move(future), buf));
checkPendingSends();
}
void MPIFabric::recv(utility::AbstractArray<uint8_t> &buf, int rank)
{
mpicommon::recv(buf.data(), buf.size(), MPI_BYTE, rank, 0, group.comm).wait();
checkPendingSends();
}
void MPIFabric::checkPendingSends()
{
if (!pendingSends.empty()) {
auto done = std::partition(pendingSends.begin(),
pendingSends.end(),
[](const std::shared_ptr<PendingSend> &ps) {
return ps->future.wait_for(std::chrono::milliseconds(0))
!= std::future_status::ready;
});
pendingSends.erase(done, pendingSends.end());
}
}
MPIFabric::PendingSend::PendingSend(std::future<void *> future,
std::shared_ptr<utility::AbstractArray<uint8_t>> &buf)
: future(std::move(future)), buf(buf)
{}
} // namespace mpicommon
|