File: _utils.py

package info (click to toggle)
python-nbstripout 0.8.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 436 kB
  • sloc: python: 870; sh: 18; makefile: 13
file content (156 lines) | stat: -rw-r--r-- 5,522 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from collections import defaultdict
import sys

__all__ = ["pop_recursive", "strip_output", "strip_zeppelin_output", "MetadataError"]


class MetadataError(Exception):
    pass


def pop_recursive(d, key, default=None):
    """dict.pop(key) where `key` is a `.`-delimited list of nested keys.

    >>> d = {'a': {'b': 1, 'c': 2}}
    >>> pop_recursive(d, 'a.c')
    2
    >>> d
    {'a': {'b': 1}}
    """
    if not isinstance(d, dict):
        return default
    if key in d:
        return d.pop(key, default)
    if '.' not in key:
        return default
    key_head, key_tail = key.split('.', maxsplit=1)
    if key_head in d:
        return pop_recursive(d[key_head], key_tail, default)
    return default


def _cells(nb, conditionals):
    """Remove cells not satisfying any conditional in conditionals and yield all other cells."""
    if hasattr(nb, 'nbformat') and nb.nbformat < 4:
        for ws in nb.worksheets:
            for conditional in conditionals:
                ws.cells = list(filter(conditional, ws.cells))
            for cell in ws.cells:
                yield cell
    else:
        for conditional in conditionals:
            nb.cells = list(filter(conditional, nb.cells))
        for cell in nb.cells:
            yield cell


def get_size(item):
    """ Recursively sums length of all strings in `item` """
    if isinstance(item, str):
        return len(item)
    elif isinstance(item, list):
        return sum(get_size(elem) for elem in item)
    elif isinstance(item, dict):
        return get_size(list(item.values()))
    else:
        return len(str(item))


def determine_keep_output(cell, default, strip_init_cells=False):
    """Given a cell, determine whether output should be kept

    Based on whether the metadata has "init_cell": true,
    "keep_output": true, or the tags contain "keep_output" """
    if 'metadata' not in cell:
        return default
    if 'init_cell' in cell.metadata:
        return bool(cell.metadata.init_cell) and not strip_init_cells

    has_keep_output_metadata = 'keep_output' in cell.metadata
    keep_output_metadata = bool(cell.metadata.get('keep_output', False))

    has_keep_output_tag = 'keep_output' in cell.metadata.get('tags', [])

    # keep_output between metadata and tags should not contradict each other
    if has_keep_output_metadata and has_keep_output_tag and not keep_output_metadata:
        raise MetadataError(
            'cell metadata contradicts tags: `keep_output` is false, but `keep_output` in tags'
        )

    if has_keep_output_metadata or has_keep_output_tag:
        return keep_output_metadata or has_keep_output_tag
    return default


def _zeppelin_cells(nb):
    for pg in nb['paragraphs']:
        yield pg


def strip_zeppelin_output(nb):
    for cell in _zeppelin_cells(nb):
        if 'results' in cell:
            cell['results'] = {}
    return nb


def strip_output(nb, keep_output, keep_count, keep_id, extra_keys=[], drop_empty_cells=False, drop_tagged_cells=[],
                 strip_init_cells=False, max_size=0):
    """
    Strip the outputs, execution count/prompt number and miscellaneous
    metadata from a notebook object, unless specified to keep either the outputs
    or counts.

    `extra_keys` could be 'metadata.foo cell.metadata.bar metadata.baz'
    """
    if keep_output is None and 'keep_output' in nb.metadata:
        keep_output = bool(nb.metadata['keep_output'])

    keys = defaultdict(list)
    for key in extra_keys:
        if '.' not in key or key.split('.')[0] not in ['cell', 'metadata']:
            sys.stderr.write(f'Ignoring invalid extra key `{key}`\n')
        else:
            namespace, subkey = key.split('.', maxsplit=1)
            keys[namespace].append(subkey)

    for field in keys['metadata']:
        pop_recursive(nb.metadata, field)

    conditionals = []
    # Keep cells if they have any `source` line that contains non-whitespace
    if drop_empty_cells:
        conditionals.append(lambda c: any(line.strip() for line in c.get('source', [])))
    for tag_to_drop in drop_tagged_cells:
        conditionals.append(lambda c: tag_to_drop not in c.get("metadata", {}).get("tags", []))

    for i, cell in enumerate(_cells(nb, conditionals)):
        keep_output_this_cell = determine_keep_output(cell, keep_output, strip_init_cells)

        # Remove the outputs, unless directed otherwise
        if 'outputs' in cell:

            # Default behavior (max_size == 0) strips all outputs.
            if not keep_output_this_cell:
                cell['outputs'] = [output for output in cell['outputs']
                                   if get_size(output) <= max_size]

            # Strip the counts from the outputs that were kept if not keep_count.
            if not keep_count:
                for output in cell['outputs']:
                    if 'execution_count' in output:
                        output['execution_count'] = None

            # If keep_output_this_cell and keep_count, do nothing.

        # Remove the prompt_number/execution_count, unless directed otherwise
        if 'prompt_number' in cell and not keep_count:
            cell['prompt_number'] = None
        if 'execution_count' in cell and not keep_count:
            cell['execution_count'] = None
        # Replace the cell id with an incremental value that will be consistent across runs
        if 'id' in cell and not keep_id:
            cell['id'] = str(i)
        for field in keys['cell']:
            pop_recursive(cell, field)
    return nb