1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
|
#include "Halide.h"
#include <limits>
#include <stdio.h>
#ifdef _MSC_VER
#pragma warning(disable : 4800) // forcing value to bool 'true' or 'false'
#endif
using namespace Halide;
Var zero_val, one_val, weight;
template<typename weight_t>
double weight_type_scale() {
if (std::numeric_limits<weight_t>::is_integer)
return std::numeric_limits<weight_t>::max();
else
return static_cast<weight_t>(1.0);
}
template<typename value_t>
double conversion_rounding() {
if (std::numeric_limits<value_t>::is_integer)
return 0.5;
else
return 0.0;
}
template<typename value_t>
value_t convert_to_value(double interpolated) {
return static_cast<value_t>(interpolated);
}
template<>
bool convert_to_value<bool>(double interpolated) {
return interpolated >= 1; // Already has rounding added in
}
// Prevent iostream from printing 8-bit numbers as character constants.
template<typename t>
struct promote_if_char {
typedef t promoted;
};
template<>
struct promote_if_char<signed char> {
typedef int32_t promoted;
};
template<>
struct promote_if_char<unsigned char> {
typedef int32_t promoted;
};
template<typename value_t>
bool relatively_equal(value_t a, value_t b) {
if (a == b) {
return true;
} else if (!std::numeric_limits<value_t>::is_integer) {
double da = (double)a, db = (double)b;
double relative_error;
// This test seems a bit high.
if (fabs(db - da) < .0001)
return true;
if (fabs(da) > fabs(db))
relative_error = fabs((db - da) / da);
else
relative_error = fabs((db - da) / db);
if (relative_error < .0000002)
return true;
std::cerr << "relatively_equal failed for (" << a << ", " << b << ") "
<< "with relative error " << relative_error << "\n";
}
return false;
}
template<typename value_t, typename weight_t>
void check_range(int32_t zero_min, int32_t zero_extent, value_t zero_offset, value_t zero_scale,
int32_t one_min, int32_t one_extent, value_t one_offset, value_t one_scale,
int32_t weight_min, int32_t weight_extent, weight_t weight_offset, weight_t weight_scale,
const char *name) {
// Stuff everything in Params as these can represent uint32_t where
// that fails in converting to Expr is we just use the raw C++ variables.
Param<value_t> zero_scale_p, zero_offset_p;
zero_scale_p.set(zero_scale);
zero_offset_p.set(zero_offset);
Param<value_t> one_scale_p, one_offset_p;
one_scale_p.set(one_scale);
one_offset_p.set(one_offset);
Param<weight_t> weight_scale_p, weight_offset_p;
weight_scale_p.set(weight_scale);
weight_offset_p.set(weight_offset);
Func lerp_test("lerp_test");
lerp_test(zero_val, one_val, weight) =
lerp(cast<value_t>((zero_val + zero_min) * zero_scale_p + zero_offset_p),
cast<value_t>((one_val + one_min) * one_scale_p + one_offset_p),
cast<weight_t>((weight + weight_min) * weight_scale_p + weight_offset_p));
Buffer<value_t> result(zero_extent, one_extent, weight_extent);
lerp_test.realize(result);
for (int32_t i = 0; i < result.extent(0); i++) {
for (int32_t j = 0; j < result.extent(1); j++) {
for (int32_t k = 0; k < result.extent(2); k++) {
value_t zero_verify = ((i + zero_min) * zero_scale + zero_offset);
value_t one_verify = ((j + one_min) * one_scale + one_offset);
weight_t weight_verify = (weight_t)((k + weight_min) * weight_scale + weight_offset);
double actual_weight = weight_verify / weight_type_scale<weight_t>();
double verify_val_full = zero_verify * (1.0 - actual_weight) + one_verify * actual_weight;
if (verify_val_full < 0)
verify_val_full -= conversion_rounding<value_t>();
else
verify_val_full += conversion_rounding<value_t>();
value_t verify_val = convert_to_value<value_t>(verify_val_full);
value_t computed_val = result(i, j, k);
if (!relatively_equal(verify_val, computed_val)) {
std::cerr << "Expected "
<< (typename promote_if_char<value_t>::promoted)(verify_val)
<< " got " << (typename promote_if_char<value_t>::promoted)(computed_val)
<< " for lerp(" << (typename promote_if_char<value_t>::promoted)(zero_verify)
<< ", " << (typename promote_if_char<value_t>::promoted)(one_verify)
<< ", " << (typename promote_if_char<weight_t>::promoted)(weight_verify)
<< ") " << actual_weight << ". " << name << "\n";
assert(false);
}
}
}
}
}
int main(int argc, char **argv) {
// Test bool
check_range<bool, uint8_t>(0, 2, 0, 1,
0, 2, 0, 1,
0, 256, 0, 1,
"<bool, uint8_t> exhaustive");
// Exhaustively test 8-bit cases
check_range<uint8_t, uint8_t>(0, 256, 0, 1,
0, 256, 0, 1,
0, 256, 0, 1,
"<uint8_t, uint8_t> exhaustive");
check_range<int8_t, uint8_t>(0, 256, -128, 1,
0, 256, -128, 1,
0, 256, 0, 1,
"<int8_t, uint8_t> exhaustive");
check_range<uint8_t, float>(0, 256, 0, 1,
0, 256, 0, 1,
0, 256, 0, 1 / 255.0f,
"<uint8_t, float> exhaustive");
check_range<int8_t, float>(0, 256, -128, 1,
0, 256, -128, 1,
0, 256, 0, 1 / 255.0f,
"<int8_t, float> exhaustive");
// Check all delta values for 16-bit, verify swapping arguments doesn't break
check_range<uint16_t, uint16_t>(0, 65536, 0, 1,
65535, 1, 0, 1,
0, 257, 255, 1,
"<uint16_t, uint16_t> all zero starts");
check_range<uint16_t, uint16_t>(65535, 1, 0, 1,
0, 65536, 0, 1,
0, 257, 255, 1,
"<uint16_t, uint16_t> all one starts");
// Verify different bit sizes for value and weight types
check_range<uint16_t, uint8_t>(0, 1, 0, 1,
65535, 1, 0, 1,
0, 255, 1, 1,
"<uint16_t, uint8_t> zero, one uint8_t weight test");
check_range<uint16_t, uint32_t>(0, 1, 0, 1,
65535, 1, 0, 1,
std::numeric_limits<int32_t>::min(), 257, 255 * 65535, 1,
"<uint16_t, uint8_t> zero, one uint32_t weight test");
check_range<uint32_t, uint8_t>(0, 1, 0, 1,
0x80000000, 1, 0, 1,
0, 255, 0, 1,
"<uint32_t, uint8_t> weight test");
check_range<uint32_t, uint16_t>(0, 1, 0, 1,
0x80000000, 1, 0, 1,
0, 65535, 0, 1,
"<uint32_t, uint16_t> weight test");
// Verify float weights with integer values
check_range<uint16_t, float>(0, 1, 0, 1,
65535, 1, 0, 1,
0, 257, 0, 255.0f / 65535.0f,
"<uint16_t, float> zero, one float weight test");
check_range<int16_t, uint16_t>(0, 65536, -32768, 1,
0, 1, 0, 1,
0, 257, 0, 255,
"<int16_t, uint16_t> all zero starts");
#if 0 // takes too long, difficult to test with uint32_t
// Check all delta values for 32-bit, do it in signed arithmetic
check_range<int32_t, uint32_t>(std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max(), 0, 1,
0x80000000, 1, 0, 1,
0, 1, 0x80000000, 1,
"<uint32_t, uint32_t> all zero starts");
#endif
check_range<float, float>(0, 100, 0, .01f,
0, 100, 0, .01f,
0, 100, 0, .01f,
"<float, float> float values 0 to 1 by 1/100ths");
check_range<float, float>(0, 100, -5, .1f,
0, 100, 0, .1f,
0, 100, 0, .1f,
"<float, float> float values -5 to 5 by 1/100ths");
// Verify float values with integer weights
check_range<float, uint8_t>(0, 100, -5, .1f,
0, 100, 0, .1f,
0, 255, 0, 1,
"<float, uint8_t> float values -5 to 5 by 1/100ths");
check_range<float, uint16_t>(0, 100, -5, .1f,
0, 100, 0, .1f,
0, 255, 0, 257,
"<float, uint16_t> float values -5 to 5 by 1/100ths");
check_range<float, uint32_t>(0, 100, -5, .1f,
0, 100, 0, .1f,
std::numeric_limits<int32_t>::min(), 257, 255 * 65535, 1,
"<float, uint32_t> float values -5 to 5 by 1/100ths");
// Check constant and constant case:
Func lerp_constants("lerp_constants");
lerp_constants() = lerp(0, cast<uint32_t>(1023), .5f);
Buffer<uint32_t> result = lerp_constants.realize();
uint32_t expected = evaluate<uint32_t>(cast<uint32_t>(lerp(0, cast<uint16_t>(1023), .5f)));
if (result(0) != expected) {
std::cerr << "Expected " << expected << " got " << result(0) << "\n";
}
assert(result(0) == expected);
// Add a little more coverage for uint32_t as this was failing
// without being detected for a long time.
Buffer<uint8_t> input_a_img(16, 16);
Buffer<uint8_t> input_b_img(16, 16);
for (int i = 0; i < 16; i++) {
for (int j = 0; j < 16; j++) {
input_a_img(i, j) = (i << 4) + j;
input_b_img(i, j) = ((15 - i) << 4) + (15 - j);
}
}
ImageParam input_a(UInt(8), 2);
ImageParam input_b(UInt(8), 2);
Var x, y;
Func lerp_with_casts;
Param<float> w;
lerp_with_casts(x, y) = lerp(cast<int32_t>(input_a(x, y)), cast<int32_t>(input_b(x, y)), w);
lerp_with_casts.vectorize(x, 4);
input_a.set(input_a_img);
input_b.set(input_b_img);
w.set(0.0f);
Buffer<int32_t> result_should_be_a = lerp_with_casts.realize({16, 16});
w.set(1.0f);
Buffer<int32_t> result_should_be_b = lerp_with_casts.realize({16, 16});
for (int i = 0; i < 16; i++) {
for (int j = 0; j < 16; j++) {
assert(input_a_img(i, j) == result_should_be_a(i, j));
assert(input_b_img(i, j) == result_should_be_b(i, j));
}
}
std::cout << "Success!\n";
}
|