import numpy as np
import matplotlib

matplotlib.use('Agg')
import pylab as pl
import argparse
import os
from sklearn.metrics import auc
from matplotlib.font_manager import FontProperties


# find bounds for plotting
# threshold tells, how many reads are to be ignored on both sides (relative to length of the list)
def find_bounds(scorelist, threshold=.0005):
    all_length = len(scorelist)
    i = 0
    j = all_length
    scorelist.sort()
    while i < threshold * all_length:
        i += 1
    while all_length - j < threshold * all_length:
        j -= 1
    return scorelist[i][0], scorelist[j - 1][0]


def find_bounds_unassigned(scorelist, threshold=.0005):
    all_length = len(scorelist)
    i = 0
    j = all_length
    scorelist.sort()
    while i < threshold * all_length:
        i += 1
    while all_length - j < threshold * all_length:
        j -= 1
    return float(scorelist[i]), float(scorelist[j - 1])


def hist_of_scores(species_score, path, name1, name2, k, amount, SE_String):
    species1 = []
    species2 = []
    bins2 = [0] * 101
    bins3 = [0] * 101
    bounds = find_bounds(species_score)
    lowerbound = bounds[0]
    upperbound = bounds[1]
    stepsize = (bounds[1] - bounds[0]) / 101.
    for key in species_score:
        logkey = (key[0])
        if key[1] == '1':
            species1.append(logkey)
        elif key[1] == '2':
            species2.append(logkey)
    for i in range(len(species1)):
        if lowerbound < species1[i] < upperbound:
            bins2[int((species1[i] - lowerbound) / stepsize)] += 1
    for i in range(len(species2)):
        if lowerbound < species2[i] < upperbound:
            bins3[int((species2[i] - lowerbound) / stepsize)] += 1
    ticks = np.arange(bounds[0], bounds[1] + bounds[1] / 1000., (bounds[1] - bounds[0]) / 100.)
    pl.plot(ticks, bins2, label=name1, color='blue')
    pl.plot(ticks, bins3, label=name2, color='red')
    pl.legend(loc='upper right', numpoints=1)
    pl.savefig(os.path.join(path, 'score_histogram_k' + str(k) + '_' + str(amount) + SE_String + '.png'))
    return species1, species2, bounds


def fasta_generator(infile):
    block = []
    state = 'name'
    for line in infile:
        if state == 'name':
            block.append(line.rstrip())
            state = 'read'
        elif state == 'read':
            block.append(line.rstrip())
            state = 'name'
            yield block
            block = []


def fastq_generator(infile):
    block = []
    qual = ''
    read = ''
    state = 'name'  # name read qual
    counter = 0

    for line in infile:
        if state == 'name':
            block.append(line.rstrip())
            state = 'read'
        elif state == 'read':
            if line[0] == '+':
                block.append(read)
                block.append('+')
                read = ''
                state = 'qual'
            else:
                counter += 1
                read += line.rstrip()
        elif state == 'qual':
            if counter > 1:
                qual += line.rstrip()
                counter -= 1
            else:
                qual += line.rstrip()
                counter = 0
                block.append(qual)
                qual = ''
                state = 'name'
                yield block
                block = []


def hist_distribution_plot(infile, path, species_hist_scores, name1, name2, k, amount, SE_String, filetype):
    if filetype == 'fastq':
        gen = fastq_generator(open(infile))
    else:
        gen = fasta_generator(open(infile))
    block = next(gen, False)
    species_by_cutoff = []
    while block:
        species_by_cutoff.append((float(block[0].split('_')[-2]) - float(block[0].split('_')[-1])))
        block = next(gen, False)
    bounds = find_bounds_unassigned(species_by_cutoff)
    lowerbound = min(species_hist_scores[2][0], bounds[0])
    upperbound = max(species_hist_scores[2][1], bounds[1])
    species1_scores = species_hist_scores[0]
    species2_scores = species_hist_scores[1]
    bins = [0] * 101
    bins2 = [0] * 101
    bins3 = [0] * 101

    stepsize = (upperbound - lowerbound) / 100.
    for logkey in species_by_cutoff:
        if logkey > lowerbound and logkey < upperbound:
            bins[int((logkey - lowerbound) / stepsize)] += 1
    for i in range(min(len(species1_scores), len(species2_scores))):
        if species1_scores[i] > lowerbound and species1_scores[i] < upperbound:
            bins2[int((species1_scores[i] - lowerbound) / stepsize)] += 1
        if species2_scores[i] > lowerbound and species2_scores[i] < upperbound:
            bins3[int((species2_scores[i] - lowerbound) / stepsize)] += 1
    sum_bins = float(sum(bins))
    for each in range(len(bins)):
        bins[each] = (bins[each] / sum_bins)  # +1 to get positive logarithms

    a = []
    b = bins
    sum_bins2 = float(sum(bins2))
    sum_bins3 = float(sum(bins3))
    for each in range(len(bins)):
        bins2[each] /= sum_bins2
        bins3[each] /= sum_bins3
    for each in range(len(bins2)):
        a.append([bins2[each], bins3[each]])
    x, y = np.linalg.lstsq(a, b)[0]
    xplusy = x + y
    x = x / xplusy
    y = y / xplusy
    sum_of_hist_scores = []
    for each in range(len(bins2)):
        bins[each] += 1
        bins2[each] = (max(0, bins2[each] * x)) + 1
        bins3[each] = (max(0, bins3[each] * y)) + 1  # +1 to get positive logarithms
        sum_of_hist_scores.append(bins2[each] * bins3[each])
    ticks = np.arange(lowerbound, upperbound + upperbound / 1000., (upperbound - lowerbound) / 100.)
    fig = pl.figure()
    fig.set_size_inches(12, 9.5)
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(ticks, bins, label='real data', color='purple', linewidth=3, alpha=1)
    ax.plot(ticks, bins2, label='simulated data of ' + name1, color='blue', linewidth=3, alpha=.25)
    ax.fill_between(ticks, bins2, 1, color='blue', alpha=0.25, offset_position='screen')
    ax.plot(ticks, bins3, label='simulated data of ' + name2, color='red', linewidth=3, alpha=.25)
    ax.fill_between(ticks, bins3, 1, color='red', alpha=0.25, offset_position='screen')
    ax.plot(ticks, sum_of_hist_scores, label='sum of simulated data', color='green', linewidth=3, alpha=1)
    ax.set_yscale('log')
    ax.axes.get_yaxis().set_visible(False)
    fontp = FontProperties()
    fontp.set_size(27)
    pl.title(name1 + ' (' + str(round(x * 100, 2)) + '%) and ' + name2 + ' (' + str(round(y * 100, 2)) + '%) data',
             fontsize=40.)
    pl.ylabel('Frequency', size=40.)
    ax.set_yticks([0, 0.5, 1])
    pl.xlabel('Read score', size=40.)
    pl.ylim(0, 1.02 * (
        max((max(bins), max(bins2), max(bins3))) + (max((max(bins), max(bins2),
                                                         max(bins3)))) / 500.))  # normalize height
    leg = pl.legend(loc='upper right', numpoints=1, prop=fontp)
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label]):
        item.set_fontsize(27)
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(20)
    ax.text(-.01, .005, '0', horizontalalignment='right', verticalalignment='center', rotation='horizontal',
            transform=ax.transAxes, size=20)
    ax.text(-.01, .475, '.5', horizontalalignment='right', verticalalignment='center', rotation='horizontal',
            transform=ax.transAxes, size=20)
    ax.text(-.01, .95, '1', horizontalalignment='right', verticalalignment='center', rotation='horizontal',
            transform=ax.transAxes, size=20)
    ax.text(-.04, .5, 'Frequency [rel. units]', horizontalalignment='right', verticalalignment='center',
            rotation='vertical', transform=ax.transAxes, size=27)

    leg.get_frame().set_alpha(0.5)
    pl.savefig(os.path.join(path, 'fitted_histograms_k' + str(k) + '_' + str(amount) + SE_String + '.png'))


def parse_fasta(infile, name1, name2):
    infile = open(infile)
    species_by_cutoff = []
    name = infile.readline()
    infile.readline()
    while name[0] == '>':
        name = name[::-1].split('_', 2)[::-1]
        name = [x[::-1].rstrip() for x in name]
        if name[0][1:len(name1) + 1] == name1:
            species_by_cutoff.append([float(name[-2]) - float(name[-1]), '1'])
        elif name[0][1:len(name2) + 1] == name2:
            species_by_cutoff.append([float(name[-2]) - float(name[-1]), '2'])
        else:
            raise Exception(
                'There seems to be a problem with the species names. Try to avoid special characters or use the default values')
        name = infile.readline()
        infile.readline()
        if name == '':
            name = 'ending now'
    return species_by_cutoff


def parse_fastq(infile, name1, name2):
    infile = open(infile)
    species_by_cutoff = []

    name = '@'
    while name[0][0] == '@':
        name = infile.readline()
        infile.readline()
        infile.readline()
        infile.readline()
        if len(name) <= 1:
            name = 'ending now'
        name = name.split('_')
        if name[0][1] == name1[0]:
            species_by_cutoff.append([float(name[-2]) - float(name[-1]), '1'])
        elif name[0][1] == name2[0]:
            species_by_cutoff.append([float(name[-2]) - float(name[-1]), '2'])
    return species_by_cutoff


def roc_plot(species_by_cutoff, k, path, amount, SE_String):
    tpr = [0]
    fpr = [0]
    species_by_cutoff = sorted(species_by_cutoff)
    tpl = []
    fontp = FontProperties()
    fontp.set_size(27)
    for each in species_by_cutoff:
        tpl.append(each[1])
    species1count = 0
    species2count = 0
    for i in range(len(tpl)):
        if tpl[i] == '2':
            species1count += 1
        elif tpl[i] == '1':
            species2count += 1
        tpr.append(species1count / float(
            len(species_by_cutoff) / 2.))  # this will not work if both species appear differently often!
        fpr.append(species2count / float(len(species_by_cutoff) / 2.))
    roc_auc = auc(fpr, tpr)
    # Plot ROC curve
    pl.clf()
    fig = pl.figure()
    fig.set_size_inches(12, 9.5)
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc, color='green', zorder=10, lw=4, alpha=.6)
    ax.plot([0.0, 1], [0, 1], 'k--')
    ax.fill_between(fpr, tpr, color=(44 / 255., 160 / 255., 44 / 255.), alpha=.6)
    pl.xlabel('False Positive Rate', fontsize=27)
    pl.ylabel('True Positive Rate', fontsize=27)
    pl.title('ROC-Curve of ' + str(k) + '-mer based read distinction', fontsize=27)
    pl.legend(loc='lower right', prop=fontp)
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(20)
    pl.savefig(os.path.join(path, 'ROCplot_k' + str(k) + '_' + str(amount) + SE_String + '.png'))


def main(k, path, name1, name2, infile_simulated, infile_real, amount, SE, filetype='fastq'):
    all_return_vals = parse_fasta(infile_simulated, name1, name2)
    species_by_cutoff = all_return_vals
    if SE:
        se_string = '_SE'
    else:
        se_string = '_PE'
    print('Plotting histogram of scores')
    species_scores = hist_of_scores(species_by_cutoff, path, name1, name2, k, amount, se_string)
    pl.close('all')
    print('Plotting ROC-plot')
    roc_plot(species_by_cutoff, k, path, amount, se_string)
    pl.close('all')
    print('Plotting distribution heights')
    hist_distribution_plot(infile_real, path, species_scores, name1, name2, k, amount, se_string, filetype)
    pl.close('all')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-k', '--kmer', required=True)
    parser.add_argument('-p', '--path', required=True)
    parser.add_argument('-a', '--amount', required=False, default=50000)
    parser.add_argument('-n', '--names', required=True, action='append',
                        help='give 2 species names (-n SPECIES1 -n SPECIES2)')
    parser.add_argument('-t', '--filetype', default='fastq',
                        help='specify whether your input is fastq of fasta-type. Default = fastq. Type -t fasta to change to fasta',
                        required=False)
    parser.add_argument('-s', 'simulated', required=True)
    parser.add_argument('-r', 'real', required=True)
    args = parser.parse_args()
    path = args.path
    amount = int(args.amount)
    filetype = args.filetype
    name1 = args.names[0]
    name2 = args.names[1]
    k = args.kmer
    infile_simulated = str(args.simulated)
    infile_real = str(args.real)
    main(k, path, name1, name2, infile_simulated, infile_real, amount, filetype)

