File: conversion.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (88 lines) | stat: -rw-r--r-- 2,847 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
## @package onnx
# Module caffe2.python.onnx.bin.conversion





import json

from caffe2.proto import caffe2_pb2
import click
from onnx import ModelProto

from caffe2.python.onnx.backend import Caffe2Backend as c2
import caffe2.python.onnx.frontend as c2_onnx


@click.command(
    help='convert caffe2 net to onnx model',
    context_settings={
        'help_option_names': ['-h', '--help']
    }
)
@click.argument('caffe2_net', type=click.File('rb'))
@click.option('--caffe2-net-name',
              type=str,
              help="Name of the caffe2 net")
@click.option('--caffe2-init-net',
              type=click.File('rb'),
              help="Path of the caffe2 init net pb file")
@click.option('--value-info',
              type=str,
              help='A json string providing the '
              'type and shape information of the inputs')
@click.option('-o', '--output', required=True,
              type=click.File('wb'),
              help='Output path for the onnx model pb file')
def caffe2_to_onnx(caffe2_net,
                   caffe2_net_name,
                   caffe2_init_net,
                   value_info,
                   output):
    c2_net_proto = caffe2_pb2.NetDef()
    c2_net_proto.ParseFromString(caffe2_net.read())
    if not c2_net_proto.name and not caffe2_net_name:
        raise click.BadParameter(
            'The input caffe2 net does not have name, '
            '--caffe2-net-name must be provided')
    c2_net_proto.name = caffe2_net_name or c2_net_proto.name
    if caffe2_init_net:
        c2_init_net_proto = caffe2_pb2.NetDef()
        c2_init_net_proto.ParseFromString(caffe2_init_net.read())
        c2_init_net_proto.name = '{}_init'.format(caffe2_net_name)
    else:
        c2_init_net_proto = None

    if value_info:
        value_info = json.loads(value_info)

    onnx_model = c2_onnx.caffe2_net_to_onnx_model(
        predict_net=c2_net_proto,
        init_net=c2_init_net_proto,
        value_info=value_info)

    output.write(onnx_model.SerializeToString())


@click.command(
    help='convert onnx model to caffe2 net',
    context_settings={
        'help_option_names': ['-h', '--help']
    }
)
@click.argument('onnx_model', type=click.File('rb'))
@click.option('-o', '--output', required=True,
              type=click.File('wb'),
              help='Output path for the caffe2 net file')
@click.option('--init-net-output',
              required=True,
              type=click.File('wb'),
              help='Output path for the caffe2 init net file')
def onnx_to_caffe2(onnx_model, output, init_net_output):
    onnx_model_proto = ModelProto()
    onnx_model_proto.ParseFromString(onnx_model.read())

    init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto)
    init_net_output.write(init_net.SerializeToString())
    output.write(predict_net.SerializeToString())