File: hsm_util.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (70 lines) | stat: -rw-r--r-- 2,259 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
## @package hsm_util
# Module caffe2.python.hsm_util





from caffe2.proto import hsm_pb2

'''
    Hierarchical softmax utility methods that can be used to:
    1) create TreeProto structure given list of word_ids or NodeProtos
    2) create HierarchyProto structure using the user-inputted TreeProto
'''


def create_node_with_words(words, name='node'):
    node = hsm_pb2.NodeProto()
    node.name = name
    for word in words:
        node.word_ids.append(word)
    return node


def create_node_with_nodes(nodes, name='node'):
    node = hsm_pb2.NodeProto()
    node.name = name
    for child_node in nodes:
        new_child_node = node.children.add()
        new_child_node.MergeFrom(child_node)
    return node


def create_hierarchy(tree_proto):
    max_index = 0

    def create_path(path, word):
        path_proto = hsm_pb2.PathProto()
        path_proto.word_id = word
        for entry in path:
            new_path_node = path_proto.path_nodes.add()
            new_path_node.index = entry[0]
            new_path_node.length = entry[1]
            new_path_node.target = entry[2]
        return path_proto

    def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
        node_proto.offset = max_index
        path.append([max_index,
                    len(node_proto.word_ids) + len(node_proto.children), 0])
        max_index += len(node_proto.word_ids) + len(node_proto.children)
        if hierarchy_proto.size < max_index:
            hierarchy_proto.size = max_index
        for target, node in enumerate(node_proto.children):
            path[-1][2] = target
            max_index = recursive_path_builder(node, path, hierarchy_proto,
                                               max_index)
        for target, word in enumerate(node_proto.word_ids):
            path[-1][2] = target + len(node_proto.children)
            path_entry = create_path(path, word)
            new_path_entry = hierarchy_proto.paths.add()
            new_path_entry.MergeFrom(path_entry)
        del path[-1]
        return max_index

    node = tree_proto.root_node
    hierarchy_proto = hsm_pb2.HierarchyProto()
    path = []
    max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
    return hierarchy_proto