1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
from contextlib import suppress
from docutils import nodes
from docutils.parsers.rst import Directive
from sklearn.utils import all_estimators
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import _construct_instance
class AllowNanEstimators(Directive):
@staticmethod
def make_paragraph_for_estimator_type(estimator_type):
intro = nodes.list_item()
intro += nodes.strong(text="Estimators that allow NaN values for type ")
intro += nodes.literal(text=f"{estimator_type}")
intro += nodes.strong(text=":\n")
exists = False
lst = nodes.bullet_list()
for name, est_class in all_estimators(type_filter=estimator_type):
with suppress(SkipTest):
est = _construct_instance(est_class)
if est._get_tags().get("allow_nan"):
module_name = ".".join(est_class.__module__.split(".")[:2])
class_title = f"{est_class.__name__}"
class_url = f"./generated/{module_name}.{class_title}.html"
item = nodes.list_item()
para = nodes.paragraph()
para += nodes.reference(
class_title, text=class_title, internal=False, refuri=class_url
)
exists = True
item += para
lst += item
intro += lst
return [intro] if exists else None
def run(self):
lst = nodes.bullet_list()
for i in ["cluster", "regressor", "classifier", "transformer"]:
item = self.make_paragraph_for_estimator_type(i)
if item is not None:
lst += item
return [lst]
def setup(app):
app.add_directive("allow_nan_estimators", AllowNanEstimators)
return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
|