File: engines.py

package info (click to toggle)
python-papermill 2.6.0-3.1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,216 kB
  • sloc: python: 4,977; makefile: 17; sh: 5
file content (449 lines) | stat: -rw-r--r-- 16,013 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
"""Engines to perform different roles"""
import datetime
import sys
from functools import wraps

import dateutil
from importlib.metadata import entry_points

from .clientwrap import PapermillNotebookClient
from .exceptions import PapermillException
from .iorw import write_ipynb
from .log import logger
from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args


class PapermillEngines:
    """
    The holder which houses any engine registered with the system.

    This object is used in a singleton manner to save and load particular
    named Engine objects so they may be referenced externally.
    """

    def __init__(self):
        self._engines = {}

    def register(self, name, engine):
        """Register a named engine"""
        self._engines[name] = engine

    def register_entry_points(self):
        """Register entrypoints for an engine

        Load handlers provided by other packages
        """
        for entrypoint in entry_points().select(group="papermill.engine"):
            self.register(entrypoint.name, entrypoint.load())

    def get_engine(self, name=None):
        """Retrieves an engine by name."""
        engine = self._engines.get(name)
        if not engine:
            raise PapermillException(f"No engine named '{name}' found")
        return engine

    def execute_notebook_with_engine(self, engine_name, nb, kernel_name, **kwargs):
        """Fetch a named engine and execute the nb object against it."""
        return self.get_engine(engine_name).execute_notebook(nb, kernel_name, **kwargs)

    def nb_kernel_name(self, engine_name, nb, name=None):
        """Fetch kernel name from the document by dropping-down into the provided engine."""
        return self.get_engine(engine_name).nb_kernel_name(nb, name)

    def nb_language(self, engine_name, nb, language=None):
        """Fetch language from the document by dropping-down into the provided engine."""
        return self.get_engine(engine_name).nb_language(nb, language)


def catch_nb_assignment(func):
    """
    Wrapper to catch `nb` keyword arguments

    This helps catch `nb` keyword arguments and assign onto self when passed to
    the wrapped function.

    Used for callback methods when the caller may optionally have a new copy
    of the originally wrapped `nb` object.
    """

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        nb = kwargs.get('nb')
        if nb:
            # Reassign if executing notebook object was replaced
            self.nb = nb
        return func(self, *args, **kwargs)

    return wrapper


class NotebookExecutionManager:
    """
    Wrapper for execution state of a notebook.

    This class is a wrapper for notebook objects to house execution state
    related to the notebook being run through an engine.

    In particular the NotebookExecutionManager provides common update callbacks
    for use within engines to facilitate metadata and persistence actions in a
    shared manner.
    """

    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"

    def __init__(self, nb, output_path=None, log_output=False, progress_bar=True, autosave_cell_every=30):
        self.nb = nb
        self.output_path = output_path
        self.log_output = log_output
        self.start_time = None
        self.end_time = None
        self.autosave_cell_every = autosave_cell_every
        self.max_autosave_pct = 25
        self.last_save_time = self.now()  # Not exactly true, but simplifies testing logic
        self.pbar = None
        if progress_bar:
            # lazy import due to implicit slow ipython import
            from tqdm.auto import tqdm

            if isinstance(progress_bar, bool):
                self.pbar = tqdm(total=len(self.nb.cells), unit="cell", desc="Executing")
            elif isinstance(progress_bar, dict):
                _progress_bar = {"unit": "cell", "desc": "Executing"}
                _progress_bar.update(progress_bar)
                self.pbar = tqdm(total=len(self.nb.cells), **_progress_bar)
            else:
                raise TypeError(
                    f"progress_bar must be instance of bool or dict, but actual type '{type(progress_bar)}'."
                )

    def now(self):
        """Helper to return current UTC time"""
        return datetime.datetime.utcnow()

    def set_timer(self):
        """
        Initializes the execution timer for the notebook.

        This is called automatically when a NotebookExecutionManager is
        constructed.
        """
        self.start_time = self.now()
        self.end_time = None

    @catch_nb_assignment
    def save(self, **kwargs):
        """
        Saves the wrapped notebook state.

        If an output path is known, this triggers a save of the wrapped
        notebook state to the provided path.

        Can be used outside of cell state changes if execution is taking
        a long time to conclude but the notebook object should be synced.

        For example, you may want to save the notebook every 10 minutes when running
        a 5 hour cell execution to capture output messages in the notebook.
        """
        if self.output_path:
            write_ipynb(self.nb, self.output_path)
        self.last_save_time = self.now()

    @catch_nb_assignment
    def autosave_cell(self):
        """Saves the notebook if it's been more than self.autosave_cell_every seconds
        since it was last saved.
        """
        if self.autosave_cell_every == 0:
            # feature is disabled
            return
        time_since_last_save = (self.now() - self.last_save_time).total_seconds()
        if time_since_last_save >= self.autosave_cell_every:
            start_save = self.now()
            self.save()
            save_elapsed = (self.now() - start_save).total_seconds()
            if save_elapsed > self.autosave_cell_every * self.max_autosave_pct / 100.0:
                # Autosave is taking too long, so exponentially back off.
                self.autosave_cell_every *= 2
                logger.warning(
                    "Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec".format(
                        save_elapsed, self.max_autosave_pct, self.autosave_cell_every
                    )
                )

    @catch_nb_assignment
    def notebook_start(self, **kwargs):
        """
        Initialize a notebook, clearing its metadata, and save it.

        When starting a notebook, this initializes and clears the metadata for
        the notebook and its cells, and saves the notebook to the given
        output path.

        Called by Engine when execution begins.
        """
        self.set_timer()

        self.nb.metadata.papermill['start_time'] = self.start_time.isoformat()
        self.nb.metadata.papermill['end_time'] = None
        self.nb.metadata.papermill['duration'] = None
        self.nb.metadata.papermill['exception'] = None

        for cell in self.nb.cells:
            # Reset the cell execution counts.
            if cell.get("cell_type") == "code":
                cell.execution_count = None

            # Clear out the papermill metadata for each cell.
            cell.metadata.papermill = dict(
                exception=None,
                start_time=None,
                end_time=None,
                duration=None,
                status=self.PENDING,  # pending, running, completed
            )
            if cell.get("cell_type") == "code":
                cell.outputs = []

        self.save()

    @catch_nb_assignment
    def cell_start(self, cell, cell_index=None, **kwargs):
        """
        Set and save a cell's start state.

        Optionally called by engines during execution to initialize the
        metadata for a cell and save the notebook to the output path.
        """
        if self.log_output:
            ceel_num = cell_index + 1 if cell_index is not None else ''
            logger.info(f'Executing Cell {ceel_num:-<40}')

        cell.metadata.papermill['start_time'] = self.now().isoformat()
        cell.metadata.papermill["status"] = self.RUNNING
        cell.metadata.papermill['exception'] = False

        # injects optional description of the current cell directly in the tqdm
        cell_description = self.get_cell_description(cell)
        if cell_description is not None and hasattr(self, 'pbar') and self.pbar:
            self.pbar.set_description(f"Executing {cell_description}")

        self.save()

    @catch_nb_assignment
    def cell_exception(self, cell, cell_index=None, **kwargs):
        """
        Set metadata when an exception is raised.

        Called by engines when an exception is raised within a notebook to
        set the metadata on the notebook indicating the location of the
        failure.
        """
        cell.metadata.papermill['exception'] = True
        cell.metadata.papermill['status'] = self.FAILED
        self.nb.metadata.papermill['exception'] = True

    @catch_nb_assignment
    def cell_complete(self, cell, cell_index=None, **kwargs):
        """
        Finalize metadata for a cell and save notebook.

        Optionally called by engines during execution to finalize the
        metadata for a cell and save the notebook to the output path.
        """
        end_time = self.now()

        if self.log_output:
            ceel_num = cell_index + 1 if cell_index is not None else ''
            logger.info(f'Ending Cell {ceel_num:-<43}')
            # Ensure our last cell messages are not buffered by python
            sys.stdout.flush()
            sys.stderr.flush()

        cell.metadata.papermill['end_time'] = end_time.isoformat()
        if cell.metadata.papermill.get('start_time'):
            start_time = dateutil.parser.parse(cell.metadata.papermill['start_time'])
            cell.metadata.papermill['duration'] = (end_time - start_time).total_seconds()
        if cell.metadata.papermill['status'] != self.FAILED:
            cell.metadata.papermill['status'] = self.COMPLETED

        self.save()
        if self.pbar:
            self.pbar.update(1)

    @catch_nb_assignment
    def notebook_complete(self, **kwargs):
        """
        Finalize the metadata for a notebook and save the notebook to
        the output path.

        Called by Engine when execution concludes, regardless of exceptions.
        """
        self.end_time = self.now()
        self.nb.metadata.papermill['end_time'] = self.end_time.isoformat()
        if self.nb.metadata.papermill.get('start_time'):
            self.nb.metadata.papermill['duration'] = (self.end_time - self.start_time).total_seconds()

        # Cleanup cell statuses in case callbacks were never called
        for cell in self.nb.cells:
            if cell.metadata.papermill['status'] == self.FAILED:
                break
            elif cell.metadata.papermill['status'] == self.PENDING:
                cell.metadata.papermill['status'] = self.COMPLETED

        self.complete_pbar()
        self.cleanup_pbar()

        # Force a final sync
        self.save()

    def get_cell_description(self, cell, escape_str="papermill_description="):
        """Fetches cell description if present"""
        if cell is None:
            return None

        cell_code = cell["source"]
        if cell_code is None or escape_str not in cell_code:
            return None

        return cell_code.split(escape_str)[1].split()[0]

    def complete_pbar(self):
        """Refresh progress bar"""
        if hasattr(self, 'pbar') and self.pbar:
            self.pbar.n = len(self.nb.cells)
            self.pbar.refresh()

    def cleanup_pbar(self):
        """Clean up a progress bar"""
        if hasattr(self, 'pbar') and self.pbar:
            self.pbar.close()
            self.pbar = None

    def __del__(self):
        self.cleanup_pbar()


class Engine:
    """
    Base class for engines.

    Other specific engine classes should inherit and implement the
    `execute_managed_notebook` method.

    Defines `execute_notebook` method which is used to correctly setup
    the `NotebookExecutionManager` object for engines to interact against.
    """

    @classmethod
    def execute_notebook(
        cls,
        nb,
        kernel_name,
        output_path=None,
        progress_bar=True,
        log_output=False,
        autosave_cell_every=30,
        **kwargs,
    ):
        """
        A wrapper to handle notebook execution tasks.

        Wraps the notebook object in a `NotebookExecutionManager` in order to track
        execution state in a uniform manner. This is meant to help simplify
        engine implementations. This allows a developer to just focus on
        iterating and executing the cell contents.
        """
        nb_man = NotebookExecutionManager(
            nb,
            output_path=output_path,
            progress_bar=progress_bar,
            log_output=log_output,
            autosave_cell_every=autosave_cell_every,
        )

        nb_man.notebook_start()
        try:
            cls.execute_managed_notebook(nb_man, kernel_name, log_output=log_output, **kwargs)
        finally:
            nb_man.cleanup_pbar()
            nb_man.notebook_complete()

        return nb_man.nb

    @classmethod
    def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
        """An abstract method where implementation will be defined in a subclass."""
        raise NotImplementedError("'execute_managed_notebook' is not implemented for this engine")

    @classmethod
    def nb_kernel_name(cls, nb, name=None):
        """Use default implementation to fetch kernel name from the notebook object"""
        return nb_kernel_name(nb, name)

    @classmethod
    def nb_language(cls, nb, language=None):
        """Use default implementation to fetch programming language from the notebook object"""
        return nb_language(nb, language)


class NBClientEngine(Engine):
    """
    A notebook engine representing an nbclient process.

    This can execute a notebook document and update the `nb_man.nb` object with
    the results.
    """

    @classmethod
    def execute_managed_notebook(
        cls,
        nb_man,
        kernel_name,
        log_output=False,
        stdout_file=None,
        stderr_file=None,
        start_timeout=60,
        execution_timeout=None,
        **kwargs,
    ):
        """
        Performs the actual execution of the parameterized notebook locally.

        Args:
            nb_man (NotebookExecutionManager): Wrapper for execution state of a notebook.
            kernel_name (str): Name of kernel to execute the notebook against.
            log_output (bool): Flag for whether or not to write notebook output to the
                               configured logger.
            start_timeout (int): Duration to wait for kernel start-up.
            execution_timeout (int): Duration to wait before failing execution (default: never).
        """

        # Exclude parameters that are unused downstream
        kwargs = remove_args(['input_path'], **kwargs)

        # Exclude parameters that named differently downstream
        safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs)

        # Nicely handle preprocessor arguments prioritizing values set by engine
        final_kwargs = merge_kwargs(
            safe_kwargs,
            timeout=execution_timeout if execution_timeout else kwargs.get('timeout'),
            startup_timeout=start_timeout,
            kernel_name=kernel_name,
            log=logger,
            log_output=log_output,
            stdout_file=stdout_file,
            stderr_file=stderr_file,
        )
        return PapermillNotebookClient(nb_man, **final_kwargs).execute()


# Instantiate a PapermillEngines instance, register Handlers and entrypoints
papermill_engines = PapermillEngines()
papermill_engines.register(None, NBClientEngine)
papermill_engines.register('nbclient', NBClientEngine)
papermill_engines.register_entry_points()