File: get_entry_from_blobs.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (83 lines) | stat: -rw-r--r-- 3,130 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################






from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier

import numpy as np


class GetEntryFromBlobs(NetModifier):
    """
    This class modifies the net passed in by adding ops to get a certain entry
    from certain blobs.

    Args:
        blobs: list of blobs to get entry from
        logging_frequency: frequency for printing entry values to logs
        i1, i2: the first, second dimension of the blob. (currently, we assume
        the blobs to be 2-dimensional blobs). When i2 = -1, print all entries
        in blob[i1]
    """

    def __init__(self, blobs, logging_frequency, i1=0, i2=0):
        self._blobs = blobs
        self._logging_frequency = logging_frequency
        self._i1 = i1
        self._i2 = i2
        self._field_name_suffix = '_{0}_{1}'.format(i1, i2) if i2 >= 0 \
            else '_{0}_all'.format(i1)

    def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
                    modify_output_record=False):

        i1, i2 = [self._i1, self._i2]
        if i1 < 0:
            raise ValueError('index is out of range')

        for blob_name in self._blobs:
            blob = core.BlobReference(blob_name)
            assert net.BlobIsDefined(blob), 'blob {} is not defined in net {} whose proto is {}'.format(blob, net.Name(), net.Proto())

            blob_i1 = net.Slice([blob], starts=[i1, 0], ends=[i1 + 1, -1])
            if self._i2 == -1:
                blob_i1_i2 = net.Copy([blob_i1],
                    [net.NextScopedBlob(prefix=blob + '_{0}_all'.format(i1))])
            else:
                blob_i1_i2 = net.Slice([blob_i1],
                    net.NextScopedBlob(prefix=blob + '_{0}_{1}'.format(i1, i2)),
                    starts=[0, i2], ends=[-1, i2 + 1])

            if self._logging_frequency >= 1:
                net.Print(blob_i1_i2, [], every_n=self._logging_frequency)

            if modify_output_record:
                output_field_name = str(blob) + self._field_name_suffix
                output_scalar = schema.Scalar((np.float), blob_i1_i2)

                if net.output_record() is None:
                    net.set_output_record(
                        schema.Struct((output_field_name, output_scalar))
                    )
                else:
                    net.AppendOutputRecordField(output_field_name, output_scalar)

    def field_name_suffix(self):
        return self._field_name_suffix