File: debug_embed_params.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (54 lines) | stat: -rw-r--r-- 1,535 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys

import caffe2.python.onnx.backend as c2

import onnx
import pytorch_test_common
import torch
import torch.jit
from torch.autograd import Variable

torch.set_default_tensor_type("torch.FloatTensor")
try:
    import torch
except ImportError:
    print("Cannot import torch, hence caffe2-torch test will not run.")
    sys.exit(0)


def run_embed_params(proto, model, input, state_dict=None, use_gpu=True):
    """
    This is only a helper debug function so we can test embed_params=False
    case as well on pytorch front
    This should likely be removed from the release version of the code
    """
    device = "CPU"
    if use_gpu:
        device = "CUDA"
    model_def = onnx.ModelProto.FromString(proto)
    onnx.checker.check_model(model_def)
    prepared = c2.prepare(model_def, device=device)

    if state_dict:
        parameters = []
        # Passed in state_dict may have a different order.  Make
        # sure our order is consistent with the model's order.
        # TODO: Even better: keyword arguments!
        for k in model.state_dict():
            if k in state_dict:
                parameters.append(state_dict[k])
    else:
        parameters = list(model.state_dict().values())

    W = {}
    for k, v in zip(
        model_def.graph.input, pytorch_test_common.flatten((input, parameters))
    ):
        if isinstance(v, Variable):
            W[k.name] = v.data.cpu().numpy()
        else:
            W[k.name] = v.cpu().numpy()

    caffe2_out = prepared.run(inputs=W)

    return caffe2_out