1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
#include <Eigen/Dense>
#include <iostream>
#include <memory>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/stl_bind.h>
#include <pybind11/stl.h>
#include "StOpt/regression/ContinuationValue.h"
#include "StOpt/core/grids/SpaceGrid.h"
namespace py = pybind11;
class BaseA
{
public:
BaseA() {}
virtual double get() const = 0;
};
class DerA : public BaseA
{
double m_x ;
public:
DerA(const double &x) : m_x(x) {}
virtual double get() const
{
return 2. ;
}
};
// wrapper
class PyBaseA : public BaseA
{
public:
using BaseA::BaseA;
double get() const override
{
PYBIND11_OVERLOAD_PURE(double, BaseA, get);
}
};
class PyDerA : public DerA
{
public:
using DerA::DerA;
double get() const override
{
PYBIND11_OVERLOAD_PURE(double, DerA, get);
}
};
PYBIND11_MODULE(pyBind, m)
{
pybind11::class_< BaseA, PyBaseA >(m, "BaseA")
.def(py::init<>())
.def("get", &BaseA::get)
;
pybind11::class_< DerA, std::shared_ptr<DerA>, PyDerA, BaseA >(m, "DerA")
// .def(py::init< double >())
.def(py::init([](const py::list & p_x)
{
return new DerA(p_x[0].cast<double>()) ;
}
))
.def("get", &DerA::get)
;
}
|