File: pairwise_similarity.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (96 lines) | stat: -rw-r--r-- 3,549 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
## @package dot_product
# Module caffe2.python.layers.dot_product





from caffe2.python import schema
from caffe2.python.layers.layers import (
    ModelLayer,
)


class PairwiseSimilarity(ModelLayer):

    def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
                 name='pairwise_similarity', **kwargs):
        super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
        assert isinstance(input_record, schema.Struct), (
            "Incorrect input type. Expected Struct, but received: {0}".
            format(input_record))
        assert (
            ('all_embeddings' in input_record) ^
            ('x_embeddings' in input_record and 'y_embeddings' in input_record)
        ), (
            "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
            "should be given."
        )
        self.pairwise_similarity_func = pairwise_similarity_func
        if 'all_embeddings' in input_record:
            x_embeddings = input_record['all_embeddings']
            y_embeddings = input_record['all_embeddings']
        else:
            x_embeddings = input_record['x_embeddings']
            y_embeddings = input_record['y_embeddings']

        assert isinstance(x_embeddings, schema.Scalar), (
            "Incorrect input type for x. Expected Scalar, " +
            "but received: {0}".format(x_embeddings))
        assert isinstance(y_embeddings, schema.Scalar), (
            "Incorrect input type for y. Expected Scalar, " +
            "but received: {0}".format(y_embeddings)
        )

        if 'indices_to_gather' in input_record:
            indices_to_gather = input_record['indices_to_gather']
            assert isinstance(indices_to_gather, schema.Scalar), (
                "Incorrect type of indices_to_gather. "
                "Expected Scalar, but received: {0}".format(indices_to_gather)
            )
            self.indices_to_gather = indices_to_gather
        else:
            self.indices_to_gather = None

        self.x_embeddings = x_embeddings
        self.y_embeddings = y_embeddings

        dtype = x_embeddings.field_types()[0].base

        self.output_schema = schema.Scalar(
            (dtype, (output_dim,)),
            self.get_next_blob_reference('output')
        )

    def add_ops(self, net):
        if self.pairwise_similarity_func == "cosine_similarity":
            x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1)
            y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1)
            Y = net.BatchMatMul(
                [x_embeddings_norm, y_embeddings_norm],
                [self.get_next_blob_reference(x_embeddings_norm + '_matmul')],
                trans_b=1,
            )
        elif self.pairwise_similarity_func == "dot":
            Y = net.BatchMatMul(
                [self.x_embeddings(), self.y_embeddings()],
                [self.get_next_blob_reference(self.x_embeddings() + '_matmul')],
                trans_b=1,
            )
        else:
            raise NotImplementedError(
                "pairwise_similarity_func={} is not valid".format(
                    self.pairwise_similarity_func
                )
            )

        if self.indices_to_gather:
            flattened = net.Flatten(
                Y, Y + '_flatten',
            )
            net.BatchGather(
                [flattened, self.indices_to_gather()],
                self.output_schema(),
            )
        else:
            net.Flatten(Y, self.output_schema())