File: select_record_by_context.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (77 lines) | stat: -rw-r--r-- 2,381 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77





import logging

from caffe2.python import schema
from caffe2.python.layers.layers import (
    InstantiationContext,
    ModelLayer,
)


logger = logging.getLogger(__name__)


class SelectRecordByContext(ModelLayer):
    """
    Allowing model to follow different paths for each instantiation context and
    join later at some point. The implementation use `Alias` because schema
    sometimes clone fields internally so we need static blob name for output
    """

    def __init__(
        self,
        model,
        input_record,
        name='select_record_by_context',
        check_field_metas=True,
        use_copy=False,
        default_output_record_field=None,
        **kwargs
    ):
        super(SelectRecordByContext, self).__init__(model, name, input_record,
                                                    **kwargs)

        assert isinstance(input_record, schema.Struct)
        assert len(input_record) > 1

        self.use_copy = use_copy
        self.default_output_record = (
            input_record[default_output_record_field]
            if (default_output_record_field is not None) else None
        )
        ref_record = input_record[0]
        for record in input_record:
            assert schema.equal_schemas(record, ref_record,
                                        check_field_metas=check_field_metas)

        self.output_schema = schema.NewRecord(model.net, ref_record)

    def _set_output_blobs(self, net, context):
        record = self.input_record.get(context, self.default_output_record)
        assert record is not None, (
            "{} context is not in input record without providing default"
            " output".format(context)
        )
        for in_blob, out_blob in zip(
                record.field_blobs(), self.output_schema.field_blobs()
        ):
            if self.use_copy:
                net.Copy(in_blob, out_blob)
            else:
                net.Alias(in_blob, out_blob)

    def add_ops(self, net):
        self._set_output_blobs(net, InstantiationContext.PREDICTION)

    def add_eval_ops(self, net):
        self._set_output_blobs(net, InstantiationContext.EVAL)

    def add_train_ops(self, net):
        self._set_output_blobs(net, InstantiationContext.TRAINING)

    def add_ops_to_accumulate_pred(self, net):
        self._set_output_blobs(net, InstantiationContext.ACCUMULATE_PRED)