File: strict_float.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (297 lines) | stat: -rw-r--r-- 10,805 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#include "Halide.h"

#include <algorithm>
#include <iomanip>
#include <ios>
#include <iostream>

using namespace Halide;

#if defined(__SSE2__) || defined(__AVX__)
#include <immintrin.h>
#endif

#ifdef __SSE2__
float no_fma_dot_prod_sse(const float *in, int count) {
    __m128 sum = _mm_set1_ps(0.0f);
    const __m128 *in_v = (const __m128 *)in;
    for (int i = 0; i < count / 4; i++) {
        __m128 prod = _mm_mul_ps(in_v[i], in_v[i]);
        sum = _mm_add_ps(prod, sum);
    }
    float *f = (float *)&sum;
    float result = 0.0f;
    for (int i = 0; i < 4; i++) {
        result += f[i];
    }
    return result;
}
#endif

#if defined(__SSE2__) && defined(__FMA__)
float fma_dot_prod_sse(const float *in, int count) {
    __m128 sum = _mm_set1_ps(0.0f);
    const __m128 *in_v = (const __m128 *)in;
    for (int i = 0; i < count / 4; i++) {
        sum = _mm_fmadd_ps(in_v[i], in_v[i], sum);
    }
    float *f = (float *)&sum;
    float result = 0.0f;
    for (int i = 0; i < 4; i++) {
        result += f[i];
    }
    return result;
}
#endif

#if defined(__AVX__)
float no_fma_dot_prod_avx(const float *in, int count) {
    __m256 sum = _mm256_set1_ps(0.0f);
    const __m256 *in_v = (const __m256 *)in;
    for (int i = 0; i < count / 8; i++) {
        __m256 prod = _mm256_mul_ps(in_v[i], in_v[i]);
        sum = _mm256_add_ps(prod, sum);
    }
    float *f = (float *)&sum;
    float result = 0.0f;
    for (int i = 0; i < 8; i++) {
        result += f[i];
    }
    return result;
}
#endif

#if defined(__AVX__) && defined(__FMA__)
float fma_dot_prod_avx(const float *in, int count) {
    __m256 sum = _mm256_set1_ps(0.0f);
    const __m256 *in_v = (const __m256 *)in;
    for (int i = 0; i < count / 8; i++) {
        sum = _mm256_fmadd_ps(in_v[i], in_v[i], sum);
    }
    float *f = (float *)&sum;
    float result = 0.0f;
    for (int i = 0; i < 8; i++) {
        result += f[i];
    }
    return result;
}
#endif

Buffer<float> one_million_rando_floats() {
    Var x("x");
    Func randos;
    randos(x) = random_float();
    return randos.realize({1000000});
}

ImageParam in(Float(32), 1);

Expr term(Expr index) {
    return in(index)*in(index);
}

enum class FloatStrictness {
    Default,
    Strict
} global_strictness = FloatStrictness::Default;

std::string strictness_to_string(FloatStrictness strictness) {
    if (strictness == FloatStrictness::Strict) {
        return "strict_float";
    }
    return "default";
}

Expr apply_strictness(Expr x) {
    if (global_strictness == FloatStrictness::Strict) {
        return strict_float(x);
    }
    return x;
}

template<typename Accum>
Func simple_sum(int vectorize) {
    Func total("total");
    // Can't use rfactor because strict_float is not associative.
    if (vectorize != 0) {
        Func total_inner("total_inner");
        RDom r_outer(0, in.width() / vectorize);
        RDom r_lanes(0, vectorize);
        Var i("i");
        total_inner(i) = cast<Accum>(0);
        total_inner(i) = apply_strictness(total_inner(i) + cast<Accum>(term(r_outer * vectorize + i)));
        total() = cast<Accum>(0);
        total() = apply_strictness(total() + total_inner(r_lanes));
        total_inner.compute_at(total, Var::outermost());
        total_inner.vectorize(i);
        total_inner.update(0).vectorize(i);
    } else {
        RDom r(0, in.width(), "r");

        total() = apply_strictness(cast<Accum>(0));
        total() = apply_strictness(total() + cast<Accum>(term(r)));
    }
#if 0
    if (vectorize != 0) {
        RVar rxo("rxo"), rxi("rxi");
        Var u("u");
        Func intm = total.update(0).split(r, rxo, rxi, vectorize).rfactor({{rxi, u}});
        intm.compute_at(total, Var::outermost());
        intm.vectorize(u, vectorize);
        intm.update(0).vectorize(u, vectorize);
    }
#endif
    return lambda(apply_strictness(cast<float>(total())));
}

Func kahan_sum(int vectorize) {
    // Item 0 of the tuple valued k_sum is the sum and item 1 is an error compensation term.
    // See: https://en.wikipedia.org/wiki/Kahan_summation_algorithm
    Func k_sum("k_sum");

    // rfactor cannot prove associativity for the non-strict formulation and strict_float is not associative.
    if (vectorize != 0) {
        Func k_sum_inner("k_sum_inner");
        RDom r_outer(0, in.width() / vectorize);
        RDom r_lanes(0, vectorize);
        Var i("i");
        k_sum_inner(i) = Tuple(0.0f, 0.0f);
        k_sum_inner(i) = Tuple(apply_strictness(k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])),
                               apply_strictness((k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])) - k_sum_inner(i)[0]) - (term(r_outer * vectorize + i) - k_sum_inner(i)[1]));
        k_sum() = Tuple(0.0f, 0.0f);
        k_sum() = Tuple(apply_strictness(k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])),
                        apply_strictness((k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])) - k_sum()[0]) - (k_sum_inner(r_lanes)[0] - k_sum()[1]));
        k_sum_inner.compute_at(k_sum, Var::outermost());
        k_sum_inner.vectorize(i);
        k_sum_inner.update(0).vectorize(i);
    } else {
        RDom r(0, in.width(), "r");

        k_sum() = Tuple(0.0f, 0.0f);
        k_sum() = Tuple(apply_strictness(k_sum()[0] + (term(r) - k_sum()[1])),
                        apply_strictness((k_sum()[0] + (term(r) - k_sum()[1])) - k_sum()[0]) - (term(r) - k_sum()[1]));
    }

    return lambda(k_sum()[0]);
}

float eval(Func f, const Target &t, const std::string &name, const std::string &suffix, float expected) {
    float val = ((Buffer<float>)f.realize({}, t))();
    std::cout << "        " << name << ": " << val;
    if (expected != 0.0f) {
        std::cout << " residual: " << val - expected;
    }
    std::cout << "\n";
    return val;
}

void run_one_condition(const Target &t, FloatStrictness strictness, Buffer<float> vals) {
    global_strictness = strictness;
    std::string suffix = "_" + t.to_string() + "_" + strictness_to_string(strictness);

    std::cout << "    Target: " << t.to_string() << " Strictness: " << strictness_to_string(strictness) << "\n";

    float simple_double = eval(simple_sum<double>(0), t, "simple_double", suffix, 0.0f);
    float simple_double_vec_4 = eval(simple_sum<double>(4), t, "simple_double_vec_4", suffix, simple_double);
    float simple_double_vec_8 = eval(simple_sum<double>(8), t, "simple_double_vec_8", suffix, simple_double);
    float simple_float = eval(simple_sum<float>(0), t, "simple_float", suffix, simple_double);
    float simple_float_vec_4 = eval(simple_sum<float>(4), t, "simple_float_vec_4", suffix, simple_double);
    float simple_float_vec_8 = eval(simple_sum<float>(8), t, "simple_float_vec_8", suffix, simple_double);
    float kahan = eval(kahan_sum(0), t, "kahan", suffix, simple_double);
    float kahan_vec_4 = eval(kahan_sum(4), t, "kahan_vec_4", suffix, simple_double);
    float kahan_vec_8 = eval(kahan_sum(8), t, "kahan_vec_8", suffix, simple_double);

#ifdef __SSE2__
    float vec_dot_prod_4 = no_fma_dot_prod_sse(&vals(0), vals.width());
    std::cout << "        four wide no fma: " << vec_dot_prod_4 << " residual: " << vec_dot_prod_4 - simple_double << "\n";
#endif

#if defined(__SSE2__) && defined(__FMA__)
    float fma_dot_prod_4 = fma_dot_prod_sse(&vals(0), vals.width());
    std::cout << "        four wide fma: " << fma_dot_prod_4 << " residual: " << fma_dot_prod_4 - simple_double << "\n";
#endif

#if defined(__AVX__)
    float vec_dot_prod_8 = no_fma_dot_prod_avx(&vals(0), vals.width());
    std::cout << "        eight wide no fma: " << vec_dot_prod_8 << " residual: " << vec_dot_prod_8 - simple_double << "\n";
#endif

#if defined(__AVX__) && defined(__FMA__)
    float fma_dot_prod_8 = fma_dot_prod_avx(&vals(0), vals.width());
    std::cout << "        eight wide fma: " << fma_dot_prod_8 << " residual: " << fma_dot_prod_8 - simple_double << "\n";
#endif

    if (strictness == FloatStrictness::Strict) {
        // assert kahan is more accurate than simple method
        assert((fabs(simple_double - kahan) <= fabs(simple_double - simple_float)));
        // assert vecotorized kahan is more accurate than simple method
        assert((fabs(simple_double - kahan_vec_4) <= fabs(simple_double - simple_float)));
        assert((fabs(simple_double - kahan_vec_8) <= fabs(simple_double - simple_float)));
        // Just use some vars for now.
        assert(simple_double_vec_4 != 0 && simple_double_vec_8 != 0 && simple_float_vec_4 != 0 && simple_float_vec_8 != 0);
    }
}

void run_all_conditions(const char *name, Buffer<float> &vals) {
    std::cout << "Running on " << name << " data:\n";

    Target loose{get_jit_target_from_environment().without_feature(Target::StrictFloat)};
    Target strict{loose.with_feature(Target::StrictFloat)};

    run_one_condition(loose, FloatStrictness::Default, vals);
    run_one_condition(strict, FloatStrictness::Default, vals);
    run_one_condition(loose, FloatStrictness::Strict, vals);
    run_one_condition(strict, FloatStrictness::Strict, vals);
}

Buffer<float> block_transposed_by_n(Buffer<float> &buf, int vectorize) {
    Buffer<float> result(buf.width());

    int block_size = buf.width() / vectorize;
    for (int32_t i = 0; i < block_size; i++) {
        for (int32_t j = 0; j < vectorize; j++) {
            result(i * vectorize + j) = buf(j * block_size + i);
        }
    }
    return result;
}

int main(int argc, char **argv) {
    std::cout << std::setprecision(10);
    Buffer<float> vals = one_million_rando_floats();
    Buffer<float> transposed;
    in.set(vals);
    // Clean up stmt file by asserting clean division. Also eliminates needing boundary conditions.
    in.dim(0).set_bounds(0, 1000000);

    // Random data, average case for error.
    run_all_conditions("random", vals);
    transposed = block_transposed_by_n(vals, 4);
    in.set(transposed);
    run_all_conditions("random transposed", transposed);

    // Originally the comments stipulated that ascending
    // was best case and descending was worst case, neither
    // of which are strictly true. Main idea is to compare
    // the relative error of two significantly different orders.

    // Ascending.
    std::sort(vals.begin(), vals.end());
    in.set(vals);
    run_all_conditions("sorted ascending", vals);
    transposed = block_transposed_by_n(vals, 4);
    in.set(transposed);
    run_all_conditions("sorted ascending transposed", transposed);

    // Descending.
    std::sort(vals.begin(), vals.end(), std::greater<float>());
    in.set(vals);
    run_all_conditions("sorted descending", vals);
    transposed = block_transposed_by_n(vals, 4);
    in.set(transposed);
    run_all_conditions("sorted descending transposed", transposed);

    printf("Success!\n");

    return 0;
}