File: average_checkpoints.py

package info (click to toggle)
pytorch-audio 2.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 108,884 kB
  • sloc: python: 44,403; cpp: 3,384; sh: 126; makefile: 32
file content (28 lines) | stat: -rw-r--r-- 827 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os

import torch


def average_checkpoints(last):
    avg = None
    for path in last:
        states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
        if avg is None:
            avg = states
        else:
            for k in avg.keys():
                avg[k] += states[k]
    # average
    for k in avg.keys():
        if avg[k] is not None:
            if avg[k].is_floating_point():
                avg[k] /= len(last)
            else:
                avg[k] //= len(last)
    return avg


def ensemble(args):
    last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
    model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
    torch.save({"state_dict": average_checkpoints(last)}, model_path)