1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
"""
Some unit tests for the pattern discovery task.
"""
import numpy as np
import json
import mir_eval
import glob
import pytest
A_TOL = 1e-12
# Path to the fixture files
REF_GLOB = "data/pattern/ref*.txt"
EST_GLOB = "data/pattern/est*.txt"
SCORES_GLOB = "data/pattern/output*.json"
ref_files = sorted(glob.glob(REF_GLOB))
est_files = sorted(glob.glob(EST_GLOB))
sco_files = sorted(glob.glob(SCORES_GLOB))
assert len(ref_files) == len(est_files) == len(sco_files) > 0
file_sets = list(zip(ref_files, est_files, sco_files))
@pytest.fixture
def pattern_data(request):
ref_f, est_f, sco_f = request.param
with open(sco_f) as f:
expected_scores = json.load(f)
reference_patterns = mir_eval.io.load_patterns(ref_f)
estimated_patterns = mir_eval.io.load_patterns(est_f)
return reference_patterns, estimated_patterns, expected_scores
@pytest.mark.parametrize(
"metric",
[
mir_eval.pattern.standard_FPR,
mir_eval.pattern.establishment_FPR,
mir_eval.pattern.occurrence_FPR,
mir_eval.pattern.three_layer_FPR,
mir_eval.pattern.first_n_three_layer_P,
mir_eval.pattern.first_n_target_proportion_R,
],
)
def test_pattern_empty(metric):
# First, test for a warning on empty pattern
with pytest.warns(UserWarning, match="Reference patterns are empty"):
metric([[[]]], [[[(100, 20)]]])
with pytest.warns(UserWarning, match="Estimated patterns are empty"):
metric([[[(100, 20)]]], [[[]]])
with pytest.warns(UserWarning, match="patterns are empty"):
# And that the metric is 0
assert np.allclose(metric([[[]]], [[[]]]), 0)
@pytest.mark.parametrize(
"metric",
[
mir_eval.pattern.standard_FPR,
mir_eval.pattern.establishment_FPR,
mir_eval.pattern.occurrence_FPR,
mir_eval.pattern.three_layer_FPR,
mir_eval.pattern.first_n_three_layer_P,
mir_eval.pattern.first_n_target_proportion_R,
],
)
@pytest.mark.parametrize(
"patterns",
[
[[[(100, 20)]], []], # patterns must have at least one occurrence
[[[(100, 20, 3)]]], # (onset, midi) tuple must contain 2 elements
],
)
@pytest.mark.xfail(raises=ValueError)
def test_pattern_failure(metric, patterns):
metric(patterns, patterns)
@pytest.mark.parametrize(
"metric",
[
mir_eval.pattern.standard_FPR,
mir_eval.pattern.establishment_FPR,
mir_eval.pattern.occurrence_FPR,
mir_eval.pattern.three_layer_FPR,
mir_eval.pattern.first_n_three_layer_P,
mir_eval.pattern.first_n_target_proportion_R,
],
)
def test_pattern_perfect(metric):
# Valid patterns which are the same produce a score of 1 for all metrics
patterns = [[[(100, 20), (200, 30)]]]
assert np.allclose(metric(patterns, patterns), 1)
@pytest.mark.parametrize("pattern_data", file_sets, indirect=True)
def test_pattern_functions(pattern_data):
reference_patterns, estimated_patterns, expected_scores = pattern_data
# Compute scores
scores = mir_eval.pattern.evaluate(reference_patterns, estimated_patterns)
# Compare them
assert scores.keys() == expected_scores.keys()
for metric in scores:
assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL)
|