File: Pybind11StOptTree.cpp

package info (click to toggle)
stopt 5.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,772 kB
  • sloc: cpp: 70,373; python: 5,942; makefile: 67; sh: 57
file content (83 lines) | stat: -rw-r--r-- 4,153 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// Copyright (C) 2019 EDF
// All Rights Reserved
// This code is published under the GNU Lesser General Public License (GNU LGPL)

/** \file Pybind11StOptTree.cpp
 * \brief Map Tree  classes to python
 * \author Xavier Warin
 */

#include <iostream>
#include <memory>
#include <Eigen/Dense>
#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/stl_bind.h>
#include <pybind11/stl.h>
#include "StOpt/core/utils/constant.h"
#include "StOpt/core/utils/version.h"
#include "StOpt/core/grids/SpaceGrid.h"
#include "StOpt/core/grids/FullGrid.h"
#include "StOpt/tree/Tree.h"
#include "StOpt/tree/ContinuationValueTree.h"
#include "StOpt/tree/ContinuationCutsTree.h"
#include "StOpt/tree/GridTreeValue.h"

namespace py = pybind11;


/// \brief Encapsulation for tree module
PYBIND11_MODULE(StOptTree, m)
{
    // version
    m.def("getVersion", StOpt::getStOptVersion);

    py::class_<StOpt::Tree, std::shared_ptr<StOpt::Tree>>(m, "Tree")
            .def(py::init<const std::vector<double> &, const std::vector< std::vector< std::array<int, 2> >  > &>())
            .def(py::init<>())
            .def("update", &StOpt::Tree::update)
            .def("getProba", (std::vector<double>(StOpt::Tree::*)() const)&StOpt::Tree::getProba)
            .def("getConnected", &StOpt::Tree::getConnected)
            .def("expCond", &StOpt::Tree::expCond)
            .def("expCondMultiple",   &StOpt::Tree::expCondMultiple)
            .def("getNbNodes", &StOpt::Tree::getNbNodes)
            .def("getNbNodesNextDate", &StOpt::Tree::getNbNodesNextDate)
            ;
    // map contination values
    py::class_<StOpt::ContinuationValueTree, std::shared_ptr<StOpt::ContinuationValueTree> >(m, "ContinuationValueTree")
    .def(py::init<>())
    .def(py::init<const  std::shared_ptr< StOpt::SpaceGrid >   &, const std::shared_ptr< StOpt::Tree >   &, const Eigen::ArrayXXd &>())
    .def("loadForSimulation", &StOpt::ContinuationValueTree::loadForSimulation)
    .def("getValueAtNodes", (Eigen::ArrayXd(StOpt::ContinuationValueTree::*)(const Eigen::ArrayXd &) const)&StOpt::ContinuationValueTree::getValueAtNodes)
    .def("getValueAtNodes", (Eigen::ArrayXd(StOpt::ContinuationValueTree::*)(const StOpt::Interpolator &) const)&StOpt::ContinuationValueTree::getValueAtNodes)
    .def("getValueAtANode", (double(StOpt::ContinuationValueTree::*)(const int &, const Eigen::ArrayXd &) const)&StOpt::ContinuationValueTree::getValueAtANode)
    .def("getValueAtANode", (double(StOpt::ContinuationValueTree::*)(const int &, const StOpt::Interpolator &) const)&StOpt::ContinuationValueTree::getValueAtANode)
    .def("getValues", &StOpt::ContinuationValueTree::getValues)
    .def("getGrid", &StOpt::ContinuationValueTree::getGrid)
    ;
    // map grid tree valeus
    py::class_<StOpt::GridTreeValue, std::shared_ptr<StOpt::GridTreeValue> >(m, "GridTreeValue")
    .def(py::init<>())
    .def(py::init<const  std::shared_ptr< StOpt::SpaceGrid >   &, const Eigen::ArrayXXd &>())
    .def(py::init<const  std::shared_ptr< StOpt::SpaceGrid >   &>())
    .def(py::init<const  std::shared_ptr< StOpt::SpaceGrid >   &, const std::vector< std::shared_ptr<StOpt::InterpolatorSpectral> > & > ())
    .def("getValue", &StOpt::GridTreeValue::getValue)
    .def("getValues", &StOpt::GridTreeValue::getValues)
    .def("getGrid", &StOpt::GridTreeValue::getGrid)
    .def("getInterpolators", &StOpt::GridTreeValue::getInterpolators)
    ;

    // map contination values
    py::class_<StOpt::ContinuationCutsTree, std::shared_ptr<StOpt::ContinuationCutsTree> >(m, "ContinuationCutsTree")
    .def(py::init<>())
    .def(py::init<const  std::shared_ptr< StOpt::SpaceGrid >   &, const  std::shared_ptr< StOpt::Tree >   &, const  Eigen::ArrayXXd>())
    .def("loadForSimulation", &StOpt::ContinuationCutsTree::loadForSimulation)
    .def("getCutsAllNodes", &StOpt::ContinuationCutsTree::getCutsAllNodes)
    .def("getCutsANode", &StOpt::ContinuationCutsTree::getCutsANode)
    .def("getValues", &StOpt::ContinuationCutsTree::getValues)
    .def("getGrid", &StOpt::ContinuationCutsTree::getGrid)
    .def("getNbNodes", &StOpt::ContinuationCutsTree::getNbNodes)
    ;

}