File: _utils.py

package info (click to toggle)
python-pot 0.9.5%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,884 kB
  • sloc: python: 56,498; cpp: 2,310; makefile: 265; sh: 19
file content (41 lines) | stat: -rw-r--r-- 1,268 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# -*- coding: utf-8 -*-
"""
Common tools of Bregman projections solvers for entropic regularized OT
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

from ..utils import list_to_array
from ..backend import get_backend


def geometricBar(weights, alldistribT):
    """return the weighted geometric mean of distributions"""
    weights, alldistribT = list_to_array(weights, alldistribT)
    nx = get_backend(weights, alldistribT)
    assert len(weights) == alldistribT.shape[1]
    return nx.exp(nx.dot(nx.log(alldistribT), weights.T))


def geometricMean(alldistribT):
    """return the  geometric mean of distributions"""
    alldistribT = list_to_array(alldistribT)
    nx = get_backend(alldistribT)
    return nx.exp(nx.mean(nx.log(alldistribT), axis=1))


def projR(gamma, p):
    """return the KL projection on the row constraints"""
    gamma, p = list_to_array(gamma, p)
    nx = get_backend(gamma, p)
    return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T


def projC(gamma, q):
    """return the KL projection on the column constraints"""
    gamma, q = list_to_array(gamma, q)
    nx = get_backend(gamma, q)
    return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)