1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
|
import os
import torch
def average_checkpoints(last):
avg = None
for path in last:
states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
if avg[k].is_floating_point():
avg[k] /= len(last)
else:
avg[k] //= len(last)
return avg
def ensemble(args):
last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
|