1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
# Author: Brian M. Clapper, G. Varoquaux, Lars Buitinck
# License: BSD
from numpy.testing import assert_array_equal, assert_raises
import numpy as np
from scipy.optimize import linear_sum_assignment
def test_linear_sum_assignment():
for cost_matrix, expected_cost in [
# Square
([[400, 150, 400],
[400, 450, 600],
[300, 225, 300]],
[150, 400, 300]
),
# Rectangular variant
([[400, 150, 400, 1],
[400, 450, 600, 2],
[300, 225, 300, 3]],
[150, 2, 300]),
# Square
([[10, 10, 8],
[9, 8, 1],
[9, 7, 4]],
[10, 1, 7]),
# Rectangular variant
([[10, 10, 8, 11],
[9, 8, 1, 1],
[9, 7, 4, 10]],
[10, 1, 4]),
# n == 2, m == 0 matrix
([[], []],
[]),
]:
cost_matrix = np.array(cost_matrix)
row_ind, col_ind = linear_sum_assignment(cost_matrix)
assert_array_equal(row_ind, np.sort(row_ind))
assert_array_equal(expected_cost, cost_matrix[row_ind, col_ind])
cost_matrix = cost_matrix.T
row_ind, col_ind = linear_sum_assignment(cost_matrix)
assert_array_equal(row_ind, np.sort(row_ind))
assert_array_equal(np.sort(expected_cost),
np.sort(cost_matrix[row_ind, col_ind]))
def test_linear_sum_assignment_input_validation():
assert_raises(ValueError, linear_sum_assignment, [1, 2, 3])
C = [[1, 2, 3], [4, 5, 6]]
assert_array_equal(linear_sum_assignment(C),
linear_sum_assignment(np.asarray(C)))
assert_array_equal(linear_sum_assignment(C),
linear_sum_assignment(np.matrix(C)))
|