File: build_index.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (70 lines) | stat: -rw-r--r-- 1,937 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70





import numpy as np

from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer


class MapToRange(ModelLayer):
    """
    This layer aims to build a mapping from raw keys to indices within [0, max_index).
    The mapping is continuously built during training. The mapping will be frozen during
    evaluation and prediction. Unseen keys will be assigned to index 0.
    """

    def __init__(
        self, model,
        input_record,
        max_index,
        name='map_to_range',
        **kwargs
    ):
        super(MapToRange, self).__init__(model, name, input_record, **kwargs)

        assert max_index > 0
        assert isinstance(input_record, schema.Scalar)

        self.max_index = max_index

        self.handler = self.create_param(
            param_name='handler',
            shape=[],
            initializer=('LongIndexCreate', {'max_elements': self.max_index}),
            optimizer=model.NoOptim
        )

        self.output_schema = schema.Struct(
            ('indices', schema.Scalar(
                np.int64, self.get_next_blob_reference("indices")
            )),
            ('handler', schema.Scalar(
                np.void, self.handler
            )),
        )

    def add_train_ops(self, net):
        if self.input_record.field_type().base != np.int64:
            keys = net.Cast(
                self.input_record(),
                net.NextScopedBlob("indices_before_mapping"),
                to=core.DataType.INT64
            )
        else:
            keys = self.input_record()

        # Load keys into indices
        indices = net.IndexGet([self.handler, keys],
                                self.output_schema.indices())

        net.StopGradient(indices, indices)

    def add_eval_ops(self, net):
        net.IndexFreeze(self.handler, self.handler)
        self.add_train_ops(net)

    def add_ops(self, net):
        self.add_eval_ops(net)