#!/usr/bin/env python3

import sys, os, jinja2
from shutil import copyfile

def render(tpl_path, context):
    path, filename = os.path.split(tpl_path)
    return jinja2.Environment(undefined=jinja2.StrictUndefined,
        loader=jinja2.FileSystemLoader(path or './')
    ).get_template(filename).render(context)

####

types_l=['real', 'complex', 'integer', 'logical' ]

sp_d={'name':'sp', 'val':'real32'}
dp_d={'name':'dp', 'val':'real64'}
i4_d={'name':'i4', 'val':'int32'}
i8_d={'name':'i8', 'val':'int64'}
l4_d={'name':'l4', 'val':'int32'}

kinds_d={'real':    [sp_d, dp_d] ,
         'complex': [sp_d, dp_d] ,
         'integer': [i4_d, i8_d] ,
         'logical': [l4_d] }

nranks=6
nranks_test=4

# Generate module
with open('devxlib_malloc.f90', 'w') as f:
    f.write(render('devxlib_malloc.jf90',
                    {'types' : types_l, 'kinds' : kinds_d, 'dimensions': nranks }
                  ))

# Generate submodules
with open('devxlib_malloc_alloc.f90', 'w') as f:
    f.write(render('devxlib_malloc_alloc.jf90',
                    {'types' : types_l, 'kinds' : kinds_d, 'dimensions': nranks }
                  ))
with open('devxlib_malloc_free.f90', 'w') as f:
    f.write(render('devxlib_malloc_free.jf90',
                    {'types' : types_l, 'kinds' : kinds_d, 'dimensions': nranks }
                  ))
with open('devxlib_malloc_allocated.f90', 'w') as f:
    f.write(render('devxlib_malloc_allocated.jf90',
                    {'types' : types_l, 'kinds' : kinds_d, 'dimensions': nranks }
                  ))

# Generate tests
with open('test_malloc.f90', 'w') as f:
    f.write(render('test_malloc.jf90',
                    {'types' : types_l, 'kinds' : kinds_d, 'dimensions': nranks_test }
                  ))

## set multiple extensions
#copyfile("device_memcpy.f90","device_memcpy.F")
#copyfile("device_memcpy_interf.f90","device_memcpy_interf.F")

