1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iostream>
#include <tuple>
#include <utility>
#include <vector>
namespace py = pybind11;
#include <hera/wasserstein_pure_geom.hpp>
using DynamicPointVector = hera::ws::dnn::DynamicPointVector<double>;
using Pair2dPoint = std::pair<double, double>;
using Tuple3dPoint = std::tuple<double, double, double>;
using Vector2dPoints = std::vector<Pair2dPoint>;
using Vector3dPoints = std::vector<Tuple3dPoint>;
using Params = hera::AuctionParams<double>;
using Result = hera::AuctionResult<double>;
using Traits = hera::ws::dnn::DynamicPointTraits<double>;
DynamicPointVector convert_2d_points_to_dnn(const Vector2dPoints& points)
{
constexpr int dim = 2;
Traits traits(dim);
DynamicPointVector result = traits.container(points.size());
for(size_t i = 0; i < points.size(); ++i) {
result[i][0] = points[i].first;
result[i][1] = points[i].second;
}
return result;
}
DynamicPointVector convert_3d_points_to_dnn(const Vector3dPoints& points)
{
constexpr int dim = 3;
Traits traits(dim);
DynamicPointVector result = traits.container(points.size());
for(size_t i = 0; i < points.size(); ++i) {
result[i][0] = std::get<0>(points[i]);
result[i][1] = std::get<1>(points[i]);
result[i][2] = std::get<2>(points[i]);
}
return result;
}
Result wasserstein_cost_geom_detailed(const Vector2dPoints& points_1, const Vector2dPoints& points_2, Params& params, const std::vector<double>& prices)
{
using Traits = hera::ws::dnn::DynamicPointTraits<double>;
constexpr int dim = 2;
hera::ws::dnn::DynamicPointTraits<double> traits(dim);
if (points_1.size() != points_2.size()) {
std::cerr << "points_1.size = " << points_1.size() << " != points_2.size = " << points_2.size() << std::endl;
throw std::runtime_error("Point clouds must have same cardinality");
}
auto dpoints_1 = convert_2d_points_to_dnn(points_1);
auto dpoints_2 = convert_2d_points_to_dnn(points_2);
if (params.return_matching) {
std::cerr << "Matching for point clouds not implemented, need id for DynamicPoint" << std::endl;
throw std::runtime_error("Matching for point clouds not supported");
}
return hera::ws::wasserstein_cost_detailed(dpoints_1, dpoints_2, params, prices);
}
double wasserstein_cost_geom(const Vector2dPoints& points_1, const Vector2dPoints& points_2, Params& params, const std::vector<double>& prices)
{
return wasserstein_cost_geom_detailed(points_1, points_2, params, prices).cost;
}
double wasserstein_cost_geom_no_params(const Vector2dPoints& points_1, const Vector2dPoints& points_2, const std::vector<double>& prices)
{
hera::AuctionParams<double> params;
params.dim = 2;
return wasserstein_cost_geom(points_1, points_2, params, prices);
}
void init_ws_geom(py::module& m)
{
m.def("wasserstein_cost_geom_detailed_", wasserstein_cost_geom_detailed);
m.def("wasserstein_cost_geom_", wasserstein_cost_geom_no_params);
m.def("wasserstein_cost_geom_", wasserstein_cost_geom);
}
|