1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
|
# -*- coding: utf-8 -*-
"""TQDM logger."""
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Union
from ignite.engine import Engine, Events
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle
from ignite.handlers.base_logger import BaseLogger, BaseOutputHandler
class ProgressBar(BaseLogger):
"""
TQDM progress bar handler to log training progress and computed metrics.
Args:
persist: set to ``True`` to persist the progress bar after completion (default = ``False``)
bar_format : Specify a custom bar string formatting. May impact performance.
[default: '{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]'].
Set to ``None`` to use ``tqdm`` default bar formatting: '{l_bar}{bar}{r_bar}', where
l_bar='{desc}: {percentage:3.0f}%|' and
r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'. For more details on the
formatting, see `tqdm docs <https://tqdm.github.io/docs/tqdm/>`_.
tqdm_kwargs: kwargs passed to tqdm progress bar.
By default, progress bar description displays "Epoch [5/10]" where 5 is the current epoch and 10 is the
number of epochs; however, if ``max_epochs`` are set to 1, the progress bar instead displays
"Iteration: [5/10]". If tqdm_kwargs defines `desc`, e.g. "Predictions", than the description is
"Predictions [5/10]" if number of epochs is more than one otherwise it is simply "Predictions".
Examples:
Simple progress bar
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
pbar = ProgressBar()
pbar.attach(trainer)
# Progress bar will looks like
# Epoch [2/50]: [64/128] 50%|█████ [06:17<12:34]
Log output to a file instead of stderr (tqdm's default output)
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
log_file = open("output.log", "w")
pbar = ProgressBar(file=log_file)
pbar.attach(trainer)
Attach metrics that already have been computed at :attr:`~ignite.engine.events.Events.ITERATION_COMPLETED`
(such as :class:`~ignite.metrics.RunningAverage`)
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
pbar = ProgressBar()
pbar.attach(trainer, ['loss'])
# Progress bar will looks like
# Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
Directly attach the engine's output
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {'loss': x})
# Progress bar will looks like
# Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
are also logged along with the NLL and Accuracy after each iteration:
.. code-block:: python
pbar.attach(
trainer,
metric_names=["nll", "accuracy"],
state_attributes=["alpha", "beta"],
)
Note:
When attaching the progress bar to an engine, it is recommended that you replace
every print operation in the engine's handlers triggered every iteration with
``pbar.log_message`` to guarantee the correct format of the stdout.
Note:
When using inside jupyter notebook, `ProgressBar` automatically uses `tqdm_notebook`. For correct rendering,
please install `ipywidgets <https://ipywidgets.readthedocs.io/en/stable/user_install.html#installation>`_.
Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set
to an empty string value.
.. versionchanged:: 0.4.7
`attach` now accepts an optional list of `state_attributes`
"""
_events_order: List[Union[Events, CallableEventWithFilter]] = [
Events.STARTED,
Events.EPOCH_STARTED,
Events.ITERATION_STARTED,
Events.ITERATION_COMPLETED,
Events.EPOCH_COMPLETED,
Events.COMPLETED,
]
def __init__(
self,
persist: bool = False,
bar_format: Union[
str, None
] = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
**tqdm_kwargs: Any,
):
try:
from tqdm.autonotebook import tqdm
except ImportError:
raise ModuleNotFoundError(
"This contrib module requires tqdm to be installed. "
"Please install it with command: \n pip install tqdm"
)
self.pbar_cls = tqdm
self.pbar = None
self.persist = persist
self.bar_format = bar_format
self.tqdm_kwargs = tqdm_kwargs
def _reset(self, pbar_total: Optional[int]) -> None:
self.pbar = self.pbar_cls(
total=pbar_total, leave=self.persist, bar_format=self.bar_format, initial=1, **self.tqdm_kwargs
)
def _close(self, engine: Engine) -> None:
if self.pbar is not None:
# https://github.com/tqdm/notebook.py#L240-L250
# issue #1115 : notebook backend of tqdm checks if n < total (error or KeyboardInterrupt)
# and the bar persists in 'danger' mode
if self.pbar.total is not None:
self.pbar.n = self.pbar.total
self.pbar.close()
self.pbar = None
@staticmethod
def _compare_lt(
event1: Union[Events, CallableEventWithFilter], event2: Union[Events, CallableEventWithFilter]
) -> bool:
i1 = ProgressBar._events_order.index(event1)
i2 = ProgressBar._events_order.index(event2)
return i1 < i2
def log_message(self, message: str) -> None:
"""
Logs a message, preserving the progress bar correct output format.
Args:
message: string you wish to log.
"""
from tqdm import tqdm
tqdm.write(message, file=self.tqdm_kwargs.get("file", None))
def attach( # type: ignore[override]
self,
engine: Engine,
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
event_name: Union[Events, CallableEventWithFilter] = Events.ITERATION_COMPLETED,
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
state_attributes: Optional[List[str]] = None,
) -> None:
"""
Attaches the progress bar to an engine object.
Args:
engine: engine object.
metric_names: list of metric names to plot or a string "all" to plot all available
metrics.
output_transform: a function to select what you want to print from the engine's
output. This function may return either a dictionary with entries in the format of ``{name: value}``,
or a single scalar, which will be displayed with the default name `output`.
event_name: event's name on which the progress bar advances. Valid events are from
:class:`~ignite.engine.events.Events`.
closing_event_name: event's name on which the progress bar is closed. Valid events are from
:class:`~ignite.engine.events.Events`.
state_attributes: list of attributes of the ``trainer.state`` to plot.
Note:
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
"""
desc = self.tqdm_kwargs.get("desc", None)
if event_name not in engine._allowed_events:
raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")
if isinstance(closing_event_name, CallableEventWithFilter):
if closing_event_name.filter is not None:
raise ValueError("Closing Event should not be a filtered event")
if not self._compare_lt(event_name, closing_event_name):
raise ValueError(f"Logging event {event_name} should be called before closing event {closing_event_name}")
log_handler = _OutputHandler(
desc,
metric_names,
output_transform,
closing_event_name=closing_event_name,
state_attributes=state_attributes,
)
super(ProgressBar, self).attach(engine, log_handler, event_name)
engine.add_event_handler(closing_event_name, self._close)
def attach_opt_params_handler( # type: ignore[empty-body]
self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any
) -> RemovableEventHandle:
"""Intentionally empty"""
pass
def _create_output_handler(self, *args: Any, **kwargs: Any) -> "_OutputHandler":
return _OutputHandler(*args, **kwargs)
def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable: # type: ignore[empty-body]
"""Intentionally empty"""
pass
class _OutputHandler(BaseOutputHandler):
"""Helper handler to log engine's output and/or metrics
pbar = ProgressBar()
Args:
description: progress bar description.
metric_names: list of metric names to plot or a string "all" to plot all available
metrics.
output_transform: output transform function to prepare `engine.state.output` as a number.
For example, `output_transform = lambda output: output`
This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot
with corresponding keys.
closing_event_name: event's name on which the progress bar is closed. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
state_attributes: list of attributes of the ``trainer.state`` to plot.
"""
def __init__(
self,
description: str,
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
state_attributes: Optional[List[str]] = None,
):
if metric_names is None and output_transform is None:
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
metric_names = []
super(_OutputHandler, self).__init__(
description, metric_names, output_transform, global_step_transform=None, state_attributes=state_attributes
)
self.closing_event_name = closing_event_name
@staticmethod
def get_max_number_events(event_name: Union[str, Events, CallableEventWithFilter], engine: Engine) -> Optional[int]:
if event_name in (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED):
return engine.state.epoch_length
if event_name in (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED):
return engine.state.max_epochs
return 1
def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, Events]) -> None:
pbar_total = self.get_max_number_events(event_name, engine)
if logger.pbar is None:
logger._reset(pbar_total=pbar_total)
max_epochs = engine.state.max_epochs
default_desc = "Iteration" if max_epochs == 1 else "Epoch"
desc = self.tag or default_desc
max_num_of_closing_events = self.get_max_number_events(self.closing_event_name, engine)
if max_num_of_closing_events and max_num_of_closing_events > 1:
global_step = engine.state.get_event_attrib_value(self.closing_event_name)
desc += f" [{global_step}/{max_num_of_closing_events}]"
logger.pbar.set_description(desc) # type: ignore[attr-defined]
rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
metrics = OrderedDict()
for key, value in rendered_metrics.items():
key = "_".join(key[1:]) # tqdm has tag as description
metrics[key] = value
if metrics:
logger.pbar.set_postfix(metrics) # type: ignore[attr-defined]
global_step = engine.state.get_event_attrib_value(event_name)
if pbar_total is not None:
global_step = (global_step - 1) % pbar_total + 1
logger.pbar.update(global_step - logger.pbar.n) # type: ignore[attr-defined]
|