File: test_beat.py

package info (click to toggle)
mir-eval 0.8.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,784 kB
  • sloc: python: 8,364; javascript: 261; makefile: 154
file content (148 lines) | stat: -rw-r--r-- 4,312 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Unit tests for mir_eval.beat
"""

import numpy as np
import json
import mir_eval
import glob
import pytest

A_TOL = 1e-12

# Path to the fixture files
REF_GLOB = "data/beat/ref*.txt"
EST_GLOB = "data/beat/est*.txt"
SCORES_GLOB = "data/beat/output*.json"

ref_files = sorted(glob.glob(REF_GLOB))
est_files = sorted(glob.glob(EST_GLOB))
sco_files = sorted(glob.glob(SCORES_GLOB))

assert len(ref_files) == len(est_files) == len(sco_files) > 0

file_sets = list(zip(ref_files, est_files, sco_files))


@pytest.fixture
def beat_data(request):
    ref_f, est_f, sco_f = request.param
    with open(sco_f) as f:
        expected_scores = json.load(f)
    reference_beats = mir_eval.io.load_events(ref_f)
    estimated_beats = mir_eval.io.load_events(est_f)

    return reference_beats, estimated_beats, expected_scores


def test_trim_beats():
    # Construct dummy beat times [0., 1., ...]
    dummy_beats = np.arange(10, dtype=np.float64)
    # We expect trim_beats to remove all beats < 5s
    expected_beats = np.arange(5, 10, dtype=np.float64)
    assert np.allclose(mir_eval.beat.trim_beats(dummy_beats), expected_beats)


@pytest.mark.parametrize(
    "metric",
    [
        mir_eval.beat.f_measure,
        mir_eval.beat.cemgil,
        mir_eval.beat.goto,
        mir_eval.beat.p_score,
        mir_eval.beat.continuity,
        mir_eval.beat.information_gain,
    ],
)
def test_beat_empty_warnings(metric):
    with pytest.warns(UserWarning, match="Reference beats are empty."):
        metric(np.array([]), np.arange(10))
    with pytest.warns(UserWarning, match="Estimated beats are empty."):
        metric(np.arange(10), np.array([]))
    with pytest.warns(UserWarning, match="beats are empty."):
        assert np.allclose(metric(np.array([]), np.array([])), 0)


@pytest.mark.parametrize(
    "metric",
    [
        mir_eval.beat.f_measure,
        mir_eval.beat.cemgil,
        mir_eval.beat.goto,
        mir_eval.beat.p_score,
        mir_eval.beat.continuity,
        mir_eval.beat.information_gain,
    ],
)
@pytest.mark.parametrize(
    "beats",
    [
        np.array([[1.0, 2.0]]),  # beats must be a 1d array
        np.array([1e10, 1e11]),  # beats must be not huge
        np.array([2.0, 1.0]),  # must be sorted
    ],
)
@pytest.mark.xfail(raises=ValueError)
def test_beat_fail(metric, beats):
    metric(beats, beats)


@pytest.mark.parametrize(
    "metric",
    [
        mir_eval.beat.f_measure,
        mir_eval.beat.cemgil,
        mir_eval.beat.goto,
        mir_eval.beat.p_score,
        mir_eval.beat.continuity,
        mir_eval.beat.information_gain,
    ],
)
def test_beat_perfect(metric):
    beats = np.arange(10, dtype=np.float64)
    assert np.allclose(metric(beats, beats), 1)


@pytest.mark.parametrize("beat_data", file_sets, indirect=True)
def test_beat_functions(beat_data):
    reference_beats, estimated_beats, expected_scores = beat_data

    # Compute scores
    scores = mir_eval.beat.evaluate(reference_beats, estimated_beats)
    # Compare them
    assert scores.keys() == expected_scores.keys()
    for metric in scores:
        assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL)


# Unit tests for specific behavior not covered by the above
def test_goto_proportion_correct():
    # This covers the case when over 75% of the beat tracking is correct, and
    # more than 3 beats are incorrect
    assert mir_eval.beat.goto(
        np.arange(100), np.append(np.arange(80), np.arange(80, 100) + 0.2)
    )


@pytest.mark.parametrize(
    "metric",
    [mir_eval.beat.p_score, mir_eval.beat.continuity, mir_eval.beat.information_gain],
)
def test_warning_on_one_beat(metric):
    # This tests the metrics where passing only a single beat raises a warning
    # and returns 0

    with pytest.warns(UserWarning, match="Only one reference beat"):
        metric(np.array([10]), np.arange(10))
    with pytest.warns(UserWarning, match="Only one estimated beat"):
        metric(np.arange(10), np.array([10]))


def test_continuity_edge_cases():
    # There is some special-case logic for when there are few beats
    assert np.allclose(
        mir_eval.beat.continuity(np.array([6.0, 6.0]), np.array([6.0, 7.0])), 0.0
    )
    assert np.allclose(
        mir_eval.beat.continuity(np.array([6.0, 6.0]), np.array([6.5, 7.0])), 0.0
    )