File: render-tree.py

package info (click to toggle)
mpich 4.3.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 101,184 kB
  • sloc: ansic: 1,040,629; cpp: 82,270; javascript: 40,763; perl: 27,933; python: 16,041; sh: 14,676; xml: 14,418; f90: 12,916; makefile: 9,270; fortran: 8,046; java: 4,635; asm: 324; ruby: 103; awk: 27; lisp: 19; php: 8; sed: 4
file content (117 lines) | stat: -rw-r--r-- 4,146 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
##
## Copyright (C) by Argonne National Laboratory
##     See COPYRIGHT in top-level directory
##

# This python script renders the files dumped by MPIR_CVAR_TREE_DUMP.
# Input: Tree files dumpped by MPIR_CVAR_TREE_DUMP.
# Output: Graph of the tree. Output filename is defined by '--output'.
# Example: python3 ./render-tree.py tree* --format png --output tree

import json
import sys
import argparse

try:
    from graphviz import Digraph
except ImportError as e:
    sys.exit('Unable to import graphviz module: {}.\n'.format(e) +
             'Please install the graphviz module by running the following command:\n'
             '\n'
             '      python -m pip install graphviz\n')


def main():
    args = parse_args()
    input_tree = load_tree(args.node_files)
    root = find_root(input_tree)
    if root == -1:
        print('cannot find root')
    level_array = bfs(input_tree, root)
    render_tree(input_tree, level_array, args.format, args.output)

# Suppose input_data is:
# [{'rank': 0, 'nranks': 4, 'parent': -1, 'children': [2, 1]},
#  {'rank': 1, 'nranks': 8, 'parent': 0, 'children': []},
#  {'rank': 2, 'nranks': 4, 'parent': 0, 'children': [3]},
#  {'rank': 3, 'nranks': 4, 'parent': 2, 'children': []}]
# The input_tree stores the input_data in a 2D array:
# [[2, 1], [], [3], []]
def load_tree(node_files):
    input_data = []
    for filename in node_files:
        with open(filename) as the_file:
            input_data.append(json.load(the_file))
    input_tree = []
    for i in range(input_data[0]['nranks']):
        input_tree.append([])
    for node in input_data:
        for child in node['children']:
            input_tree[node['rank']].append(child)
    print(input_tree)
    return input_tree

# Find the root of the input_tree
def find_root(input_tree):
    check_list = []
    for i in range(len(input_tree)):
        check_list.append(0);
    for node in input_tree:
        for child in node:
            check_list[child] = 1
    for i in range(len(check_list)):
        if (check_list[i] == 0):
            return i
    return -1

# Perform a breadth-first search
# The level_array would be: [[0], [2, 1], [3]]
# Level 0 has rank 0, level 1 has rank 2 and 1 and level 2 has rank 3
def bfs(input_tree, root):
    level_array = []
    node_queue = []
    node_queue.append(root)
    cur_level = 0
    while node_queue:
        level_array.append([])
        cur_queue_len = len(node_queue)
        for i in range(cur_queue_len):
            cur_node = node_queue.pop(0)
            level_array[cur_level].append(cur_node)
            for child in input_tree[cur_node]:
                node_queue.append(child)
        cur_level = cur_level + 1
    print(level_array)
    return level_array


def render_tree(input_tree, level_array, format, output):
    dot = Digraph(output)
    # Create invisible edges to keep the order of the children
    for i in range(len(level_array)):
        with dot.subgraph() as s:
            s.attr(rank='same')
            s.attr(rankdir='LR')
            for cur_node in level_array[i]:
                s.node(str(cur_node))
            if len(level_array[i]) >= 2:
                for j in range(len(level_array[i])):
                    if (j >= 1):
                        dot.edge(str(level_array[i][j-1]), str(level_array[i][j]), style='invis')
    for i in range(len(input_tree)):
        for j in range(len(input_tree[i])):
            dot.edge(str(i), str(input_tree[i][j]))            
    dot.render(format=format, view=True)


def parse_args():
    description = 'Render a topology-aware collective tree as a graphical tree.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('node_files', metavar='FILE', nargs='+', help='File(s) containing the JSON-formatted tree nodes to be rendered (e.g. tree-node-*.json).')
    parser.add_argument('--format', '-f', default='svg', help='Output format. Can be any format supported by graphviz. Default: svg.')
    parser.add_argument('--output', '-o', default='tree', help='Output file name. Default: tree.')
    return parser.parse_args()


if __name__ == '__main__':
    main()