1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
import logging
import os
import warnings
from contextlib import contextmanager
from functools import wraps
from .exceptions import PapermillParameterOverwriteWarning
logger = logging.getLogger('papermill.utils')
def any_tagged_cell(nb, tag):
"""Whether the notebook contains at least one cell tagged ``tag``?
Parameters
----------
nb : nbformat.NotebookNode
The notebook to introspect
tag : str
The tag to look for
Returns
-------
bool
Whether the notebook contains a cell tagged ``tag``?
"""
return any([tag in cell.metadata.tags for cell in nb.cells])
def nb_kernel_name(nb, name=None):
"""Helper for fetching out the kernel name from a notebook object.
Parameters
----------
nb : nbformat.NotebookNode
The notebook to introspect
name : str
A provided name field
Returns
-------
str
The name of the kernel
Raises
------
ValueError
If no kernel name is found or provided
"""
name = name or nb.metadata.get('kernelspec', {}).get('name')
if not name:
raise ValueError("No kernel name found in notebook and no override provided.")
return name
def nb_language(nb, language=None):
"""Helper for fetching out the programming language from a notebook object.
Parameters
----------
nb : nbformat.NotebookNode
The notebook to introspect
language : str
A provided language field
Returns
-------
str
The programming language of the notebook
Raises
------
ValueError
If no notebook language is found or provided
"""
language = language or nb.metadata.get('language_info', {}).get('name')
if not language:
# v3 language path for old notebooks that didn't convert cleanly
language = language or nb.metadata.get('kernelspec', {}).get('language')
if not language:
raise ValueError("No language found in notebook and no override provided.")
return language
def find_first_tagged_cell_index(nb, tag):
"""Find the first tagged cell ``tag`` in the notebook.
Parameters
----------
nb : nbformat.NotebookNode
The notebook to introspect
tag : str
The tag to look for
Returns
-------
nbformat.NotebookNode
Whether the notebook contains a cell tagged ``tag``?
"""
parameters_indices = []
for idx, cell in enumerate(nb.cells):
if tag in cell.metadata.tags:
parameters_indices.append(idx)
if not parameters_indices:
return -1
return parameters_indices[0]
def merge_kwargs(caller_args, **callee_args):
"""Merge named argument.
Function takes a dictionary of caller arguments and callee arguments as keyword arguments
Returns a dictionary with merged arguments. If same argument is in both caller and callee
arguments the last one will be taken and warning will be raised.
Parameters
----------
caller_args : dict
Caller arguments
**callee_args
Keyword callee arguments
Returns
-------
args : dict
Merged arguments
"""
conflicts = set(caller_args) & set(callee_args)
if conflicts:
args = format('; '.join([f'{key}={value}' for key, value in callee_args.items()]))
msg = f"Callee will overwrite caller's argument(s): {args}"
warnings.warn(msg, PapermillParameterOverwriteWarning)
return dict(caller_args, **callee_args)
def remove_args(args=None, **kwargs):
"""Remove arguments from kwargs.
Parameters
----------
args : list
Argument names to remove from kwargs
**kwargs
Arbitrary keyword arguments
Returns
-------
kwargs : dict
New dictionary of arguments
"""
if not args:
return kwargs
return {k: v for k, v in kwargs.items() if k not in args}
# retry decorator
def retry(num):
def decorate(func):
@wraps(func)
def wrapper(*args, **kwargs):
exception = None
for i in range(num):
try:
return func(*args, **kwargs)
except Exception as e:
logger.debug(f'Retrying after: {e}')
exception = e
else:
raise exception
return wrapper
return decorate
@contextmanager
def chdir(path):
"""Change working directory to `path` and restore old path on exit.
`path` can be `None` in which case this is a no-op.
"""
if path is None:
yield
else:
old_dir = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(old_dir)
|