File: gallery_scraper.py

package info (click to toggle)
python-vispy 0.15.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,868 kB
  • sloc: python: 59,799; javascript: 6,800; makefile: 69; sh: 6
file content (268 lines) | stat: -rw-r--r-- 10,048 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# -*- coding: utf-8 -*-
# Copyright (c) Vispy Development Team. All Rights Reserved.
# Distributed under the (new) BSD License. See LICENSE.txt for more info.
"""Scraper for sphinx-gallery.

This is used to collect screenshots from the examples when executed via
sphinx-gallery. This can be included in any project wanting to take
advantage of this by adding the following to your sphinx ``conf.py``:

.. code-block:: python

    sphinx_gallery_conf = {
        ...
        'image_scrapers': ('vispy',)
    }

The scraper is provided to sphinx-gallery via the
``vispy._get_sg_image_scraper()`` function.

"""

from __future__ import annotations

import os
import time
import shutil
from vispy.io import imsave
from vispy.gloo.util import _screenshot
from vispy.scene import SceneCanvas

from sphinx_gallery.scrapers import optipng, figure_rst


class VisPyGalleryScraper:
    """Custom sphinx-gallery scraper to save the current Canvas to an image."""

    def __repr__(self):
        return self.__class__.__name__

    def __call__(self, block, block_vars, gallery_conf):
        """Scrape VisPy Canvases and applications.

        Parameters
        ----------
        block : tuple
            A tuple containing the (label, content, line_number) of the block.
        block_vars : dict
            Dict of block variables.
        gallery_conf : dict
            Contains the configuration of Sphinx-Gallery

        Returns
        -------
        rst : str
            The ReSTructuredText that will be rendered to HTML containing
            the images. This is often produced by
            :func:`sphinx_gallery.scrapers.figure_rst`.

        """
        example_fn = block_vars["src_file"]
        frame_num_list = self._get_frame_list_from_source(example_fn)
        image_path_iterator = block_vars['image_path_iterator']
        canvas_or_widget = get_canvaslike_from_globals(block_vars["example_globals"])
        if not frame_num_list:
            image_paths = []
        elif isinstance(frame_num_list[0], str):
            # example produces an image/animation as output
            image_paths = []
            for frame_image, image_path in zip(frame_num_list, image_path_iterator):
                image_path = os.path.splitext(image_path)[0] + os.path.splitext(frame_image)[1]
                shutil.move(frame_image, image_path)
                image_paths.append(image_path)
        else:
            image_paths = self._save_example_to_files(
                canvas_or_widget, frame_num_list, gallery_conf, image_path_iterator)

        fig_titles = ""  # alt text
        # FUTURE: Handle non-images (ex. MP4s) with raw HTML
        return figure_rst(image_paths, gallery_conf['src_dir'], fig_titles)

    def _save_example_to_files(self, canvas_or_widget, frame_num_list, gallery_conf, image_path_iterator):
        image_path = next(image_path_iterator)
        frame_grabber = FrameGrabber(canvas_or_widget, frame_num_list)
        frame_grabber.collect_frames()
        if len(frame_num_list) > 1:
            # let's make an animation
            # FUTURE: mp4 with imageio?
            image_path = os.path.splitext(image_path)[0] + ".gif"
            frame_grabber.save_animation(image_path)
        else:
            frame_grabber.save_frame(image_path)
        frame_grabber.cleanup()
        if 'images' in gallery_conf['compress_images']:
            optipng(image_path, gallery_conf['compress_images_args'])
        return [image_path]

    def _get_frame_list_from_source(self, filename):
        lines = open(filename, 'rb').read().decode('utf-8').splitlines()
        for line in lines[:10]:
            if not line.startswith("# vispy:"):
                continue
            if "gallery-exports" in line:
                _frames = line.split('gallery-exports')[1].split(',')[0].strip()
                frames = self._frame_exports_to_list(_frames)
                break
            if "gallery " in line:
                # Get what frames to grab
                _frames = line.split('gallery')[1].split(',')[0].strip()
                frames = self._frame_specifier_to_list(_frames)
                break
        else:
            # no frame number hint - don't grab any frames
            frames = []
        return frames

    def _frame_specifier_to_list(self, frame_specifier):
        _frames = frame_specifier or '0'
        frames = [int(i) for i in _frames.split(':')]
        if not frames:
            frames = [5]
        if len(frames) > 1:
            frames = list(range(*frames))
        return frames

    def _frame_exports_to_list(self, frame_specifier):
        frames = frame_specifier.split(" ")
        frame_paths = []
        for frame_fn in frames:
            # existing image file created by the example
            if not os.path.isfile(frame_fn):
                raise FileNotFoundError(
                    "Example gallery frame specifier must be a frame number, "
                    "frame range, or relative filename produced by the example.")
            frame_paths.append(frame_fn)
        return frame_paths


def get_canvaslike_from_globals(globals_dict):
    qt_widget = _get_qt_top_parent(globals_dict)
    if qt_widget is not None:
        return qt_widget

    # Get canvas
    if "canvas" in globals_dict:
        return globals_dict["canvas"]
    if "Canvas" in globals_dict:
        return globals_dict["Canvas"]()
    if "fig" in globals_dict:
        return globals_dict["fig"]
    return None


def _get_qt_top_parent(globals_dict):
    if "QWidget" not in globals_dict and "QMainWindow" not in globals_dict and "QtWidgets" not in globals_dict:
        return None

    qtwidgets = globals_dict.get("QtWidgets")
    qmainwindow = globals_dict.get("QMainWindow", getattr(qtwidgets, "QMainWindow", None))
    qwidget = globals_dict.get("QWidget", getattr(qtwidgets, "QWidget", qmainwindow))
    all_qt_widgets = [widget for widget in globals_dict.values()
                      if isinstance(widget, qwidget) and widget is not None]
    all_qt_mains = [widget for widget in all_qt_widgets if isinstance(widget, qmainwindow)]
    if all_qt_mains:
        return all_qt_mains[0]
    if all_qt_widgets:
        return all_qt_widgets[0]
    return None


class FrameGrabber:
    """Helper to grab a series of screenshots from the current Canvas-like object."""

    def __init__(self, canvas_obj, frame_grab_list: list[int]):
        self._canvas = canvas_obj
        self._done = False
        self._current_frame = -1
        self._collected_images = []
        self._frames_to_grab = frame_grab_list[:]  # copy so original list is preserved

    def cleanup(self):
        from PyQt5.QtWidgets import QApplication
        for child_widget in QApplication.allWidgets():
            if hasattr(child_widget, 'close'):
                child_widget.close()
        QApplication.processEvents()

    def on_draw(self, _):
        if self._done:
            return  # Grab only once
        self._current_frame += 1
        if self._current_frame in self._frames_to_grab:
            self._frames_to_grab.remove(self._current_frame)
            if isinstance(self._canvas, SceneCanvas):
                im = self._canvas.render(alpha=True)
            else:
                im = _screenshot()
            self._collected_images.append(im)
        if not self._frames_to_grab or self._current_frame > self._frames_to_grab[0]:
            self._done = True

    def collect_frames(self):
        """Show current Canvas and render and collect all frames requested."""
        if self._is_qt_widget():
            self._grab_qt_screenshot()
        else:
            self._grab_vispy_screenshots()

    def _is_qt_widget(self):
        try:
            from PyQt5.QtWidgets import QWidget
        except ImportError:
            return False
        return isinstance(self._canvas, QWidget)

    def _grab_qt_screenshot(self):
        from PyQt5.QtWidgets import QApplication
        from PyQt5.QtCore import QTimer
        self._canvas.show()
        # Qt is going to grab from the screen so we need the window on top
        self._canvas.raise_()
        # We need to give the GUI event loop and OS time to draw everything
        time.sleep(1.5)
        QApplication.processEvents()
        QTimer.singleShot(1000, self._grab_widget_screenshot)
        time.sleep(1.5)
        QApplication.processEvents()

    def _grab_widget_screenshot(self):
        from PyQt5.QtWidgets import QApplication
        screen = QApplication.screenAt(self._canvas.pos())
        screenshot = screen.grabWindow(int(self._canvas.windowHandle().winId()))
        arr = self._qpixmap_to_ndarray(screenshot)
        self._collected_images.append(arr)

    @staticmethod
    def _qpixmap_to_ndarray(pixmap):
        from PyQt5 import QtGui
        import numpy as np
        im = pixmap.toImage().convertToFormat(QtGui.QImage.Format.Format_RGB32)
        size = pixmap.size()
        width = size.width()
        height = size.height()
        im_bits = im.constBits()
        im_bits.setsize(height * width * 4)
        # Convert 0xffRRGGBB buffer -> (B, G, R, 0xff) -> (R, G, B)
        return np.array(im_bits).reshape((height, width, 4))[:, :, 2::-1]

    def _grab_vispy_screenshots(self):
        os.environ['VISPY_IGNORE_OLD_VERSION'] = 'true'
        self._canvas.events.draw.connect(self.on_draw, position='last')
        with self._canvas as c:
            self._collect_frames(c)

    def _collect_frames(self, canvas, limit=10000):
        n = 0
        while not self._done and n < limit:
            canvas.update()
            canvas.app.process_events()
            n += 1
        if n >= limit or len(self._frames_to_grab) > 0:
            raise RuntimeError("Could not collect any images")

    def save_frame(self, filename, frame_index=0):
        imsave(filename, self._collected_images[frame_index])

    def save_animation(self, filename):
        import imageio  # multiple gif not properly supported yet
        imageio.mimsave(filename, self._collected_images)