####################################################################
# combine_replicates.py
#   used to combine lamarc runs from multiple replicates into a
#   single region file with insumfiles for "poor man's parallelization"
####################################################################

# system imports
import getopt
import random
import re
import sys
import os.path
from xml.dom.minidom import parse, Document

import parallelCommon


file_reg_rep_pattern = re.compile('reg(.*)_rep(.*)')
number_pattern       = re.compile('\s*(\d+)\s+(\d+)\s+(\d+)')
reg_rep_pattern      = re.compile('\s*(\d+)\s+(\d+)')

# each file assumes it is for the only region and only replicate
# and so has region and replicate numbers of zero. Here we
# replace those zeroes in the <number> and <reg_rep> tags
def replaceReplicateNumbers(myElem,repNum):
    if repNum < 0:
        print "bad replicate number, %s , to replaceReplicateNumbers\n" , repNum
        exit(2)

    for numberTag in myElem.getElementsByTagName("number"):
        for child in numberTag.childNodes:
            if child.nodeType == child.TEXT_NODE:
                m = number_pattern.match(child.data)
                if m:
                    child.data = " %s %s %s " % (m.group(1),repNum,m.group(3))

    for regRepTag in myElem.getElementsByTagName("reg_rep"):
        for child in regRepTag.childNodes:
            if child.nodeType == child.TEXT_NODE:
                m = reg_rep_pattern.match(child.data)
                if m:
                    child.data = " %s %s " % (m.group(1),repNum)

# combine outsumfiles for multiple replicates over a single region 
# into a single insumfile.
# argument "regString" will usually be something like "reg2", but
# can be "final" for a multiple replicate run over a single region
def combineSumfiles(lamdir,regString,regRepList):

    parallelCommon.sumfileCombineWarn(regString)

    # create output dom
    sumDoc = Document()
    topElem = sumDoc.createElement("XML-summary-file")
    sumDoc.appendChild(topElem)

    for suffix in regRepList:
        # get reg and rep number form suffix
        m = file_reg_rep_pattern.match(suffix)
        regNum = "" 
        repNum = ""
        if m:
            regNum = m.group(1)
            repNum = m.group(2)

        # find file and parse it in
        outsumdir = os.path.join(lamdir,'%s' % suffix)
        oneRep = os.path.join(outsumdir,'outsumfile_%s.xml' % suffix)
        inf = open(oneRep,'r')
        oneDom = parse(inf)
        inf.close()
        localTop = parallelCommon.getFirstTag(oneDom,"XML-summary-file")

        # add to output xml structure
        childNodeList = localTop.childNodes
        for myNode in childNodeList:
            myNewNode = sumDoc.importNode(myNode,True)
            topElem.appendChild(myNewNode)
            if myNewNode.nodeType == myNewNode.ELEMENT_NODE:
                replaceReplicateNumbers(myNewNode,repNum)

        
    # write output file
    insumdir = os.path.join(lamdir,'%s' % regString)
    sumfile = os.path.join(insumdir,'insumfile_%s.xml' % regString)
    outf = open(sumfile,'w')
    sumDoc.writexml(outf)
    outf.close()
    parallelCommon.stripXmlInfo(sumfile)


# input options
[lamarcfile,lamdir,pydir] = parallelCommon.getOptionsAndVerify(False)
parallelCommon.describeThisScript("combine_replicates.py","combine replicate runs for each region",lamarcfile,lamdir)

inlamarc = open(lamarcfile,'r')
lamDom = parse(inlamarc)
inlamarc.close()

# add comment to identify outfile as generated by this script
lamarcTag = parallelCommon.getFirstTag(lamDom,"lamarc")
commentNode = lamDom.createComment("Created by combine_replicates.py")
lamarcTag.insertBefore(commentNode,lamarcTag.firstChild)

# divide up regions
dataElem = parallelCommon.getSingleTag(lamarcTag,"data")
regionElems = dataElem.getElementsByTagName("region")
singleRegion = len(regionElems) == 1

# turn off profiling -- we only do it at the end
if not singleRegion:
    parallelCommon.turnProfilesOff(lamDom,lamarcTag)

# find replicate number -- do not change it
chainsTag = parallelCommon.getFirstTag(lamarcTag,"chains")
replicatesTag = parallelCommon.getFirstTag(chainsTag,"replicates")
originalReplicateCount = 1
if replicatesTag:
    originalReplicateCount = parallelCommon.getLongVal(replicatesTag)

# find format tag
formatTag = parallelCommon.getFirstTag(lamarcTag,"format")

# disconnect regions from dom
for region in regionElems:
    dataElem.removeChild(region)

infileList = []

# output each region file
regCount = 0
for region in regionElems:
    dataElem.appendChild(region)
    repList = []
    for repCount in range(originalReplicateCount):
        idString = 'reg%d_rep%d' % (regCount,repCount)
        repList.append(idString)

    regString = 'reg%d' % regCount
    if singleRegion:
        regString = "final"

    parallelCommon.fixFormatTag(lamDom,formatTag,regString,True)

    runDir = os.path.join(lamdir,'%s' % regString)
    if not os.path.exists(runDir):
        os.makedirs(runDir)
    regFile = os.path.join(runDir,'infile_%s.xml' % regString)
    outf = open(regFile,'w')
    lamDom.writexml(outf)
    outf.close()
    infileList.append(regFile)

    combineSumfiles(lamdir,regString,repList)

    dataElem.removeChild(region)
    regCount = regCount + 1
    # unlink region
    

# output instructions: files to run, next program to run
parallelCommon.nextStep(lamdir,infileList,singleRegion)
if not singleRegion:
    parallelCommon.finalStep(pydir,"combine_regions.py",lamarcfile,lamdir)
