File: utils.py

package info (click to toggle)
python-papermill 2.6.0-3.1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,216 kB
  • sloc: python: 4,977; makefile: 17; sh: 5
file content (192 lines) | stat: -rw-r--r-- 4,742 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import logging
import os
import warnings
from contextlib import contextmanager
from functools import wraps

from .exceptions import PapermillParameterOverwriteWarning

logger = logging.getLogger('papermill.utils')


def any_tagged_cell(nb, tag):
    """Whether the notebook contains at least one cell tagged ``tag``?

    Parameters
    ----------
    nb : nbformat.NotebookNode
        The notebook to introspect
    tag : str
        The tag to look for

    Returns
    -------
    bool
        Whether the notebook contains a cell tagged ``tag``?
    """
    return any([tag in cell.metadata.tags for cell in nb.cells])


def nb_kernel_name(nb, name=None):
    """Helper for fetching out the kernel name from a notebook object.

    Parameters
    ----------
    nb : nbformat.NotebookNode
        The notebook to introspect
    name : str
        A provided name field

    Returns
    -------
    str
        The name of the kernel

    Raises
    ------
    ValueError
        If no kernel name is found or provided
    """
    name = name or nb.metadata.get('kernelspec', {}).get('name')
    if not name:
        raise ValueError("No kernel name found in notebook and no override provided.")
    return name


def nb_language(nb, language=None):
    """Helper for fetching out the programming language from a notebook object.

    Parameters
    ----------
    nb : nbformat.NotebookNode
        The notebook to introspect
    language : str
        A provided language field

    Returns
    -------
    str
        The programming language of the notebook

    Raises
    ------
    ValueError
        If no notebook language is found or provided
    """
    language = language or nb.metadata.get('language_info', {}).get('name')
    if not language:
        # v3 language path for old notebooks that didn't convert cleanly
        language = language or nb.metadata.get('kernelspec', {}).get('language')
    if not language:
        raise ValueError("No language found in notebook and no override provided.")
    return language


def find_first_tagged_cell_index(nb, tag):
    """Find the first tagged cell ``tag`` in the notebook.

    Parameters
    ----------
    nb : nbformat.NotebookNode
        The notebook to introspect
    tag : str
        The tag to look for

    Returns
    -------
    nbformat.NotebookNode
        Whether the notebook contains a cell tagged ``tag``?
    """
    parameters_indices = []
    for idx, cell in enumerate(nb.cells):
        if tag in cell.metadata.tags:
            parameters_indices.append(idx)
    if not parameters_indices:
        return -1
    return parameters_indices[0]


def merge_kwargs(caller_args, **callee_args):
    """Merge named argument.

    Function takes a dictionary of caller arguments and callee arguments as keyword arguments
    Returns a dictionary with merged arguments. If same argument is in both caller and callee
    arguments the last one will be taken and warning will be raised.

    Parameters
    ----------
    caller_args : dict
        Caller arguments
    **callee_args
        Keyword callee arguments

    Returns
    -------
    args : dict
       Merged arguments
    """
    conflicts = set(caller_args) & set(callee_args)
    if conflicts:
        args = format('; '.join([f'{key}={value}' for key, value in callee_args.items()]))
        msg = f"Callee will overwrite caller's argument(s): {args}"
        warnings.warn(msg, PapermillParameterOverwriteWarning)
    return dict(caller_args, **callee_args)


def remove_args(args=None, **kwargs):
    """Remove arguments from kwargs.

    Parameters
    ----------
    args : list
        Argument names to remove from kwargs
    **kwargs
        Arbitrary keyword arguments

    Returns
    -------
    kwargs : dict
       New dictionary of arguments
    """
    if not args:
        return kwargs
    return {k: v for k, v in kwargs.items() if k not in args}


# retry decorator
def retry(num):
    def decorate(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            exception = None

            for i in range(num):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    logger.debug(f'Retrying after: {e}')
                    exception = e
            else:
                raise exception

        return wrapper

    return decorate


@contextmanager
def chdir(path):
    """Change working directory to `path` and restore old path on exit.

    `path` can be `None` in which case this is a no-op.
    """
    if path is None:
        yield

    else:
        old_dir = os.getcwd()
        os.chdir(path)
        try:
            yield
        finally:
            os.chdir(old_dir)