####################################################################
# combine_regions.py
#   used to combine lamarc runs from regions (which may include 
#   multiple replicates) into a single region file with insumfiles 
#   for "poor man's parallelization"
####################################################################

# system imports
import re
import os.path
import sys
from xml.dom.minidom import parse, Document

import parallelCommon


file_reg_pattern     = re.compile('reg(.*)')
number_pattern       = re.compile('\s*(\d+)\s+(\d+)\s+(\d+)')
reg_rep_pattern      = re.compile('\s*(\d+)\s+(\d+)')

t_number_pattern       = re.compile('(\s*<number>\s*)(\d+)(\s+\d+\s+\d+\s*</number>\s*)')
t_reg_rep_pattern      = re.compile('(\s*<reg_rep>\s*)(\d+)(\s+\d+\s*</reg_rep>\s*)')
t_top_tag_pattern      = re.compile('.*XML.summary.file.*')

# each file assumes it is for the only region 
# and so has region numbers of zero. Here we
# replace those zeroes in the <number> and <reg_rep> tags
def replaceRegionNumbers(myElem,regNum):
    if regNum < 0:
        print "Bad region number, %s, to replaceRegionNumbers\n" , regNum
        exit(2)

    for numberTag in myElem.getElementsByTagName("number"):
        for child in numberTag.childNodes:
            if child.nodeType == child.TEXT_NODE:
                m = number_pattern.match(child.data)
                if m:
                    child.data = " %s %s %s " % (regNum,m.group(2),m.group(3))

    for regRepTag in myElem.getElementsByTagName("reg_rep"):
        for child in regRepTag.childNodes:
            if child.nodeType == child.TEXT_NODE:
                m = reg_rep_pattern.match(child.data)
                if m:
                    child.data = " %s %s " % (regNum,m.group(2))

# combine outsumfiles for single regions into a single insumfile.
# 
# lamdir == top level directory containing subdirectories for runs
# rundir == where to write the sumfile
# regionList == list of strings like "reg0" and "reg1"
def combineSumfilesRegion(lamdir,rundir,regionList):

    parallelCommon.sumfileCombineWarn(rundir)

    # write output file
    sumfile = os.path.join(rundir,'insumfile_final.xml')
    outf = open(sumfile,'w')
    outf.write("<XML-summary-file>\n")

    for suffix in regionList:
        # get reg and rep number form suffix
        m = file_reg_pattern.match(suffix)
        regNum = "" 
        if m:
            regNum = m.group(1)

        # find file and parse it in
        sumdir = os.path.join(lamdir,'reg%s' % regNum)
        oneRep = os.path.join(sumdir,'outsumfile_%s.xml' % suffix)
        print "processing %s" % oneRep
        sys.stdout.flush()


        # append to existing file, changing region numbers as necessary
        for line in open(oneRep):

            # throw away outermost tags
            m = t_top_tag_pattern.match(line)
            if m:
                continue

            # replace region number in <number> tag and continue
            m = t_number_pattern.match(line)
            if m:
                outf.write("%s%s%s\n" % (m.group(1),regNum,m.group(3)))
                continue

            # replace region number in <reg_rep> tag and continue
            m = t_reg_rep_pattern.match(line)
            if m:
                outf.write("%s%s%s\n" % (m.group(1),regNum,m.group(3)))
                continue

            # not a special case
            outf.write("%s\n" % line)


    outf.write("</XML-summary-file>\n")
    outf.close()

def combineSumfilesRegion__OLD(lamdir,rundir,regionList):

    parallelCommon.sumfileCombineWarn(rundir)

    # create output dom
    sumDoc = Document()
    topElem = sumDoc.createElement("XML-summary-file")
    sumDoc.appendChild(topElem)


    for suffix in regionList:
        # get reg and rep number form suffix
        m = file_reg_pattern.match(suffix)
        regNum = "" 
        if m:
            regNum = m.group(1)

        # find file and parse it in
        sumdir = os.path.join(lamdir,'reg%s' % regNum)
        oneRep = os.path.join(sumdir,'outsumfile_%s.xml' % suffix)
        inf = open(oneRep,'r')
        oneDom = parse(inf)
        inf.close()
        localTop = parallelCommon.getFirstTag(oneDom,"XML-summary-file")

        # add to output xml structure
        childNodeList = localTop.childNodes
        for myNode in childNodeList:
            myNewNode = sumDoc.importNode(myNode,True)
            topElem.appendChild(myNewNode)
            if myNewNode.nodeType == myNewNode.ELEMENT_NODE:
                replaceRegionNumbers(myNewNode,regNum)

        
    # write output file
    sumfile = os.path.join(rundir,'insumfile_final.xml')
    outf = open(sumfile,'w')
    sumDoc.writexml(outf)
    outf.close()
    parallelCommon.stripXmlInfo(sumfile)

# input options
[lamarcfile,lamdir,pydir] = parallelCommon.getOptionsAndVerify(False)
parallelCommon.describeThisScript("combine_regions","combine all regions and replicates",lamarcfile,lamdir)

inlamarc = open(lamarcfile,'r')
lamDom = parse(inlamarc)
inlamarc.close()

# add comment to identify outfile as generated by this script
lamarcTag = parallelCommon.getFirstTag(lamDom,"lamarc")
commentNode = lamDom.createComment("Created by combine_regions.py")
lamarcTag.insertBefore(commentNode,lamarcTag.firstChild)


# EWFIX.REMOVE
# find replicate number -- do not change it
#chainsTag = parallelCommon.getFirstTag(lamarcTag,"chains")
#replicatesTag = parallelCommon.getFirstTag(chainsTag,"replicates")
#originalReplicateCount = 1
#if replicatesTag:
#    originalReplicateCount = parallelCommon.getLongVal(replicatesTag)

# find format tag
formatTag = parallelCommon.getFirstTag(lamarcTag,"format")

# divide up regions
dataElem = parallelCommon.getSingleTag(lamarcTag,"data")
regionElems = dataElem.getElementsByTagName("region")

infileList = []

# output each region file
regionList = []
for regCount in range(regionElems.length):
    idString = 'reg%d' % regCount
    regionList.append(idString)

parallelCommon.fixFormatTag(lamDom,formatTag,"final",True)
rundir = os.path.join(lamdir,'final')
os.makedirs(rundir)
regFile = os.path.join(rundir,'infile_final.xml')
outf = open(regFile,'w')
lamDom.writexml(outf)
outf.close()

combineSumfilesRegion(lamdir,rundir,regionList)

# output instructions: files to run, next program to run
parallelCommon.nextStep(lamdir,[regFile],True)
