File: compute_histogram_for_blobs.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (92 lines) | stat: -rw-r--r-- 3,508 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92





from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier

import numpy as np


class ComputeHistogramForBlobs(NetModifier):
    """
    This class modifies the net passed in by adding ops to compute histogram for
    certain blobs.

    Args:
        blobs: list of blobs to compute histogram for
        logging_frequency: frequency for printing
        lower_bound: left boundary of histogram values
        upper_bound: right boundary of histogram values
        num_buckets: number of buckets to use in [lower_bound, upper_bound)
        accumulate: boolean to output accumulate or per-batch histogram
    """

    def __init__(self, blobs, logging_frequency, num_buckets=30,
            lower_bound=0.0, upper_bound=1.0, accumulate=False):
        self._blobs = blobs
        self._logging_frequency = logging_frequency
        self._accumulate = accumulate
        if self._accumulate:
            self._field_name_suffix = '_acc_normalized_hist'
        else:
            self._field_name_suffix = '_curr_normalized_hist'

        self._num_buckets = int(num_buckets)
        assert self._num_buckets > 0, (
            "num_buckets need to be greater than 0, got {}".format(num_buckets))
        self._lower_bound = float(lower_bound)
        self._upper_bound = float(upper_bound)

    def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
                   modify_output_record=False):
        for blob_name in self._blobs:
            blob = core.BlobReference(blob_name)
            assert net.BlobIsDefined(blob), 'blob {} is not defined in net {} whose proto is {}'.format(blob, net.Name(), net.Proto())

            blob_float = net.Cast(blob, net.NextScopedBlob(prefix=blob +
                '_float'), to=core.DataType.FLOAT)
            curr_hist, acc_hist = net.AccumulateHistogram(
                [blob_float],
                [net.NextScopedBlob(prefix=blob + '_curr_hist'),
                 net.NextScopedBlob(prefix=blob + '_acc_hist')],
                num_buckets=self._num_buckets,
                lower_bound=self._lower_bound,
                upper_bound=self._upper_bound)

            if self._accumulate:
                hist = net.Cast(
                    acc_hist,
                    net.NextScopedBlob(prefix=blob + '_cast_hist'),
                    to=core.DataType.FLOAT)
            else:
                hist = net.Cast(
                    curr_hist,
                    net.NextScopedBlob(prefix=blob + '_cast_hist'),
                    to=core.DataType.FLOAT)

            normalized_hist = net.NormalizeL1(
                hist,
                net.NextScopedBlob(prefix=blob + self._field_name_suffix)
            )

            if self._logging_frequency >= 1:
                net.Print(normalized_hist, [], every_n=self._logging_frequency)

            if modify_output_record:
                output_field_name = str(blob) + self._field_name_suffix
                output_scalar = schema.Scalar((np.float32, (self._num_buckets + 2,)),
                    normalized_hist)

                if net.output_record() is None:
                    net.set_output_record(
                        schema.Struct((output_field_name, output_scalar))
                    )
                else:
                    net.AppendOutputRecordField(
                        output_field_name,
                        output_scalar)

    def field_name_suffix(self):
        return self._field_name_suffix