1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
import json
import logging
import os
import shutil
from collections import defaultdict
import torch
class MetricLogger(defaultdict):
def __init__(self, name, print_freq=1, disable=False):
super().__init__(lambda: 0.0)
self.disable = disable
self.print_freq = print_freq
self._iter = 0
self["name"] = name
def __str__(self):
return json.dumps(self)
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self.disable and not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename, disable):
"""
Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if disable:
return
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.warning("Checkpoint: saved")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|