1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
"""
Fast image interpolation using a pyramid.
"""
import halide as hl
from datetime import datetime
import imageio
import numpy as np
import os.path
int_t = hl.Int(32)
float_t = hl.Float(32)
def get_interpolate(input, levels):
"""
Build function, schedules it, and invokes jit compiler
:return: halide.hl.Func
"""
# THE ALGORITHM
downsampled = [hl.Func('downsampled%d' % i) for i in range(levels)]
downx = [hl.Func('downx%d' % l) for l in range(levels)]
interpolated = [hl.Func('interpolated%d' % i) for i in range(levels)]
upsampled = [hl.Func('upsampled%d' % l) for l in range(levels)]
upsampledx = [hl.Func('upsampledx%d' % l) for l in range(levels)]
x = hl.Var('x')
y = hl.Var('y')
c = hl.Var('c')
clamped = hl.Func('clamped')
clamped[x, y, c] = input[hl.clamp(x, 0, input.width() - 1), hl.clamp(y, 0, input.height() - 1), c]
# This triggers a bug in llvm 3.3 (3.2 and trunk are fine), so we
# rewrite it in a way that doesn't trigger the bug. The rewritten
# form assumes the input alpha is zero or one.
# downsampled[0][x, y, c] = hl.select(c < 3, clamped[x, y, c] * clamped[x, y, 3], clamped[x, y, 3])
downsampled[0][x, y, c] = clamped[x, y, c] * clamped[x, y, 3]
for l in range(1, levels):
prev = downsampled[l - 1]
if l == 4:
# Also add a boundary condition at a middle pyramid level
# to prevent the footprint of the downsamplings to extend
# too far off the base image. Otherwise we look 512
# pixels off each edge.
w = input.width() / (1 << l)
h = input.height() / (1 << l)
prev = hl.lambda_func(x, y, c, prev[hl.clamp(x, 0, w), hl.clamp(y, 0, h), c])
downx[l][x, y, c] = (prev[x * 2 - 1, y, c] + 2.0 * prev[x * 2, y, c] + prev[x * 2 + 1, y, c]) * 0.25
downsampled[l][x, y, c] = (downx[l][x, y * 2 - 1, c] + 2.0 * downx[l][x, y * 2, c] + downx[l][
x, y * 2 + 1, c]) * 0.25
interpolated[levels - 1][x, y, c] = downsampled[levels - 1][x, y, c]
for l in range(levels - 1)[::-1]:
upsampledx[l][x, y, c] = (interpolated[l + 1][x / 2, y, c] + interpolated[l + 1][(x + 1) / 2, y, c]) / 2.0
upsampled[l][x, y, c] = (upsampledx[l][x, y / 2, c] + upsampledx[l][x, (y + 1) / 2, c]) / 2.0
interpolated[l][x, y, c] = downsampled[l][x, y, c] + (1.0 - downsampled[l][x, y, 3]) * upsampled[l][x, y, c]
normalize = hl.Func('normalize')
normalize[x, y, c] = interpolated[0][x, y, c] / interpolated[0][x, y, 3]
final = hl.Func('final')
final[x, y, c] = normalize[x, y, c]
print("Finished function setup.")
# THE SCHEDULE
target = hl.get_target_from_environment()
if target.has_gpu_feature():
sched = 4
else:
sched = 2
if sched == 0:
print("Flat schedule.")
for l in range(levels):
downsampled[l].compute_root()
interpolated[l].compute_root()
final.compute_root()
elif sched == 1:
print("Flat schedule with vectorization.")
for l in range(levels):
downsampled[l].compute_root().vectorize(x, 4)
interpolated[l].compute_root().vectorize(x, 4)
final.compute_root()
elif sched == 2:
print("Flat schedule with parallelization + vectorization")
xi, yi = hl.Var('xi'), hl.Var('yi')
clamped.compute_root().parallel(y).bound(c, 0, 4).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
for l in range(1, levels - 1):
if l > 0:
downsampled[l].compute_root().parallel(y).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
interpolated[l].compute_root().parallel(y).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
interpolated[l].unroll(x, 2).unroll(y, 2)
final.reorder(c, x, y).bound(c, 0, 3).parallel(y)
final.tile(x, y, xi, yi, 2, 2).unroll(xi).unroll(yi)
final.bound(x, 0, input.width())
final.bound(y, 0, input.height())
elif sched == 3:
print("Flat schedule with vectorization sometimes.")
for l in range(levels):
if l + 4 < levels:
downsampled[l].compute_root().vectorize(x, 4)
interpolated[l].compute_root().vectorize(x, 4)
else:
downsampled[l].compute_root()
interpolated[l].compute_root()
final.compute_root()
elif sched == 4:
print("GPU schedule.")
# Some gpus don't have enough memory to process the entire
# image, so we process the image in tiles.
yo, yi, xo, xi, ci = hl.Var('yo'), hl.Var('yi'), hl.Var('xo'), hl.Var("xi"), hl.Var("ci")
final.reorder(c, x, y).bound(c, 0, 3).vectorize(x, 4)
final.tile(x, y, xo, yo, xi, yi, input.width() / 4, input.height() / 4)
normalize.compute_at(final, xo).reorder(c, x, y).gpu_tile(x, y, xi, yi, 16, 16).unroll(c)
# Start from level 1 to save memory - level zero will be computed on demand
for l in range(1, levels):
tile_size = 32 >> l
if tile_size < 1: tile_size = 1
if tile_size > 16: tile_size = 16
downsampled[l].compute_root().gpu_tile(x, y, c, xi, yi, ci, tile_size, tile_size, 4)
interpolated[l].compute_at(final, xo).gpu_tile(x, y, c, xi, yi, ci, tile_size, tile_size, 4)
else:
print("No schedule with this number.")
exit(1)
# JIT compile the pipeline eagerly, so we don't interfere with timing
final.compile_jit(target)
return final
def get_input_data():
image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgba.png")
assert os.path.exists(image_path), "Could not find %s" % image_path
rgba_data = imageio.imread(image_path)
# input data is in range [0, 1]
input_data = np.copy(rgba_data, order="F").astype(np.float32) / 255.0
return input_data
def main():
input = hl.ImageParam(float_t, 3, "input")
levels = 10
interpolate = get_interpolate(input, levels)
# preparing input and output memory buffers (numpy ndarrays)
input_data = get_input_data()
assert input_data.shape[2] == 4
input_image = hl.Buffer(input_data)
input.set(input_image)
input_width, input_height = input_data.shape[:2]
t0 = datetime.now()
output_image = interpolate.realize([input_width, input_height, 3])
t1 = datetime.now()
elapsed = (t1 - t0).total_seconds()
print('Interpolated in {:.5f} secs'.format(elapsed))
output_data = np.asanyarray(output_image)
# convert output
input_data = (input_data * 255).astype(np.uint8)
output_data = (output_data * 255).astype(np.uint8)
# save results
input_path = "interpolate_input.png"
output_path = "interpolate_result.png"
imageio.imsave(input_path, input_data)
imageio.imsave(output_path, output_data)
print()
print('blur realized on output image. Result saved at {} (input data copy at {})'.format(output_path, input_path))
print()
print("End of game. Have a nice day!")
if __name__ == '__main__':
main()
|