#!/usr/bin/python 

import sys
import math
import numpy

def nCk(n,k):
    return math.factorial(n) / math.factorial(k) / math.factorial(n - k)


usage = "usage: doMS.py [ --priorfile filename | --flatmodelprior | --flatphenoprior ] [ --mode 1|2|3|4 ] input.select.multinom \n"

usageArgs = usage
usageArgs += "\n"
usageArgs += "Arguments to set the analysis mode:  \n"
usageArgs += "  --mode 1    - Bayes factor normalization (outputs a logBF for each model relative to the null model)\n"
usageArgs += "  --mode 2    - Posterior normalization (ouputs a posterior probability for each model) [requires a prior]\n"
usageArgs += "  --mode 3    - Phenotype specificity summary (outputs the best phenotype-specific and shared models, with a Bayes factor in favour of the former)\n"
usageArgs += "  --mode 4    - Credible set summary (outputs the best model, the posterior for this model, and a 95% credible set of models)\n"
usageArgs += "\n"
usageArgs += "Arguments to set the prior:  \n"
usageArgs += "  --priorfile filename    - specifies a model prior file (two columns, first column is the model name and second column is the prior probability for that model) \n"
usageArgs += "  --flatmodelprior        - use a default uniform prior on models \n"
usageArgs += "  --flatphenoprior        - use a default uniform prior on phenotype counts \n"
usageArgs += "\n"

### process the arguments

c = 1

priorfile = ""
priormode = 0
mode = 0

if len(sys.argv) == 1:
   sys.stderr.write(usageArgs)
   sys.exit(1) 

while (c < len(sys.argv) - 1):
    if sys.argv[c] == "--h" or sys.argv[c] == "--help" or sys.argv[c] == "-h":
        sys.stderr.write(usageArgs)
        sys.exit(1)    
    if sys.argv[c] == "--priorfile":
        c += 1
        priorfile = sys.argv[c]
        c += 1
        continue
    if sys.argv[c] == "--flatmodelprior":
        c += 1
        priormode = 1
        continue
    if sys.argv[c] == "--flatphenoprior":
        c += 1
        priormode = 2
        continue
    if sys.argv[c] == "--mode":
        c+= 1
        mode = int(sys.argv[c])
        c += 1
        continue
    sys.stderr.write("Error: argument " + sys.argv[c] + " not recognised\n")
    sys.stderr.write(usage)
    sys.exit(1)    

if c >= len(sys.argv):
    sys.stderr.write("Error: too few arguments\n")
    sys.stderr.write(usage)
    sys.exit(1)

if mode == 1:
    sys.stderr.write("Analysis mode 1 (Bayes factor normalization)\n")
elif mode == 2:
    sys.stderr.write("Analysis mode 2 (posterior normalization)\n")
elif mode == 3:
    sys.stderr.write("Analysis mode 3 (Phenotype specificity report)\n")
elif mode == 4:
    sys.stderr.write("Analysis mode 4 (Credible set report)\n")
elif mode == 0:
    sys.stderr.write("No analysis mode specified, defaulting to mode 1 (Bayes factor normalization)\n")
    mode = 1
else:
    sys.stderr.write("Error: analysis mode " + mode + " not recognised\n")
    sys.exit(1)

if (mode == 1 or mode == 3) and (priorfile != "" or priormode != 0):
    sys.stderr.write("Warning: You have specified a prior on models, but this analysis mode does not use it.\n")
if (mode == 2 or mode == 4) and (priorfile == "" and priormode == 0):
    sys.stderr.write("Warning: You have not specified a prior on models, but this analysis mode requires one. Defaulting to a flat prior on models\n")
    priormode = 1

if (priorfile != ""):
    if priormode != 0:
        sys.stderr.write("Error: You have specified a prior mode and a prior file. Please select one or the other.\n")
        sys.exit(1)
    sys.stderr.write("Reading prespecified prior from file " + priorfile + ".\n")
    PRIORFILE = open(priorfile)
    priordict = {}
    for line_raw in PRIORFILE:
        line = line_raw.split()
        priordict[line[0]] = float(line[1])
    if "NULL" not in priordict:
        sys.stderr.write("Warning: No prior specified for the NULL model. Setting it to zero.\n")
        priordict["NULL"] = 0.0

if priormode == 1:
    sys.stderr.write("Using default prior (uniform across models)\n")
if priormode == 2:
    sys.stderr.write("Using default prior (uniform across phenotype counts)\n")


FILE = open(sys.argv[c])

phenos = set([])

header = FILE.readline().split()

if priorfile != "":
    priors = [priordict["NULL"]]

ModelSize = [0]
for h in header[6:]:
    if priorfile != "":
        if h not in priordict:
            sys.stderr.write("Error: No prior specified for model " + h + "\n")    
            sys.exit(1)
        priors.append(priordict[h])
    temp = set(h.split("_")[1:])
    phenos |= temp
    ModelSize.append(len(temp))

Npheno = len(phenos)
Nmodel = len(header) - 5
#print Npheno, Nmodel
if Nmodel != 2**Npheno:
    sys.stderr.write("Error: Wrong number of unique phenotypes in header relative to columns. Input data misformatted." + h + "\n")    
    sys.exit(1)

if priormode != 0:
    if (priormode == 1):
        priors = [1.0/Nmodel]
    else:
        priors = [1.0/(Npheno + 1)]
    for i in range(len(header[6:])):
        if (priormode == 1):
            priors.append(1.0/Nmodel)
        else:
            k = ModelSize[i + 1]
            #print k
            priors.append(1.0/((Npheno + 1)*nCk(Npheno,k)))
#print priors
if mode == 1 or mode == 2:
    print "\t".join(header)
if mode == 3:
    print "\t".join(header[0:5]),
    print "BestSpecModel\tBestSharedModel\tLogSpecificityBF"
if mode == 4:
    print "\t".join(header[0:5]),
    print "BestModel\tBestModelProb\tCredibleSet"

for line_raw in FILE:
    line = line_raw.split()
    print "\t".join(line[0:5]),
    BFs = [0.0]
    for i in range(6,len(line)):
        BFs.append(float(line[i]) - float(line[5]))
    if mode == 2 or mode == 4:
        posteriors = numpy.exp(numpy.array(BFs)*numpy.array(priors))/sum(numpy.exp(numpy.array(BFs)*numpy.array(priors)))
    if (mode == 1):
        print "\t".join([str(x) for x in BFs])
    elif (mode == 2):
        print "\t".join([str(x) for x in posteriors])
    elif (mode == 3):
        bestShare = -1
        bestShareBF = -1e100
        bestUnique = -1
        bestUniqueBF = -1e100
        for i in range(Nmodel):
            if ModelSize[i] == 1 and BFs[i] > bestUniqueBF:
                bestUniqueBF = BFs[i]
                bestUnique = i
            if ModelSize[i] > 1 and BFs[i] > bestShareBF:
                bestShareBF = BFs[i]
                bestShare = i
        #print bestShare, bestShareBF, bestUnique, bestUniqueBF
        print header[5 + bestUnique],header[5 + bestShare], bestUniqueBF - bestShareBF
    elif (mode == 4):
        order = numpy.argsort(-posteriors)
        CredSet = header[5 + order[0]]
        CredValue = posteriors[order[0]]
        i = 1
        while (CredValue < 0.95 and i < Nmodel):
            CredSet += "&" +  header[5 + order[i]]
            CredValue += posteriors[order[i]]
            i += 1
        print header[5 + order[0]],posteriors[order[0]],CredSet
#print phenos
#print priors