1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
import logging
import os
import shutil
from collections import defaultdict, deque
import torch
class MetricLogger:
r"""Logger for model metrics"""
def __init__(self, group, print_freq=1):
self.print_freq = print_freq
self._iter = 0
self.data = defaultdict(lambda: deque(maxlen=self.print_freq))
self.data["group"].append(group)
def __setitem__(self, key, value):
self.data[key].append(value)
def _get_last(self):
return {k: v[-1] for k, v in self.data.items()}
def __str__(self):
return str(self._get_last())
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename):
r"""Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.info("Checkpoint: saved")
def count_parameters(model):
r"""Count the total number of parameters in the model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|