File: compare-fastrnn-results.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (55 lines) | stat: -rw-r--r-- 2,179 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import json
from collections import namedtuple

Result = namedtuple("Result", ["name", "base_time", "diff_time"])

def construct_name(fwd_bwd, test_name):
    bwd = 'backward' in fwd_bwd
    suite_name = fwd_bwd.replace('-backward', '')
    return '{suite}[{test}]:{fwd_bwd}'.format(suite=suite_name, test=test_name, fwd_bwd='bwd' if bwd else 'fwd')

def get_times(json_data):
    r = {}
    for fwd_bwd in json_data:
        for test_name in json_data[fwd_bwd]:
            name = construct_name(fwd_bwd, test_name)
            r[name] = json_data[fwd_bwd][test_name]
    return r

parser = argparse.ArgumentParser("compare two pytest jsons")
parser.add_argument('base', help="base json file")
parser.add_argument('diff', help='diff json file')
parser.add_argument('--format', default='md', type=str, help='output format (csv, md, json, table)')
args = parser.parse_args()

with open(args.base, "r") as base:
    base_times = get_times(json.load(base))
with open(args.diff, "r") as diff:
    diff_times = get_times(json.load(diff))

all_keys = set(base_times.keys()).union(diff_times.keys())
results = [
    Result(name, base_times.get(name, float("nan")), diff_times.get(name, float("nan")))
    for name in sorted(all_keys)
]

header_fmt = {'table' : '{:48s} {:>13s} {:>15s} {:>10s}',
              'md'    : '| {:48s} | {:>13s} | {:>15s} | {:>10s} |',
              'csv'   : '{:s}, {:s}, {:s}, {:s}'}
data_fmt = {'table' : '{:48s} {:13.6f} {:15.6f} {:9.1f}%',
            'md'    : '| {:48s} | {:13.6f} | {:15.6f} | {:9.1f}% |',
            'csv'   : '{:s}, {:.6f}, {:.6f}, {:.2f}%'}

if args.format in ['table', 'md', 'csv']:
    header_fmt_str = header_fmt[args.format]
    data_fmt_str = data_fmt[args.format]
    print(header_fmt_str.format("name", "base time (s)", "diff time (s)", "% change"))
    if args.format == 'md':
        print(header_fmt_str.format(":---", "---:", "---:", "---:"))
    for r in results:
        print(data_fmt_str.format(r.name, r.base_time, r.diff_time, (r.diff_time / r.base_time - 1.0) * 100.0))
elif args.format == 'json':
    print(json.dumps(results))
else:
    raise ValueError('Unknown output format: ' + args.format)