File: plot_directive.py

package info (click to toggle)
contourpy 1.3.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 16,688 kB
  • sloc: python: 7,998; cpp: 6,241; makefile: 13
file content (92 lines) | stat: -rw-r--r-- 3,484 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import annotations

import os
from typing import Any, ClassVar

from docutils import nodes
from docutils.parsers.rst.directives import choice, flag
from sphinx.directives.code import CodeBlock

from contourpy.util.mpl_renderer import MplRenderer


class PlotDirective(CodeBlock):
    has_content = True
    optional_arguments = 2

    option_spec: ClassVar[dict[str, Any]] = {  # type: ignore[misc]
        "separate-modes": flag,
        "source-position": lambda x: choice(x, ("below", "above", "none")),
    }

    # dict of string docname -> latest image index used.
    latest_image_index: ClassVar[dict[str, int]] = {}

    def _mpl_mode_header(self, mode: str) -> str:
        if mode == "light":
            return "import matplotlib.pyplot as plt;plt.style.use('default');\n"
        elif mode == "dark":
            return "import matplotlib as mpl;cycler=mpl.rcParams['axes.prop_cycle'];\n" \
                "import matplotlib.pyplot as plt;plt.style.use('dark_background');\n" \
                "mpl.rcParams['axes.prop_cycle']=cycler;\n"
        else:
            raise ValueError(f"Unexpected mode {mode}")

    def _temporary_show(self, renderer: MplRenderer, image_filenames: list[str]) -> None:
        # Temporary replacement for MplRenderer.show() to save to SVG file instead.
        docname = self.env.docname
        index = self.latest_image_index.get(docname, -1) + 1
        self.latest_image_index[docname] = index

        directory, filename = os.path.split(docname)
        output_directory = os.path.join(directory, "generated")
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)

        output_filename = f"{filename}_{index}.svg"

        renderer.save(os.path.join(output_directory, output_filename), transparent=True)
        image_filenames.append(os.path.join("generated", output_filename))

    def run(self: Any) -> list[Any]:
        source_position = self.options.get("source-position", "below")

        source = self.content
        combined_source = "\n".join(source)

        using_modes = "separate-modes" in self.options
        modes = ["light", "dark"] if using_modes else ["light"]

        svg_filenames: list[str] = []

        # Temporarily replace MplRenderer.show() to save to SVG file and include SVG files in
        # sphinx output. Should probably be in a context manager instead.
        old_show = getattr(MplRenderer, "show")
        setattr(MplRenderer, "show", lambda renderer: self._temporary_show(renderer, svg_filenames))
        for mode in modes:
            exec(self._mpl_mode_header(mode) + combined_source)
        setattr(MplRenderer, "show", old_show)

        images: list[nodes.Node] = []
        for i, svg_filename in enumerate(svg_filenames):
            image = nodes.image(uri=svg_filename)
            if using_modes:
                mode = modes[i % len(modes)]
                image["classes"].append(f"only-{mode}")
            container = nodes.container()
            container += image
            images += container

        if source_position == "none":
            return images
        else:
            code_block = super().run()
            if source_position == "above":
                return code_block + images
            else:
                return images + code_block


def setup(app: Any) -> dict[str, bool]:
    app.add_directive("plot", PlotDirective)
    return {"parallel_read_safe": True, "parallel_write_safe": True}