File: _hook_iterator.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (240 lines) | stat: -rw-r--r-- 11,185 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import inspect
import functools
from enum import Enum

import torch.autograd


class _SnapshotState(Enum):
    r"""
    These are the snapshotting-related states that IterDataPipes can be in.
    `NotStarted` - allows you to restore a snapshot and create an iterator with reset
    `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
    `Iterating` - can restore, will reset if you create a new iterator
    """
    NotStarted = 0
    Restored = 1
    Iterating = 2


def _simplify_obj_name(obj) -> str:
    """
    Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.
    """
    if inspect.isfunction(obj):
        return obj.__name__
    else:
        return repr(obj)


def _generate_input_args_string(obj):
    """
    Generate a string for the input arguments of an object.
    """
    signature = inspect.signature(obj.__class__)
    input_param_names = set()
    for param_name, _ in signature.parameters.items():
        input_param_names.add(param_name)
    result = []
    for name, obj in inspect.getmembers(obj):
        if name in input_param_names:
            result.append((name, _simplify_obj_name(obj)))
    return ', '.join([f'{name}={value}' for name, value in result])


def _generate_iterdatapipe_msg(datapipe):
    return f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"


def _gen_invalid_iterdatapipe_msg(datapipe):
    return ("This iterator has been invalidated because another iterator has been created "
            f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
            "This may be caused multiple references to the same IterDataPipe. We recommend "
            "using `.fork()` if that is necessary.")


_feedback_msg = ("\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
                 "to comment on this issue: https://github.com/pytorch/data/issues/45.")


def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
    r"""
    Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
    In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
    """
    if next_method_exists:
        # This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
        # The `_valid_iterator_id` should either be never set (`None`), or set by at most one
        # iterator (`0`). Otherwise, it means there are multiple iterators.
        if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
            extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
            raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg)
    elif hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
        if hasattr(datapipe, "_check_valid_iterator_id"):
            if not datapipe._check_valid_iterator_id(iterator_id):
                raise RuntimeError("This iterator has been invalidated, because a new iterator has been created "
                                   f"from one of the ChildDataPipes of "
                                   f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + _feedback_msg)
        else:
            raise RuntimeError("ChildDataPipe must have method `_check_valid_iterator_id`.")
    elif datapipe._valid_iterator_id != iterator_id:
        raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)


def _set_datapipe_valid_iterator_id(datapipe):
    r"""
    Given a DataPipe, updates its valid iterator ID and reset the DataPipe.
    """
    if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
        if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
            datapipe._set_main_datapipe_valid_iterator_id()  # reset() is called within this method when appropriate
        else:
            raise RuntimeError("ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`.")
    else:
        if datapipe._valid_iterator_id is None:
            datapipe._valid_iterator_id = 0
        else:
            datapipe._valid_iterator_id += 1
        datapipe.reset()
    return datapipe._valid_iterator_id


def hook_iterator(namespace, profile_name):
    r"""
    Hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. This is done for the purpose of
    profiling and checking if an iterator is still valid.
    """
    def profiler_record_fn_context():
        return torch.autograd.profiler.record_function(profile_name)

    class IteratorDecorator:
        r"""
        Wrap the iterator and modifying its `__next__` method. This decorator is applied to
        DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
        method commonly returns `self` but not necessarily.
        """
        def __init__(self, iterator, source_dp, iterator_id, has_next_method):
            self.iterator = iterator
            self.source_dp = source_dp
            self.iterator_id = iterator_id
            self._profiler_enabled = torch.autograd._profiler_enabled()
            # Check if `__iter__` returns `self` and `DataPipe` has `__next__`
            self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method

        def __iter__(self):
            return self

        def _get_next(self):
            r"""
            Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
            """
            _check_iterator_valid(self.source_dp, self.iterator_id)
            result = next(self.iterator)
            if not self.self_and_has_next_method:
                self.source_dp._number_of_samples_yielded += 1
            return result

        def __next__(self):
            # TODO: Add try-except to in-place reduce traceback from the Exception
            # See: https://github.com/pytorch/data/issues/284
            if self._profiler_enabled:
                with profiler_record_fn_context():
                    return self._get_next()
            else:  # Decided against using `contextlib.nullcontext` for performance reasons
                return self._get_next()

        def __getattr__(self, name):
            return getattr(self.iterator, name)

    func = namespace['__iter__']

    # ``__iter__`` of IterDataPipe is a generator function
    if inspect.isgeneratorfunction(func):
        @functools.wraps(func)
        def wrap_generator(*args, **kwargs):
            gen = func(*args, **kwargs)
            datapipe = args[0]
            if datapipe._fast_forward_iterator:
                it = datapipe._fast_forward_iterator
                datapipe._fast_forward_iterator = None
                datapipe._snapshot_state = _SnapshotState.Iterating
                while True:
                    try:
                        yield next(it)
                    except StopIteration:
                        return
            iterator_id = _set_datapipe_valid_iterator_id(datapipe)  # This ID is tied to each created iterator
            _profiler_enabled = torch.autograd._profiler_enabled()
            try:
                if _profiler_enabled:
                    with profiler_record_fn_context():
                        response = gen.send(None)
                else:
                    response = gen.send(None)

                while True:
                    datapipe._number_of_samples_yielded += 1
                    request = yield response
                    # Pass through here every time `__next__` is called
                    if _profiler_enabled:
                        with profiler_record_fn_context():
                            _check_iterator_valid(datapipe, iterator_id)
                            response = gen.send(request)
                    else:  # Decided against using `contextlib.nullcontext` for performance reasons
                        _check_iterator_valid(datapipe, iterator_id)
                        response = gen.send(request)
            except StopIteration as e:
                return
            except Exception as e:
                # TODO: Simplify the traceback message to skip over `response = gen.send(None)`
                #       Part of https://github.com/pytorch/data/issues/284
                datapipe = args[0]
                msg = "thrown by __iter__ of"
                single_iterator_msg = "single iterator per IterDataPipe constraint"
                if hasattr(e.args, '__len__'):
                    full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
                    if len(e.args) == 0 or not isinstance(e.args[0], str):  # If an exception message doesn't exist
                        e.args = (f'\nThis exception is {full_msg}',)
                    elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
                        e.args = (e.args[0] + f'\nThis exception is {full_msg}',) + e.args[1:]
                raise

        namespace['__iter__'] = wrap_generator
    else:  # ``__iter__`` of IterDataPipe is NOT a generator function
        # IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
        # And ``__iter__`` may or may not return `self`
        if '__next__' in namespace:  # If `__next__` exists, put a wrapper around it
            next_func = namespace['__next__']

            @functools.wraps(next_func)
            def wrap_next(*args, **kwargs):
                if torch.autograd._profiler_enabled():
                    with profiler_record_fn_context():
                        result = next_func(*args, **kwargs)
                else:
                    result = next_func(*args, **kwargs)
                datapipe = args[0]
                datapipe._number_of_samples_yielded += 1
                return result

            namespace['__next__'] = wrap_next

            # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
            # the user will be violating the iterator protocol. Potential issue:
            # 1. Valid iterator ID may not update or checked properly
            # 2. The number of samples yielded will be miscounted

        # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
        @functools.wraps(func)
        def wrap_iter(*args, **kwargs):
            iter_ret = func(*args, **kwargs)
            datapipe = args[0]
            datapipe._snapshot_state = _SnapshotState.Iterating
            if datapipe._fast_forward_iterator:
                iter_ret = datapipe._fast_forward_iterator
                datapipe._fast_forward_iterator = None
                return iter_ret
            iterator_id = _set_datapipe_valid_iterator_id(datapipe)  # This ID is tied to each created iterator
            return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)

        namespace['__iter__'] = wrap_iter