File: AnyExpOnTerm.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (98 lines) | stat: -rw-r--r-- 3,346 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98





import argparse
import json
import os

import caffe2.contrib.playground.AnyExp as AnyExp
import caffe2.contrib.playground.checkpoint as checkpoint

import logging
logging.basicConfig()
log = logging.getLogger("AnyExpOnTerm")
log.setLevel(logging.DEBUG)


def runShardedTrainLoop(opts, myTrainFun):
    start_epoch = 0
    pretrained_model = opts['model_param']['pretrained_model']
    if pretrained_model != '' and os.path.exists(pretrained_model):
        # Only want to get start_epoch.
        start_epoch, prev_checkpointed_lr, best_metric = \
            checkpoint.initialize_params_from_file(
                model=None,
                weights_file=pretrained_model,
                num_xpus=1,
                opts=opts,
                broadcast_computed_param=True,
                reset_epoch=opts['model_param']['reset_epoch'],
            )
    log.info('start epoch: {}'.format(start_epoch))
    pretrained_model = None if pretrained_model == '' else pretrained_model
    ret = None

    pretrained_model = ""
    shard_results = []

    for epoch in range(start_epoch,
                       opts['epoch_iter']['num_epochs'],
                       opts['epoch_iter']['num_epochs_per_flow_schedule']):
        # must support checkpoint or the multiple schedule will always
        # start from initial state
        checkpoint_model = None if epoch == start_epoch else ret['model']
        pretrained_model = None if epoch > start_epoch else pretrained_model
        shard_results = []
        # with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False):
        for shard_id in range(opts['distributed']['num_shards']):
            opts['temp_var']['shard_id'] = shard_id
            opts['temp_var']['pretrained_model'] = pretrained_model
            opts['temp_var']['checkpoint_model'] = checkpoint_model
            opts['temp_var']['epoch'] = epoch
            opts['temp_var']['start_epoch'] = start_epoch
            shard_ret = myTrainFun(opts)
            shard_results.append(shard_ret)

        ret = None
        # always only take shard_0 return
        for shard_ret in shard_results:
            if shard_ret is not None:
                ret = shard_ret
                opts['temp_var']['metrics_output'] = ret['metrics']
                break
        log.info('ret is: {}'.format(str(ret)))

    return ret


def trainFun():
    def simpleTrainFun(opts):
        trainerClass = AnyExp.createTrainerClass(opts)
        trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts)
        trainer = trainerClass(opts)
        return trainer.buildModelAndTrain(opts)
    return simpleTrainFun


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Any Experiment training.')
    parser.add_argument("--parameters-json", type=json.loads,
                        help='model options in json format', dest="params")

    args = parser.parse_args()
    opts = args.params['opts']
    opts = AnyExp.initOpts(opts)
    log.info('opts is: {}'.format(str(opts)))

    AnyExp.initDefaultModuleMap()

    opts['input']['datasets'] = AnyExp.aquireDatasets(opts)

    # defined this way so that AnyExp.trainFun(opts) can be replaced with
    # some other custermized training function.
    ret = runShardedTrainLoop(opts, trainFun())

    log.info('ret is: {}'.format(str(ret)))