File: hera_ws_geom.cpp

package info (click to toggle)
hera 2.0.0%2Bgit20221115.8bfdd4b%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,120 kB
  • sloc: cpp: 30,506; python: 8,986; ansic: 2,088; sh: 49; makefile: 16
file content (100 lines) | stat: -rw-r--r-- 3,126 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <iostream>
#include <tuple>
#include <utility>
#include <vector>


namespace py = pybind11;

#include <hera/wasserstein_pure_geom.hpp>

using DynamicPointVector = hera::ws::dnn::DynamicPointVector<double>;

using Pair2dPoint = std::pair<double, double>;
using Tuple3dPoint = std::tuple<double, double, double>;

using Vector2dPoints = std::vector<Pair2dPoint>;
using Vector3dPoints = std::vector<Tuple3dPoint>;

using Params = hera::AuctionParams<double>;
using Result = hera::AuctionResult<double>;

using Traits = hera::ws::dnn::DynamicPointTraits<double>;


DynamicPointVector convert_2d_points_to_dnn(const Vector2dPoints& points)
{
    constexpr int dim = 2;
    Traits traits(dim);

    DynamicPointVector result = traits.container(points.size());

    for(size_t i = 0; i < points.size(); ++i) {
        result[i][0] = points[i].first;
        result[i][1] = points[i].second;
    }

    return result;
}

DynamicPointVector convert_3d_points_to_dnn(const Vector3dPoints& points)
{
    constexpr int dim = 3;
    Traits traits(dim);

    DynamicPointVector result = traits.container(points.size());

    for(size_t i = 0; i < points.size(); ++i) {
        result[i][0] = std::get<0>(points[i]);
        result[i][1] = std::get<1>(points[i]);
        result[i][2] = std::get<2>(points[i]);
    }

    return result;
}

Result wasserstein_cost_geom_detailed(const Vector2dPoints& points_1, const Vector2dPoints& points_2,  Params& params, const std::vector<double>& prices)
{
    using Traits = hera::ws::dnn::DynamicPointTraits<double>;
    constexpr int dim = 2;
    hera::ws::dnn::DynamicPointTraits<double> traits(dim);

    if (points_1.size() != points_2.size()) {
        std::cerr << "points_1.size = " << points_1.size() << " != points_2.size = " << points_2.size() << std::endl;
        throw std::runtime_error("Point clouds must have same cardinality");
    }

    auto dpoints_1 = convert_2d_points_to_dnn(points_1);
    auto dpoints_2 = convert_2d_points_to_dnn(points_2);

    if (params.return_matching) {
        std::cerr << "Matching for point clouds not implemented, need id for DynamicPoint" << std::endl;
        throw std::runtime_error("Matching for point clouds not supported");
    }

    return hera::ws::wasserstein_cost_detailed(dpoints_1, dpoints_2, params, prices);
}

double wasserstein_cost_geom(const Vector2dPoints& points_1, const Vector2dPoints& points_2,  Params& params, const std::vector<double>& prices)
{
    return wasserstein_cost_geom_detailed(points_1, points_2, params, prices).cost;
}

double wasserstein_cost_geom_no_params(const Vector2dPoints& points_1, const Vector2dPoints& points_2, const std::vector<double>& prices)
{
    hera::AuctionParams<double> params;
    params.dim = 2;

    return wasserstein_cost_geom(points_1, points_2, params, prices);
}


void init_ws_geom(py::module& m)
{
    m.def("wasserstein_cost_geom_detailed_", wasserstein_cost_geom_detailed);
    m.def("wasserstein_cost_geom_", wasserstein_cost_geom_no_params);
    m.def("wasserstein_cost_geom_", wasserstein_cost_geom);
}