#!/usr/bin/python
"""
takes templated file .xxx.src and produces .xxx file  where .xxx is
.i or .c or .h, using the following template rules

/**begin repeat  -- on a line by itself marks the start of a repeated code
                    segment
/**end repeat**/ -- on a line by itself marks it's end

After the /**begin repeat and before the */, all the named templates are placed
these should all have the same number of replacements

Repeat blocks can be nested, with each nested block labeled with its depth,
i.e.
/**begin repeat1
 *....
 */
/**end repeat1**/

In the main body each replace will use one entry from the list of named replacements

 Note that all #..# forms in a block must have the same number of
   comma-separated entries.

Example:

    An input file containing

        /**begin repeat
         * #a = 1,2,3#
         * #b = 1,2,3#
         */

        /**begin repeat1
         * #c = ted, jim#
         */
        @a@, @b@, @c@
        /**end repeat1**/

        /**end repeat**/

    produces

        line 1 "template.c.src"

        /*
         *********************************************************************
         **       This file was autogenerated from a template  DO NOT EDIT!!**
         **       Changes should be made to the original source (.src) file **
         *********************************************************************
         */

        #line 9
        1, 1, ted

        #line 9
        1, 1, jim

        #line 9
        2, 2, ted

        #line 9
        2, 2, jim

        #line 9
        3, 3, ted

        #line 9
        3, 3, jim

"""

__all__ = ['process_str', 'process_file']

import os
import sys
import re

# names for replacement that are already global.
global_names = {}

# header placed at the front of head processed file
header =\
"""
/*
 *****************************************************************************
 **       This file was autogenerated from a template  DO NOT EDIT!!!!      **
 **       Changes should be made to the original source (.src) file         **
 *****************************************************************************
 */

"""
# Parse string for repeat loops
def parse_structure(astr, level):
    """
    The returned line number is from the beginning of the string, starting
    at zero. Returns an empty list if no loops found.

    """
    if level == 0 :
        loopbeg = "/**begin repeat"
        loopend = "/**end repeat**/"
    else :
        loopbeg = "/**begin repeat%d" % level
        loopend = "/**end repeat%d**/" % level

    ind = 0
    line = 0
    spanlist = []
    while 1:
        start = astr.find(loopbeg, ind)
        if start == -1:
            break
        start2 = astr.find("*/",start)
        start2 = astr.find("\n",start2)
        fini1 = astr.find(loopend,start2)
        fini2 = astr.find("\n",fini1)
        line += astr.count("\n", ind, start2+1)
        spanlist.append((start, start2+1, fini1, fini2+1, line))
        line += astr.count("\n", start2+1, fini2)
        ind = fini2
    spanlist.sort()
    return spanlist


def paren_repl(obj):
    torep = obj.group(1)
    numrep = obj.group(2)
    return ','.join([torep]*int(numrep))

parenrep = re.compile(r"[(]([^)]*?)[)]\*(\d+)")
plainrep = re.compile(r"([^*]+)\*(\d+)")
def conv(astr):
    # replaces all occurrences of '(a,b,c)*4' in astr
    #  with 'a,b,c,a,b,c,a,b,c,a,b,c'. The result is
    # split at ',' and a list of values returned.
    astr = parenrep.sub(paren_repl,astr)
    # replaces occurences of xxx*3 with xxx, xxx, xxx
    astr = ','.join([plainrep.sub(paren_repl,x.strip())
                     for x in astr.split(',')])
    return astr.split(',')

named_re = re.compile(r"#\s*([\w]*)\s*=\s*([^#]*)#")
def parse_loop_header(loophead) :
    """Find all named replacements in the header

    Returns a list of dictionaries, one for each loop iteration,
    where each key is a name to be substituted and the corresponding
    value is the replacement string.

    """
    # parse out the names and lists of values
    names = []
    reps = named_re.findall(loophead)
    nsub = None
    for rep in reps:
        name = rep[0].strip()
        vals = conv(rep[1])
        size = len(vals)
        if nsub is None :
            nsub = size
        elif nsub != size :
            msg = "Mismatch in number: %s - %s" % (name, vals)
            raise ValueError, msg
        names.append((name,vals))

    # generate list of dictionaries, one for each template iteration
    dlist = []
    for i in range(nsub) :
        tmp = {}
        for name,vals in names :
            tmp[name] = vals[i]
        dlist.append(tmp)
    return dlist

replace_re = re.compile(r"@([\w]+)@")
def parse_string(astr, env, level, line) :
    lineno = "#line %d\n" % line

    # local function for string replacement, uses env
    def replace(match):
        name = match.group(1)
        try :
            val = env[name]
        except KeyError, e :
            msg = '%s: %s'%(lineno, e)
            raise KeyError, msg
        return val

    code = [lineno]
    struct = parse_structure(astr, level)
    if struct :
        # recurse over inner loops
        oldend = 0
        newlevel = level + 1
        for sub in struct:
            pref = astr[oldend:sub[0]]
            head = astr[sub[0]:sub[1]]
            text = astr[sub[1]:sub[2]]
            oldend = sub[3]
            newline = line + sub[4]
            code.append(replace_re.sub(replace, pref))
            try :
                envlist = parse_loop_header(head)
            except ValueError, e :
                msg = "%s: %s" % (lineno, e)
                raise ValueError, msg
            for newenv in envlist :
                newenv.update(env)
                newcode = parse_string(text, newenv, newlevel, newline)
                code.extend(newcode)
        suff = astr[oldend:]
        code.append(replace_re.sub(replace, suff))
    else :
        # replace keys
        code.append(replace_re.sub(replace, astr))
    code.append('\n')
    return ''.join(code)

def process_str(astr):
    code = [header]
    code.extend(parse_string(astr, global_names, 0, 1))
    return ''.join(code)


include_src_re = re.compile(r"(\n|\A)#include\s*['\"]"
                            r"(?P<name>[\w\d./\\]+[.]src)['\"]", re.I)

def resolve_includes(source):
    d = os.path.dirname(source)
    fid = open(source)
    lines = []
    for line in fid.readlines():
        m = include_src_re.match(line)
        if m:
            fn = m.group('name')
            if not os.path.isabs(fn):
                fn = os.path.join(d,fn)
            if os.path.isfile(fn):
                print 'Including file',fn
                lines.extend(resolve_includes(fn))
            else:
                lines.append(line)
        else:
            lines.append(line)
    fid.close()
    return lines

def process_file(source):
    lines = resolve_includes(source)
    sourcefile = os.path.normcase(source).replace("\\","\\\\")
    code = process_str(''.join(lines))
    return '#line 1 "%s"\n%s' % (sourcefile, code)


def unique_key(adict):
    # this obtains a unique key given a dictionary
    # currently it works by appending together n of the letters of the
    #   current keys and increasing n until a unique key is found
    # -- not particularly quick
    allkeys = adict.keys()
    done = False
    n = 1
    while not done:
        newkey = "".join([x[:n] for x in allkeys])
        if newkey in allkeys:
            n += 1
        else:
            done = True
    return newkey


if __name__ == "__main__":

    try:
        file = sys.argv[1]
    except IndexError:
        fid = sys.stdin
        outfile = sys.stdout
    else:
        fid = open(file,'r')
        (base, ext) = os.path.splitext(file)
        newname = base
        outfile = open(newname,'w')

    allstr = fid.read()
    writestr = process_str(allstr)
    outfile.write(writestr)
