1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
|
ignite.engine
==============
Main module of the library containing:
ignite.engine.engine
--------------------
.. currentmodule:: ignite.engine.engine
.. autosummary::
:nosignatures:
:toctree: generated
Engine
ignite.engine.events
--------------------
.. currentmodule:: ignite.engine.events
.. autosummary::
:nosignatures:
:toctree: generated
CallableEventWithFilter
EventEnum
Events
EventsList
State
RemovableEventHandle
ignite.engine.deterministic
---------------------------
Helper methods for deterministic training
.. currentmodule:: ignite.engine.deterministic
.. autosummary::
:nosignatures:
:toctree: generated
DeterministicEngine
ReproducibleBatchSampler
keep_random_state
update_dataloader
helper methods to define supervised trainer and evaluator
---------------------------------------------------------
.. currentmodule:: ignite.engine
.. autosummary::
:nosignatures:
:toctree: generated
create_supervised_trainer
create_supervised_evaluator
supervised_training_step
supervised_training_step_amp
supervised_training_step_apex
supervised_training_step_tpu
supervised_evaluation_step
supervised_evaluation_step_amp
Resuming the training
---------------------
It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour.
Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods
to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can
store the trainer and then resume the training. For example:
.. code-block:: python
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
.. code-block:: bash
ls /tmp/training
> "checkpoint_50000.pt"
We can then restore the training from the last checkpoint.
.. code-block:: python
from ignite.handlers import Checkpoint
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
trainer.run(train_loader, max_epochs=100)
It is also possible to store checkpoints every N iterations and continue the training from one of these checkpoints, i.e
from iteration.
Complete examples that resumes the training from a checkpoint can be found here:
- `save/resume MNIST <https://github.com/pytorch/ignite/tree/master/examples/mnist#training-save--resume>`_
- `save/resume Distributed CIFAR10 <https://github.com/pytorch/ignite/tree/master/examples/cifar10#check-resume-training>`_
Deterministic training
----------------------
In general, it is rather difficult task to achieve deterministic and reproducible trainings as it relies on multiple
aspects, e.g. data version, code version, software environment, hardware etc. According to `PyTorch note on randomness <https://pytorch.org/docs/stable/notes/randomness.html>`_:
there are some steps to take in order to make computations deterministic on your specific problem on one specific
platform and PyTorch release:
- setup random state seed
- set `cudnn to deterministic <https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking>`_ if applicable
By default, these two options can be enough to run and rerun experiments in a deterministic way. Ignite's engine does not impact this behaviour.
In this module we provide helper methods and classes to make additional ":ref:`Dataflow synchronization`"
to ensure that model sees the same data for a given epoch:
- :class:`~ignite.engine.deterministic.DeterministicEngine`
- :class:`~ignite.engine.deterministic.ReproducibleBatchSampler`
Dataflow synchronization
------------------------
Ignite provides an option to control the dataflow by synchronizing random state on epochs. In this way, for a given
iteration/epoch the dataflow can be the same for a given seed. More precisely it is roughly looks like:
.. code-block:: python
for e in range(num_epochs):
set_seed(seed + e)
do_single_epoch_iterations(dataloader)
In addition, if data provider is ``torch.utils.data.DataLoader``, batch data indices can be made completely deterministic.
Here is a trivial example of usage:
.. code-block:: python
import torch
from torch.utils.data import DataLoader
from ignite.engine import DeterministicEngine, Events
from ignite.utils import manual_seed
def random_train_data_loader(size):
data = torch.arange(0, size)
return DataLoader(data, batch_size=4, shuffle=True)
def print_train_data(engine, batch):
i = engine.state.iteration
e = engine.state.epoch
print("train", e, i, batch.tolist())
trainer = DeterministicEngine(print_train_data)
print("Original Run")
manual_seed(56)
trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5)
print("Resumed Run")
# Resume from 2nd epoch
trainer.load_state_dict({"epoch": 1, "epoch_length": 5, "max_epochs": 2, "rng_states": None})
manual_seed(56)
trainer.run(random_train_data_loader(40))
.. code-block:: text
Original Run
train 1 1 [31, 13, 3, 4]
train 1 2 [23, 18, 6, 16]
train 1 3 [10, 8, 33, 36]
train 1 4 [1, 37, 19, 9]
train 1 5 [20, 30, 14, 26]
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]
Resumed Run
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]
We can see that the data samples are exactly the same between original and resumed runs.
Complete examples that simulates a crash on a defined iteration and resumes the training from a checkpoint can be found
here:
- `save/resume MNIST <https://github.com/pytorch/ignite/tree/master/examples/mnist#training-save--resume>`_
- `save/resume Distributed CIFAR10 <https://github.com/pytorch/ignite/tree/master/examples/cifar10#check-resume-training>`_
.. Note ::
In case when input data is `torch.utils.data.DataLoader`, previous batches are skipped and the first provided batch
corresponds to the batch after the checkpoint iteration. Internally, while resuming, previous datapoint indices are just
skipped without fetching the data.
.. warning::
However, while resuming from iteration, random data augmentations are not synchronized in the middle of the epoch and
thus batches remaining until the end of the epoch can be different of those from the initial run.
.. warning::
However, please, keep in mind that there can be an issue with dataflow synchronization on every epoch
if user's handler synchronizes the random state, for example, by calling periodically ``torch.manual_seed(seed)`` during
the run. This can have an impact on the dataflow:
.. code-block:: python
def random_train_data_generator():
while True:
yield torch.randint(0, 100, size=(1, ))
trainer = DeterministicEngine(print_train_data)
@trainer.on(Events.ITERATION_COMPLETED(every=3))
def user_handler():
# handler synchronizes the random state
torch.manual_seed(12)
a = torch.rand(1)
trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5);
.. code-block:: text
train 1 1 [32]
train 1 2 [29]
train 1 3 [40]
train 1 4 [3] <---
train 1 5 [22]
train 2 6 [77]
train 2 7 [3] <---
train 2 8 [22]
train 2 9 [77]
train 2 10 [3] <---
train 3 11 [22]
train 3 12 [77]
train 3 13 [3] <---
train 3 14 [22]
train 3 15 [77]
Initially, the function ``random_train_data_generator()`` generates randomly data batches using the random state set
up by ``trainer``. This is intended behaviour until ``user_handler()`` is called.
After ``user_handler()`` execution, random state is altered and thus ``random_train_data_generator()`` will produce
random batches based on altered random state.
We provide helper decorator :meth:`~ignite.engine.deterministic.keep_random_state` to save and restore random states for
`torch`, `numpy` and `random`. Therefore, we can deal with described issue using this decorator:
.. code-block:: python
from ignite.engine.deterministic import keep_random_state
@trainer.on(Events.ITERATION_COMPLETED(every=3))
@keep_random_state
def user_handler():
# handler synchronizes the random state
torch.manual_seed(12)
a = torch.rand(1)
|