1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
from io import StringIO
import numpy as np
import pandas as pd
import pytest
from sorted_nearest import get_all_ties, get_different_ties
@pytest.fixture
def data():
c = """ids dist
1 1
1 1
1 2
1 3
0 5000
0 5000
0 5000
2 100
2 110
2 110
2 111
3 111
3 111
3 111
3 112
3 112
3 113
4 112
4 113
4 113
4 113
4 113
4 113
4 113
4 150"""
df = pd.read_table(StringIO(c), header=0, sep="\s+")
return df
def test_get_all_ties(data):
df = data
print(df)
k = 2
result = get_all_ties(df.index.values, df.ids.values, df.dist.values, k)
print(df.reindex(result))
print(result)
expected = [0, 1, 4, 5, 6, 7, 8, 9, 11, 12, 13, 17, 18, 19, 20, 21, 22, 23]
assert list(result) == expected
def test_get_different_ties(data):
df = data
k = 2
result = get_different_ties(df.index.values, df.ids.values, df.dist.values, k)
print(df.reindex(result))
print(result)
assert list(result) == [0, 1, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23]
|