#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-3-Clause
# Copyright 2019, Intel Corporation

"""
latency_plot.py - tool for drawing latency benchmarks plots based on
output generated by 'bench_simul' and written to a file provided with
'latency_file' argument.
"""

import argparse
import matplotlib.pyplot as plt


MARKERS = ('o', '^', 's', 'D', 'X')
CUR_MARKER = 0


def _add_series(series, label, marker):
    """Add data series to plot"""
    plt.plot(series, label=label, marker=marker, linestyle=':', linewidth=0.5,
             markersize=4)


def draw_plot(yscale='linear'):
    """Draw a plot of all previously added data series"""
    plt.yscale(yscale)
    plt.xticks(list(range(0, 101, 5)))
    plt.xlabel('percentile [%]')
    plt.grid(True)
    plt.ylabel('operation time [ns]')
    plt.legend()
    plt.show()


def _parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument('out', nargs='*', help='Create a plot for all provided'
                        ' output files')
    parser.add_argument('--yscale', '-y', help='Y-axis scale',
                        default='linear')
    parser.add_argument('--hits', help='Draw hits', dest='hits', action='store_true')
    parser.add_argument('--no-hits', help='Do not draw hits', dest='hits', action='store_false')
    parser.set_defaults(hits=True)

    parser.add_argument('--ltrim', help='Remove a number of smallest latency values from the plot', default=0, type=int)
    parser.add_argument('--rtrim', help='Remove a number of biggest latency values from the plot', default=0, type=int)

    parser.add_argument('--misses', help='Draw misses', dest='misses', action='store_true')
    parser.add_argument('--no-misses', help='Do not draw misses', dest='misses', action='store_false')
    parser.set_defaults(misses=True)

    args = parser.parse_args()
    if not args.out:
        parser.error('at least one output need to be provided')
    return args


def _read_out(path):
    """Read 'latency_file' output file"""
    with open(path, 'r') as f:
        out = f.readlines()
        hits = [float(h) for h in out[0].split(';')]
        misses = [float(m) for m in out[1].split(';')]
        return hits, misses


def add_data(output, name, hits=True, misses=True, ltrim=0, rtrim=0):
    """Add data from 'latency_file' output file to plot"""
    global CUR_MARKER

    h, m = _read_out(output)
    if ltrim:
        h, m = h[ltrim:], m[ltrim:]
    if rtrim:
        h, m = h[:-rtrim], m[:-rtrim]
    if hits:
        _add_series(h, '{}_hits'.format(name),
                    MARKERS[CUR_MARKER % len(MARKERS)])
    if misses:
        _add_series(m, '{}_misses'.format(name),
                    MARKERS[CUR_MARKER % len(MARKERS)])

    # use different marker for each plotted benchmark data
    CUR_MARKER += 1


def _main():
    args = _parse_args()

    for out in args.out:
        add_data(out, out, args.hits, args.misses, args.ltrim, args.rtrim)
    draw_plot(args.yscale)


if __name__ == '__main__':
    _main()
