File: test_parallel.py

package info (click to toggle)
scikit-learn 1.8.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 26,132 kB
  • sloc: python: 224,867; cpp: 5,790; ansic: 846; makefile: 190; javascript: 179
file content (197 lines) | stat: -rw-r--r-- 7,589 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import itertools
import re
import time
import warnings

import joblib
import numpy as np
import pytest
from numpy.testing import assert_array_equal

from sklearn import config_context, get_config
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.fixes import _IS_WASM
from sklearn.utils.parallel import Parallel, delayed


def get_working_memory():
    return get_config()["working_memory"]


@pytest.mark.parametrize("n_jobs", [1, 2])
@pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"])
def test_configuration_passes_through_to_joblib(n_jobs, backend):
    # Tests that the global global configuration is passed to joblib jobs

    with config_context(working_memory=123):
        results = Parallel(n_jobs=n_jobs, backend=backend)(
            delayed(get_working_memory)() for _ in range(2)
        )

    assert_array_equal(results, [123] * 2)


def test_parallel_delayed_warnings():
    """Informative warnings should be raised when mixing sklearn and joblib API"""
    # We should issue a warning when one wants to use sklearn.utils.fixes.Parallel
    # with joblib.delayed. The config will not be propagated to the workers.
    warn_msg = "`sklearn.utils.parallel.Parallel` needs to be used in conjunction"
    with pytest.warns(UserWarning, match=warn_msg) as records:
        Parallel()(joblib.delayed(time.sleep)(0) for _ in range(10))
    assert len(records) == 10

    # We should issue a warning if one wants to use sklearn.utils.fixes.delayed with
    # joblib.Parallel
    warn_msg = (
        "`sklearn.utils.parallel.delayed` should be used with "
        "`sklearn.utils.parallel.Parallel` to make it possible to propagate"
    )
    with pytest.warns(UserWarning, match=warn_msg) as records:
        joblib.Parallel()(delayed(time.sleep)(0) for _ in range(10))
    assert len(records) == 10


@pytest.mark.parametrize("n_jobs", [1, 2])
def test_dispatch_config_parallel(n_jobs):
    """Check that we properly dispatch the configuration in parallel processing.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/25239
    """
    pd = pytest.importorskip("pandas")
    iris = load_iris(as_frame=True)

    class TransformerRequiredDataFrame(StandardScaler):
        def fit(self, X, y=None):
            assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
            return super().fit(X, y)

        def transform(self, X, y=None):
            assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
            return super().transform(X, y)

    dropper = make_column_transformer(
        ("drop", [0]),
        remainder="passthrough",
        n_jobs=n_jobs,
    )
    param_grid = {"randomforestclassifier__max_depth": [1, 2, 3]}
    search_cv = GridSearchCV(
        make_pipeline(
            dropper,
            TransformerRequiredDataFrame(),
            RandomForestClassifier(n_estimators=5, n_jobs=n_jobs),
        ),
        param_grid,
        cv=5,
        n_jobs=n_jobs,
        error_score="raise",  # this search should not fail
    )

    # make sure that `fit` would fail in case we don't request dataframe
    with pytest.raises(AssertionError, match="X should be a DataFrame"):
        search_cv.fit(iris.data, iris.target)

    with config_context(transform_output="pandas"):
        # we expect each intermediate steps to output a DataFrame
        search_cv.fit(iris.data, iris.target)

    assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()


def raise_warning():
    warnings.warn("Convergence warning", ConvergenceWarning)


def _yield_n_jobs_backend_combinations():
    n_jobs_values = [1, 2]
    backend_values = ["loky", "threading", "multiprocessing"]
    for n_jobs, backend in itertools.product(n_jobs_values, backend_values):
        if n_jobs == 2 and backend == "loky":
            # XXX Mark thread-unsafe to avoid:
            # RuntimeError: The executor underlying Parallel has been shutdown.
            # See https://github.com/joblib/joblib/issues/1743 for more details.
            yield pytest.param(n_jobs, backend, marks=pytest.mark.thread_unsafe)
        else:
            yield n_jobs, backend


@pytest.mark.parametrize("n_jobs, backend", _yield_n_jobs_backend_combinations())
def test_filter_warning_propagates(n_jobs, backend):
    """Check warning propagates to the job."""
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        with pytest.raises(ConvergenceWarning):
            Parallel(n_jobs=n_jobs, backend=backend)(
                delayed(raise_warning)() for _ in range(2)
            )


def get_warning_filters():
    # In free-threading Python >= 3.14, warnings filters are managed through a
    # ContextVar and warnings.filters is not modified inside a
    # warnings.catch_warnings context. You need to use warnings._get_filters().
    # For more details, see
    # https://docs.python.org/3.14/whatsnew/3.14.html#concurrent-safe-warnings-control
    filters_func = getattr(warnings, "_get_filters", None)
    return filters_func() if filters_func is not None else warnings.filters


def test_check_warnings_threading():
    """Check that warnings filters are set correctly in the threading backend."""
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        main_warning_filters = get_warning_filters()

        assert ("error", None, ConvergenceWarning, None, 0) in main_warning_filters

        all_worker_warning_filters = Parallel(n_jobs=2, backend="threading")(
            delayed(get_warning_filters)() for _ in range(2)
        )

        def normalize_main_module(filters):
            # In Python 3.14 free-threaded, there is a small discrepancy main
            # warning filters have an entry with module = "__main__" whereas it
            # is a regex in the workers
            return [
                (
                    action,
                    message,
                    type_,
                    module
                    if "__main__" not in str(module)
                    or not isinstance(module, re.Pattern)
                    else module.pattern,
                    lineno,
                )
                for action, message, type_, module, lineno in main_warning_filters
            ]

        for worker_warning_filter in all_worker_warning_filters:
            assert normalize_main_module(
                worker_warning_filter
            ) == normalize_main_module(main_warning_filters)


@pytest.mark.xfail(_IS_WASM, reason="Pyodide always use the sequential backend")
def test_filter_warning_propagates_no_side_effect_with_loky_backend():
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        Parallel(n_jobs=2, backend="loky")(delayed(time.sleep)(0) for _ in range(10))

        # Since loky workers are reused, make sure that inside the loky workers,
        # warnings filters have been reset to their original value. Using joblib
        # directly should not turn ConvergenceWarning into an error.
        joblib.Parallel(n_jobs=2, backend="loky")(
            joblib.delayed(warnings.warn)("Convergence warning", ConvergenceWarning)
            for _ in range(10)
        )