File: allow_nan_estimators.py

package info (click to toggle)
scikit-learn 1.7.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 25,616 kB
  • sloc: python: 219,123; cpp: 5,790; ansic: 846; makefile: 172; javascript: 110
file content (58 lines) | stat: -rw-r--r-- 2,187 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from contextlib import suppress

from docutils import nodes
from docutils.parsers.rst import Directive

from sklearn.utils import all_estimators
from sklearn.utils._test_common.instance_generator import _construct_instances
from sklearn.utils._testing import SkipTest


class AllowNanEstimators(Directive):
    @staticmethod
    def make_paragraph_for_estimator_type(estimator_type):
        intro = nodes.list_item()
        intro += nodes.strong(text="Estimators that allow NaN values for type ")
        intro += nodes.literal(text=f"{estimator_type}")
        intro += nodes.strong(text=":\n")
        exists = False
        lst = nodes.bullet_list()
        for name, est_class in all_estimators(type_filter=estimator_type):
            with suppress(SkipTest):
                # Here we generate the text only for one instance. This directive
                # should not be used for meta-estimators where tags depend on the
                # sub-estimator.
                est = next(_construct_instances(est_class))

                if est.__sklearn_tags__().input_tags.allow_nan:
                    module_name = ".".join(est_class.__module__.split(".")[:2])
                    class_title = f"{est_class.__name__}"
                    class_url = f"./generated/{module_name}.{class_title}.html"
                    item = nodes.list_item()
                    para = nodes.paragraph()
                    para += nodes.reference(
                        class_title, text=class_title, internal=False, refuri=class_url
                    )
                    exists = True
                    item += para
                    lst += item
        intro += lst
        return [intro] if exists else None

    def run(self):
        lst = nodes.bullet_list()
        for i in ["cluster", "regressor", "classifier", "transformer"]:
            item = self.make_paragraph_for_estimator_type(i)
            if item is not None:
                lst += item
        return [lst]


def setup(app):
    app.add_directive("allow_nan_estimators", AllowNanEstimators)

    return {
        "version": "0.1",
        "parallel_read_safe": True,
        "parallel_write_safe": True,
    }