File: template.py

package info (click to toggle)
bottleneck 1.2.1%2Bds1-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 564 kB
  • sloc: ansic: 4,414; python: 1,742; makefile: 68
file content (178 lines) | stat: -rw-r--r-- 5,085 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import re
import ast


def make_c_files():
    modules = ['reduce', 'move', 'nonreduce', 'nonreduce_axis']
    dirpath = os.path.dirname(__file__)
    for module in modules:
        filepath = os.path.join(dirpath, module + '_template.c')
        with open(filepath, 'r') as f:
            src_str = f.read()
        src_str = template(src_str)
        filepath = os.path.join(dirpath, module + '.c')
        with open(filepath, 'w') as f:
            f.write(src_str)


def template(src_str):
    src_list = src_str.splitlines()
    src_list = repeat_templating(src_list)
    src_list = dtype_templating(src_list)
    src_list = string_templating(src_list)
    src_str = '\n'.join(src_list)
    src_str = re.sub(r'\n\s*\n\s*\n', r'\n\n', src_str)
    return src_str


# repeat --------------------------------------------------------------------

REPEAT_BEGIN = r'^/\*\s*repeat\s*=\s*'
REPEAT_END = r'^/\*\s*repeat end'
COMMENT_END = r'.*\*\/.*'


def repeat_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, REPEAT_BEGIN, REPEAT_END)
        if idx0 is None:
            break
        func_list = lines[idx0:idx1]
        func_list = expand_functions_repeat(func_list)
        # the +1 below is to skip the /* repeat end */ line
        lines = lines[:idx0] + func_list + lines[idx1+1:]
        index = idx0
    return lines


def expand_functions_repeat(lines):
    idx = first_occurence(COMMENT_END, lines)
    repeat_dict = repeat_info(lines[:idx + 1])
    lines = lines[idx + 1:]
    func_str = '\n'.join(lines)
    func_list = expand_repeat(func_str, repeat_dict)
    return func_list


def repeat_info(lines):
    line = ''.join(lines)
    repeat = re.findall(r'\{.*\}', line)
    repeat_dict = ast.literal_eval(repeat[0])
    return repeat_dict


def expand_repeat(func_str, repeat_dict):
    nrepeats = [len(repeat_dict[key]) for key in repeat_dict]
    if len(set(nrepeats)) != 1:
        raise ValueError("All repeat lists must be the same length")
    nrepeat = nrepeats[0]
    func_list = []
    for i in range(nrepeat):
        f = func_str[:]
        for key in repeat_dict:
            f = f.replace(key, repeat_dict[key][i])
        func_list.append('\n' + f)
    func_list = (''.join(func_list)).splitlines()
    return func_list


# dtype ---------------------------------------------------------------------

DTYPE_BEGIN = r'^/\*\s*dtype\s*=\s*'
DTYPE_END = r'^/\*\s*dtype end'


def dtype_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, DTYPE_BEGIN, DTYPE_END)
        if idx0 is None:
            break
        func_list = lines[idx0:idx1]
        func_list = expand_functions_dtype(func_list)
        # the +1 below is to skip the /* dtype end */ line
        lines = lines[:idx0] + func_list + lines[idx1+1:]
        index = idx0
    return lines


def expand_functions_dtype(lines):
    idx = first_occurence(COMMENT_END, lines)
    dtypes = dtype_info(lines[:idx + 1])
    lines = lines[idx + 1:]
    func_str = '\n'.join(lines)
    func_list = expand_dtypes(func_str, dtypes)
    return func_list


def dtype_info(lines):
    line = ''.join(lines)
    dtypes = re.findall(r'\[.*\]', line)
    if len(dtypes) != 1:
        raise ValueError("expecting exactly one dtype specification")
    dtypes = ast.literal_eval(dtypes[0])
    return dtypes


def expand_dtypes(func_str, dtypes):
    if 'DTYPE' not in func_str:
        raise ValueError("cannot find dtype marker")
    func_list = []
    for dtype in dtypes:
        f = func_str[:]
        for i, dt in enumerate(dtype):
            f = f.replace('DTYPE%d' % i, dt)
            if i > 0:
                f = f + '\n'
        func_list.append('\n\n' + f)
    return func_list


# multiline strings ---------------------------------------------------------

STRING_BEGIN = r'.*MULTILINE STRING BEGIN.*'
STRING_END = r'.*MULTILINE STRING END.*'


def string_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, STRING_BEGIN, STRING_END)
        if idx0 is None:
            break
        str_list = lines[idx0+1:idx1]
        str_list = quote_string(str_list)
        lines = lines[:idx0] + str_list + lines[idx1+1:]
        index = idx0
    return lines


def quote_string(lines):
    for i in range(len(lines)):
        lines[i] = "\"" + lines[i] + r"\n" + "\""
    lines[-1] = lines[-1] + ";"
    return lines


# utility -------------------------------------------------------------------

def first_occurence(pattern, lines):
    for i in range(len(lines)):
        if re.match(pattern, lines[i]):
            return i
    raise ValueError("`pattern` not found")


def next_block(lines, index, begine_pattern, end_pattern):
    idx = None
    for i in range(index, len(lines)):
        line = lines[i]
        if re.match(begine_pattern, line):
            idx = i
        elif re.match(end_pattern, line):
            if idx is None:
                raise ValueError("found end of function before beginning")
            return idx, i
    return None, None