1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
// Copyright (C) 2019 EDF
// All Rights Reserved
// This code is published under the GNU Lesser General Public License (GNU LGPL)
/** \file Pybind11StOptTree.cpp
* \brief Map Tree classes to python
* \author Xavier Warin
*/
#include <iostream>
#include <memory>
#include <Eigen/Dense>
#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/stl_bind.h>
#include <pybind11/stl.h>
#include "StOpt/core/utils/constant.h"
#include "StOpt/core/utils/version.h"
#include "StOpt/core/grids/SpaceGrid.h"
#include "StOpt/core/grids/FullGrid.h"
#include "StOpt/tree/Tree.h"
#include "StOpt/tree/ContinuationValueTree.h"
#include "StOpt/tree/ContinuationCutsTree.h"
#include "StOpt/tree/GridTreeValue.h"
namespace py = pybind11;
/// \brief Encapsulation for tree module
PYBIND11_MODULE(StOptTree, m)
{
// version
m.def("getVersion", StOpt::getStOptVersion);
py::class_<StOpt::Tree, std::shared_ptr<StOpt::Tree>>(m, "Tree")
.def(py::init<const std::vector<double> &, const std::vector< std::vector< std::array<int, 2> > > &>())
.def(py::init<>())
.def("update", &StOpt::Tree::update)
.def("getProba", (std::vector<double>(StOpt::Tree::*)() const)&StOpt::Tree::getProba)
.def("getConnected", &StOpt::Tree::getConnected)
.def("expCond", &StOpt::Tree::expCond)
.def("expCondMultiple", &StOpt::Tree::expCondMultiple)
.def("getNbNodes", &StOpt::Tree::getNbNodes)
.def("getNbNodesNextDate", &StOpt::Tree::getNbNodesNextDate)
;
// map contination values
py::class_<StOpt::ContinuationValueTree, std::shared_ptr<StOpt::ContinuationValueTree> >(m, "ContinuationValueTree")
.def(py::init<>())
.def(py::init<const std::shared_ptr< StOpt::SpaceGrid > &, const std::shared_ptr< StOpt::Tree > &, const Eigen::ArrayXXd &>())
.def("loadForSimulation", &StOpt::ContinuationValueTree::loadForSimulation)
.def("getValueAtNodes", (Eigen::ArrayXd(StOpt::ContinuationValueTree::*)(const Eigen::ArrayXd &) const)&StOpt::ContinuationValueTree::getValueAtNodes)
.def("getValueAtNodes", (Eigen::ArrayXd(StOpt::ContinuationValueTree::*)(const StOpt::Interpolator &) const)&StOpt::ContinuationValueTree::getValueAtNodes)
.def("getValueAtANode", (double(StOpt::ContinuationValueTree::*)(const int &, const Eigen::ArrayXd &) const)&StOpt::ContinuationValueTree::getValueAtANode)
.def("getValueAtANode", (double(StOpt::ContinuationValueTree::*)(const int &, const StOpt::Interpolator &) const)&StOpt::ContinuationValueTree::getValueAtANode)
.def("getValues", &StOpt::ContinuationValueTree::getValues)
.def("getGrid", &StOpt::ContinuationValueTree::getGrid)
;
// map grid tree valeus
py::class_<StOpt::GridTreeValue, std::shared_ptr<StOpt::GridTreeValue> >(m, "GridTreeValue")
.def(py::init<>())
.def(py::init<const std::shared_ptr< StOpt::SpaceGrid > &, const Eigen::ArrayXXd &>())
.def(py::init<const std::shared_ptr< StOpt::SpaceGrid > &>())
.def(py::init<const std::shared_ptr< StOpt::SpaceGrid > &, const std::vector< std::shared_ptr<StOpt::InterpolatorSpectral> > & > ())
.def("getValue", &StOpt::GridTreeValue::getValue)
.def("getValues", &StOpt::GridTreeValue::getValues)
.def("getGrid", &StOpt::GridTreeValue::getGrid)
.def("getInterpolators", &StOpt::GridTreeValue::getInterpolators)
;
// map contination values
py::class_<StOpt::ContinuationCutsTree, std::shared_ptr<StOpt::ContinuationCutsTree> >(m, "ContinuationCutsTree")
.def(py::init<>())
.def(py::init<const std::shared_ptr< StOpt::SpaceGrid > &, const std::shared_ptr< StOpt::Tree > &, const Eigen::ArrayXXd>())
.def("loadForSimulation", &StOpt::ContinuationCutsTree::loadForSimulation)
.def("getCutsAllNodes", &StOpt::ContinuationCutsTree::getCutsAllNodes)
.def("getCutsANode", &StOpt::ContinuationCutsTree::getCutsANode)
.def("getValues", &StOpt::ContinuationCutsTree::getValues)
.def("getGrid", &StOpt::ContinuationCutsTree::getGrid)
.def("getNbNodes", &StOpt::ContinuationCutsTree::getNbNodes)
;
}
|