File: checkpoint.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (174 lines) | stat: -rw-r--r-- 6,881 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174





import numpy as np
import pickle
from collections import OrderedDict

from caffe2.proto import caffe2_pb2

from caffe2.python import workspace, core, scope

import logging
logging.basicConfig()
log = logging.getLogger("AnyExpOnTerm")
log.setLevel(logging.DEBUG)


def initialize_params_from_file(
        model, weights_file, num_xpus, opts,
        broadcast_computed_param=False, reset_epoch=False):
    start_epoch, lr, best_metric = initialize_master_xpu_model_params(
        model, weights_file, opts, reset_epoch)
    broadcast_parameters(opts, model, num_xpus, broadcast_computed_param)
    return start_epoch, lr, best_metric


def initialize_master_xpu_model_params(model, weights_file, opts, reset_epoch):
    log.info("Initializing model params from file: {}".format(weights_file))
    with open(weights_file, 'r') as fopen:
        blobs = pickle.load(fopen)
    if 'blobs' in blobs:
        blobs = blobs['blobs']

    start_epoch = 0
    best_metric = float('-inf')
    if 'epoch' in blobs:
        log.info('epoch {} is found in model file'.format(blobs['epoch']))
        if not reset_epoch:
            start_epoch = blobs['epoch']
        else:
            log.info('Reset epoch')
    else:
        log.info('no epoch is found in model file')
    lr = opts['model_param']['base_learning_rate']
    if 'lr' in blobs:
        lr = blobs['lr']
    if 'best_metric' in blobs and not reset_epoch:
        best_metric = blobs['best_metric']

    if model is not None:
        log.info('initialize model parameters using weights file: {}'.format(
            weights_file
        ))
        ws_blobs = workspace.Blobs()
        unscoped_blob_names = OrderedDict()
        for blob in model.GetAllParams():
            unscoped_blob_names[unscope_name(str(blob))] = True
        root_xpu_id = opts['distributed']['first_xpu_id']
        device = opts['distributed']['device']
        caffe2_pb2_DEVICE =\
            caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
            else caffe2_pb2.CPU
        with core.NameScope('{}_{}'.format(device, root_xpu_id)):
            with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, 0)):
                for unscoped_blob_name in unscoped_blob_names.keys():
                    scoped_blob_name = scoped_name(unscoped_blob_name)
                    if unscoped_blob_name not in blobs:
                        log.info('{:s} not found'.format(unscoped_blob_name))
                        continue
                    log.info(
                        '{:s} loaded from weights file into: {:s}'.format(
                            unscoped_blob_name, scoped_blob_name
                        )
                    )
                    if scoped_blob_name in ws_blobs:
                        ws_blob = workspace.FetchBlob(scoped_blob_name)
                        if not ws_blob.shape == blobs[unscoped_blob_name].shape:
                            log.info(
                                ('Workspace blob {} with shape {} does '
                                    'not match weights file shape {}').format(
                                            unscoped_blob_name, ws_blob.shape,
                                            blobs[unscoped_blob_name].shape)
                            )
                        else:
                            workspace.FeedBlob(
                                scoped_blob_name,
                                blobs[unscoped_blob_name].astype(
                                    np.float32, copy=False))
    else:
        log.info('Skip initializing model parameters from file: {}'.format(
            weights_file
        ))
    log.info('Complete initialize_master_xpu_model_params')
    return start_epoch, lr, best_metric


def broadcast_parameters(opts, model, num_xpus, broadcast_computed_param=False):
    if num_xpus == 1:
        log.info("only 1 device. Skip parameter broadcast")
        return
    all_params = [model.GetParams()]
    if broadcast_computed_param:
        all_params.append(model.GetComputedParams())
    caffe2_pb2_DEVICE =\
        caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
        else caffe2_pb2.CPU
    for params in all_params:
        assert len(params) % num_xpus == 0, \
            "Current model doesn't match device number when loading checkpoint"
        params_per_xpu = int(len(params) / num_xpus)
        for idx in range(params_per_xpu):
            blobs = [param for param in params[idx::params_per_xpu]]
            data = workspace.FetchBlob(blobs[0])
            log.info('Broadcasting {} to'.format(str(blobs[0])))
            for i, p in enumerate(blobs[1:]):
                log.info(' |-> {}'.format(str(p)))
                with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, i+1)):
                    workspace.FeedBlob(p, data)
    log.info("Complete parameter broadcast")


def save_model_params(is_checkpoint, model, checkpoint_path, epoch, opts, best_metric):
    # best_metric=float('-inf')
    if checkpoint_path is None:
        return None

    try:
        save_model_params_blob(
            model, checkpoint_path, epoch, opts, best_metric
        )
    except Exception as e:
        log.warning('Exception from save_model_params {}'.format(str(e)))
    return checkpoint_path


def save_model_params_blob(model, params_file, epoch, opts, best_metric):
    # best_metric=float('-inf')
    log.info("Saving model params...")
    root_xpu_id = opts['distributed']['first_xpu_id']
    device = opts['distributed']['device']
    save_params = [str(param) for param in
                   model.GetParams('{}_{}'.format(device, root_xpu_id))]
    save_computed_params = [str(param) for param in
                            model.GetComputedParams('{}_{}'
                            .format(device, root_xpu_id))]
    save_blobs = {}
    save_blobs['epoch'] = epoch
    save_blobs['best_metric'] = best_metric
    save_blobs['lr'] = \
        workspace.FetchBlob('{}_{}/lr'.format(device, root_xpu_id))
    for param in save_params + save_computed_params:
        scoped_blob_name = str(param)
        unscoped_blob_name = unscope_name(scoped_blob_name)
        if unscoped_blob_name not in save_blobs:
            save_blobs[unscoped_blob_name] = workspace.FetchBlob(
                scoped_blob_name)
            log.debug(
                '{:s} -> {:s}'.format(scoped_blob_name, unscoped_blob_name))
    log.info('to weights file {}'.format(params_file))
    try:
        with open(params_file, 'w') as fwrite:
            pickle.dump(dict(blobs=save_blobs), fwrite, pickle.HIGHEST_PROTOCOL)
    except IOError as e:
        log.error('I/O error({0}): {1}'.format(e.errno, e.strerror))


def unscope_name(blob_name):
    return blob_name[blob_name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]


def scoped_name(blob_name):
    return scope.CurrentNameScope() + blob_name