
|
"""
Fast image interpolation using a pyramid.
"""
import halide as hl
from datetime import datetime
import imageio
import numpy as np
import os.path
int_t = hl.Int(32)
float_t = hl.Float(32)
def get_interpolate(input, levels):
"""
Build function, schedules it, and invokes jit compiler
:return: halide.hl.Func
"""
# THE ALGORITHM
downsampled = [hl.Func('downsampled%d' % i) for i in range(levels)]
downx = [hl.Func('downx%d' % l) for l in range(levels)]
interpolated = [hl.Func('interpolated%d' % i) for i in range(levels)]
upsampled = [hl.Func('upsampled%d' % l) for l in range(levels)]
upsampledx = [hl.Func('upsampledx%d' % l) for l in range(levels)]
x = hl.Var('x')
y = hl.Var('y')
c = hl.Var('c')
clamped = hl.Func('clamped')
clamped[x, y, c] = input[hl.clamp(x, 0, input.width() - 1), hl.clamp(y, 0, input.height() - 1), c]
# This triggers a bug in llvm 3.3 (3.2 and trunk are fine), so we
# rewrite it in a way that doesn't trigger the bug. The rewritten
# form assumes the input alpha is zero or one.
# downsampled[0][x, y, c] = hl.select(c < 3, clamped[x, y, c] * clamped[x, y, 3], clamped[x, y, 3])
downsampled[0][x, y, c] = clamped[x, y, c] * clamped[x, y, 3]
for l in range(1, levels):
prev = downsampled[l - 1]
if l == 4:
# Also add a boundary condition at a middle pyramid level
# to prevent the footprint of the downsamplings to extend
# too far off the base image. Otherwise we look 512
# pixels off each edge.
w = input.width() / (1 << l)
h = input.height() / (1 << l)
prev = hl.lambda_func(x, y, c, prev[hl.clamp(x, 0, w), hl.clamp(y, 0, h), c])
downx[l][x, y, c] = (prev[x * 2 - 1, y, c] + 2.0 * prev[x * 2, y, c] + prev[x * 2 + 1, y, c]) * 0.25
downsampled[l][x, y, c] = (downx[l][x, y * 2 - 1, c] + 2.0 * downx[l][x, y * 2, c] + downx[l][
x, y * 2 + 1, c]) * 0.25
interpolated[levels - 1][x, y, c] = downsampled[levels - 1][x, y, c]
for l in range(levels - 1)[::-1]:
upsampledx[l][x, y, c] = (interpolated[l + 1][x / 2, y, c] + interpolated[l + 1][(x + 1) / 2, y, c]) / 2.0
upsampled[l][x, y, c] = (upsampledx[l][x, y / 2, c] + upsampledx[l][x, (y + 1) / 2, c]) / 2.0
interpolated[l][x, y, c] = downsampled[l][x, y, c] + (1.0 - downsampled[l][x, y, 3]) * upsampled[l][x, y, c]
normalize = hl.Func('normalize')
normalize[x, y, c] = interpolated[0][x, y, c] / interpolated[0][x, y, 3]
final = hl.Func('final')
final[x, y, c] = normalize[x, y, c]
print("Finished function setup.")
# THE SCHEDULE
target = hl.get_target_from_environment()
if target.has_gpu_feature():
sched = 4
else:
sched = 2
if sched == 0:
print("Flat schedule.")
for l in range(levels):
downsampled[l].compute_root()
interpolated[l].compute_root()
final.compute_root()
elif sched == 1:
print("Flat schedule with vectorization.")
for l in range(levels):
downsampled[l].compute_root().vectorize(x, 4)
interpolated[l].compute_root().vectorize(x, 4)
final.compute_root()
elif sched == 2:
print("Flat schedule with parallelization + vectorization")
xi, yi = hl.Var('xi'), hl.Var('yi')
clamped.compute_root().parallel(y).bound(c, 0, 4).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
for l in range(1, levels - 1):
if l > 0:
downsampled[l].compute_root().parallel(y).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
interpolated[l].compute_root().parallel(y).reorder(c, x, y).reorder_storage(c, x, y).vectorize(c, 4)
interpolated[l].unroll(x, 2).unroll(y, 2)
final.reorder(c, x, y).bound(c, 0, 3).parallel(y)
final.tile(x, y, xi, yi, 2, 2).unroll(xi).unroll(yi)
final.bound(x, 0, input.width())
final.bound(y, 0, input.height())
elif sched == 3:
print("Flat schedule with vectorization sometimes.")
for l in range(levels):
if l + 4 < levels:
downsampled[l].compute_root().vectorize(x, 4)
interpolated[l].compute_root().vectorize(x, 4)
else:
downsampled[l].compute_root()
interpolated[l].compute_root()
final.compute_root()
elif sched == 4:
print("GPU schedule.")
# Some gpus don't have enough memory to process the entire
# image, so we process the image in tiles.
yo, yi, xo, xi, ci = hl.Var('yo'), hl.Var('yi'), hl.Var('xo'), hl.Var("xi"), hl.Var("ci")
final.reorder(c, x, y).bound(c, 0, 3).vectorize(x, 4)
final.tile(x, y, xo, yo, xi, yi, input.width() / 4, input.height() / 4)
normalize.compute_at(final, xo).reorder(c, x, y).gpu_tile(x, y, xi, yi, 16, 16).unroll(c)
# Start from level 1 to save memory - level zero will be computed on demand
for l in range(1, levels):
tile_size = 32 >> l
if tile_size < 1: tile_size = 1
if tile_size > 16: tile_size = 16
downsampled[l].compute_root().gpu_tile(x, y, c, xi, yi, ci, tile_size, tile_size, 4)
interpolated[l].compute_at(final, xo).gpu_tile(x, y, c, xi, yi, ci, tile_size, tile_size, 4)
else:
print("No schedule with this number.")
exit(1)
# JIT compile the pipeline eagerly, so we don't interfere with timing
final.compile_jit(target)
return final
def get_input_data():
image_path = os.path.join(os.path.dirname(__file__), "../../apps/images/rgba.png")
assert os.path.exists(image_path), "Could not find %s" % image_path
rgba_data = imageio.imread(image_path)
# input data is in range [0, 1]
input_data = np.copy(rgba_data, order="F").astype(np.float32) / 255.0
return input_data
def main():
input = hl.ImageParam(float_t, 3, "input")
levels = 10
interpolate = get_interpolate(input, levels)
# preparing input and output memory buffers (numpy ndarrays)
input_data = get_input_data()
assert input_data.shape[2] == 4
input_image = hl.Buffer(input_data)
input.set(input_image)
input_width, input_height = input_data.shape[:2]
t0 = datetime.now()
output_image = interpolate.realize([input_width, input_height, 3])
t1 = datetime.now()
elapsed = (t1 - t0).total_seconds()
print('Interpolated in {:.5f} secs'.format(elapsed))
output_data = np.asanyarray(output_image)
# convert output
input_data = (input_data * 255).astype(np.uint8)
output_data = (output_data * 255).astype(np.uint8)
# save results
input_path = "interpolate_input.png"
output_path = "interpolate_result.png"
imageio.imsave(input_path, input_data)
imageio.imsave(output_path, output_data)
print()
print('blur realized on output image. Result saved at {} (input data copy at {})'.format(output_path, input_path))
print()
print("End of game. Have a nice day!")
if __name__ == '__main__':
main()
|