1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
// -*- tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
// SPDX-FileCopyrightInfo: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#ifndef DUNE_PYTHON_COMMON_MPIHELPER_HH
#define DUNE_PYTHON_COMMON_MPIHELPER_HH
#include <vector>
#include <dune/common/parallel/communication.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <dune/python/common/typeregistry.hh>
#include <dune/python/pybind11/pybind11.h>
namespace Dune
{
namespace Python
{
// registerCommunication
// -------------------------------
template< class Comm, class... objects >
inline static void registerCommunication ( pybind11::class_< Comm, objects... > cls )
{
using pybind11::operator""_a;
cls.def_property_readonly( "rank", &Comm::rank );
cls.def_property_readonly( "size", &Comm::size );
cls.def( "barrier", &Comm::barrier );
cls.def( "min", [] ( const Comm &self, double x ) { return self.min( x ); }, "x"_a );
cls.def( "min", [] ( const Comm &self, std::vector< double > x ) { self.min( x.data(), x.size() ); return x; }, "x"_a );
cls.def( "max", [] ( const Comm &self, double x ) { return self.max( x ); }, "x"_a );
cls.def( "max", [] ( const Comm &self, std::vector<double> x ) { self.max( x.data(), x.size() ); return x; }, "x"_a );
cls.def( "sum", [] ( const Comm &self, double x ) { return self.sum( x ); }, "x"_a );
cls.def( "sum", [] ( const Comm &self, std::vector<double> x ) { self.sum( x.data(), x.size() ); return x; }, "x"_a );
cls.def( "broadcast", [] ( const Comm &self, double x, int root ) { self.broadcast( &x, 1, root); return x; }, "x"_a, "root"_a );
cls.def( "broadcast", [] ( const Comm &self, std::vector<double> x, int root ) { self.broadcast( x.data(), x.size(), root); return x; }, "x"_a, "root"_a );
cls.def( "gather", [] ( const Comm &self, double x, int root )
{
// result will contain valid values only on rank=root
std::vector< double > out;
if( self.rank() == root )
out.resize( self.size(), x );
self.gather( &x, out.data(), 1, root);
return out;
}, "x"_a, "root"_a );
cls.def( "scatter", [] ( const Comm &self, double x, int root )
{
double out = x;
self.scatter( &x, &out, 1, root);
return out;
}, "x"_a, "root"_a );
}
inline static void registerCommunication ( pybind11::handle scope )
{
using Comm = Dune::Communication< Dune::MPIHelper::MPICommunicator >;
auto typeName = GenerateTypeName( "Dune::Communication", "Dune::MPIHelper::MPICommunicator" );
auto includes = IncludeFiles{ "dune/common/parallel/communication.hh", "dune/common/parallel/mpihelper.hh" };
auto [ cls, notRegistered ] = insertClass< Comm >( scope, "Communication", typeName, includes );
if( notRegistered )
registerCommunication( cls );
scope.attr( "comm" ) = pybind11::cast( Dune::MPIHelper::getCommunication() );
}
} // namespace Python
} // namespace Dune
#endif // #ifndef DUNE_PYTHON_COMMON_MPIHELPER_HH
|