File: test_qr.cpp

package info (click to toggle)
xtensor-blas 0.23.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 13,860 kB
  • sloc: cpp: 98,000; makefile: 201; perl: 178; python: 153
file content (449 lines) | stat: -rw-r--r-- 18,067 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
/***************************************************************************
 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
 * Copyright (c) QuantStack                                                 *
 *                                                                          *
 * Distributed under the terms of the BSD 3-Clause License.                 *
 *                                                                          *
 * The full license is in the file LICENSE, distributed with this software. *
 ****************************************************************************/
// This file is generated from test/files/cppy_source/test_qr.cppy by preprocess.py!


#include <algorithm>

#include "xtensor/containers/xarray.hpp"
#include "xtensor/containers/xfixed.hpp"
#include "xtensor/containers/xtensor.hpp"
#include "xtensor/core/xnoalias.hpp"
#include "xtensor/views/xstrided_view.hpp"
#include "xtensor/views/xview.hpp"

#include "doctest/doctest.h"
#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    using namespace xt::placeholders;

    TEST_SUITE("xqr_extended")
    {
        /*py
        a = np.random.random((6, 3))
        res_q1 = np.linalg.qr(a, 'raw')
        res_q2 = np.linalg.qr(a, 'complete')
        res_q3 = np.linalg.qr(a, 'reduced')
        res_q4 = np.linalg.qr(a, 'r')
        */
        TEST_CASE("qr1")
        {
            // py_a
            xarray<double> py_a = {
                {0.3745401188473625, 0.9507143064099162, 0.7319939418114051},
                {0.5986584841970366, 0.1560186404424365, 0.1559945203362026},
                {0.0580836121681995, 0.8661761457749352, 0.6011150117432088},
                {0.7080725777960455, 0.0205844942958024, 0.9699098521619943},
                {0.8324426408004217, 0.2123391106782762, 0.1818249672071006},
                {0.1834045098534338, 0.3042422429595377, 0.5247564316322378}
            };
            // py_resq1_h = res_q1[0]
            xarray<double> py_resq1_h = {
                {-1.3152987216651169,
                 0.3542695728401418,
                 0.0343722790456067,
                 0.4190178144924799,
                 0.4926165861757361,
                 0.1085337284576868},
                {-0.567877094797874,
                 1.2223138676385652,
                 -0.507377563354501,
                 0.3838046167052855,
                 0.3339455785740943,
                 -0.0869071101793681},
                {-1.0163710885529547,
                 0.7215655008695085,
                 0.7854784971183754,
                 -0.8184018010449026,
                 0.3355103841692942,
                 -0.2743559826773575}
            };
            // py_resq1_tau = res_q1[1]
            xarray<double> py_resq1_tau = {1.2847566964660388, 1.3124991842889797, 1.0766465015522177};

            auto res1 = linalg::qr(py_a, linalg::qrmode::raw);
            CHECK(allclose(std::get<0>(res1), py_resq1_h));
            CHECK(allclose(std::get<1>(res1), py_resq1_tau));

            // py_resq2_q_cmpl = res_q2[0]
            xarray<double> py_resq2_q_cmpl = {
                {-0.2847566964660388,
                 0.6455031901264903,
                 -0.0295327810119745,
                 -0.5849049416686276,
                 -0.0730618203174815,
                 -0.3923203408230155},
                {-0.4551502060605353,
                 -0.0838170448559192,
                 -0.3133472182914375,
                 0.0819245453270296,
                 -0.7892351407115688,
                 0.2408791714587238},
                {-0.0441600156766425,
                 0.6881200538051697,
                 0.0760152664601147,
                 0.7143224973945713,
                 0.0235700722943726,
                 0.0891638112668339},
                {-0.538335943107778,
                 -0.2332659103773061,
                 0.7525061466150681,
                 0.1447692100263398,
                 -0.0279639819291247,
                 -0.2603378924852559},
                {-0.6328924578795164,
                 -0.1203177215897514,
                 -0.4769214096589271,
                 0.1040507467269484,
                 0.5878955555305321,
                 0.0326957112268428},
                {-0.1394394344284399,
                 0.1841243791750922,
                 0.3185019359677401,
                 -0.330353243868553,
                 0.1575155429538277,
                 0.8433664457979998}
            };
            // py_resq2_r_cmpl = res_q2[1]
            xarray<double> py_resq2_r_cmpl = {
                {-1.3152987216651169, -0.567877094797874, -1.0163710885529547},
                {0., 1.2223138676385652, 0.7215655008695085},
                {0., 0., 0.7854784971183754},
                {0., 0., 0.},
                {0., 0., 0.},
                {0., 0., 0.}
            };

            auto res2 = linalg::qr(py_a, linalg::qrmode::complete);
            CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl));
            CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl));

            // py_resq3_q_cmpl = res_q3[0]
            xarray<double> py_resq3_q_cmpl = {
                {-0.2847566964660388, 0.6455031901264903, -0.0295327810119745},
                {-0.4551502060605353, -0.0838170448559192, -0.3133472182914375},
                {-0.0441600156766425, 0.6881200538051697, 0.0760152664601147},
                {-0.538335943107778, -0.2332659103773061, 0.7525061466150681},
                {-0.6328924578795164, -0.1203177215897514, -0.4769214096589271},
                {-0.1394394344284399, 0.1841243791750922, 0.3185019359677401}
            };
            // py_resq3_r_cmpl = res_q3[1]
            xarray<double> py_resq3_r_cmpl = {
                {-1.3152987216651169, -0.567877094797874, -1.0163710885529547},
                {0., 1.2223138676385652, 0.7215655008695085},
                {0., 0., 0.7854784971183754}
            };

            auto res3 = linalg::qr(py_a, linalg::qrmode::reduced);
            CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl));
            CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl));

            // py_resq4_r_r = res_q4
            xarray<double> py_resq4_r_r = {
                {-1.3152987216651169, -0.567877094797874, -1.0163710885529547},
                {0., 1.2223138676385652, 0.7215655008695085},
                {0., 0., 0.7854784971183754}
            };

            auto res4 = linalg::qr(py_a, linalg::qrmode::r);
            CHECK(allclose(std::get<1>(res4), py_resq4_r_r));
        }

        /*py
        a = np.random.random((5, 10))
        res_q1 = np.linalg.qr(a, 'raw')
        res_q2 = np.linalg.qr(a, 'complete')
        res_q3 = np.linalg.qr(a, 'reduced')
        res_q4 = np.linalg.qr(a, 'r')
        */
        TEST_CASE("qr2")
        {
            // py_a
            xarray<double> py_a = {
                {0.4319450186421158,
                 0.2912291401980419,
                 0.6118528947223795,
                 0.1394938606520418,
                 0.2921446485352182,
                 0.3663618432936917,
                 0.4560699842170359,
                 0.7851759613930136,
                 0.1996737821583597,
                 0.5142344384136116},
                {0.5924145688620425,
                 0.0464504127199977,
                 0.6075448519014384,
                 0.1705241236872915,
                 0.0650515929852795,
                 0.9488855372533332,
                 0.9656320330745594,
                 0.8083973481164611,
                 0.3046137691733707,
                 0.0976721140063839},
                {0.6842330265121569,
                 0.4401524937396013,
                 0.1220382348447788,
                 0.4951769101112702,
                 0.0343885211152184,
                 0.9093204020787821,
                 0.2587799816000169,
                 0.662522284353982,
                 0.311711076089411,
                 0.5200680211778108},
                {0.5467102793432796,
                 0.184854455525527,
                 0.9695846277645586,
                 0.7751328233611146,
                 0.9394989415641891,
                 0.8948273504276488,
                 0.5978999788110851,
                 0.9218742350231168,
                 0.0884925020519195,
                 0.1959828624191452},
                {0.0452272889105381,
                 0.3253303307632643,
                 0.388677289689482,
                 0.2713490317738959,
                 0.8287375091519293,
                 0.3567533266935893,
                 0.2809345096873808,
                 0.5426960831582485,
                 0.1409242249747626,
                 0.8021969807540397}
            };
            // py_resq1_h = res_q1[0]
            xarray<double> py_resq1_h = {
                {-1.1430852952870696, 0.3761289948662397, 0.4344253062693247, 0.3471109568548026, 0.0287151863113738
                },
                {-0.4988738747365855, 0.4145384440977923, -0.1456730968857619, 0.1343802288038164, -0.4549175132696515
                },
                {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577, -0.2697221220857602, -0.2118640849148782
                },
                {-0.8189559577243967, 0.2159221672678355, 0.2467828455102149, -0.4358731022610104, 0.0126894274012749
                },
                {-0.6468222288756241, 0.5399745339753012, 0.9011434603476536, -0.351682869414533, 0.120561296448323
                },
                {-1.6166030169206462, 0.0627336303098122, 0.1745159258713337, -0.1676233275811678, 0.3369911999240203
                },
                {-1.1247642047094615, -0.1631138338388989, 0.4469666475320985, 0.2296736319774871, 0.3155802843315489
                },
                {-1.5746170823854422, 0.2876936477590399, 0.5186696050660639, 0.0972324032495857, 0.1124970816045022
                },
                {-0.4678059691956431, 0.0924634343088704, -0.0398310260167535, 0.1199094213119632, 0.1189824829973467
                },
                {-0.6817147826175952, 0.820970464835294, 0.1936105292921997, 0.1556371881989975, 0.1610633542281176}
            };
            // py_resq1_tau = res_q1[1]
            xarray<double> py_resq1_tau =
                {1.3778764545594464, 1.604841948190939, 1.7894907284949315, 1.9996780087119976, 0.};

            auto res1 = linalg::qr(py_a, linalg::qrmode::raw);
            CHECK(allclose(std::get<0>(res1), py_resq1_h));
            CHECK(allclose(std::get<1>(res1), py_resq1_tau));
            // py_resq2_q_cmpl = res_q2[0]
            xarray<double> py_resq2_q_cmpl = {
                {-0.3778764545594464, 0.2477850983490846, 0.2323946032026168, 0.6442783657634201, -0.571585571841376
                },
                {-0.5182592859033026, -0.5116427882060656, 0.0755411199296714, 0.3718390470559858, 0.570664728309004
                },
                {-0.5985844007732788, 0.3414264138444293, -0.6868036045376356, -0.2311050279755065, 0.0039992397751237
                },
                {-0.4782760145698324, -0.1296501056095487, 0.5617369601749854, -0.625900329033334, -0.2171250093654842
                },
                {-0.0395659791067296, 0.7371859035261209, 0.3912015899373973, 0.0384753772446493, 0.5481536630508248}
            };
            // py_resq2_r_cmpl = res_q2[1]
            xarray<double> py_resq2_r_cmpl = {
                {-1.1430852952870696,
                 -0.4988738747365855,
                 -1.0982282164248067,
                 -0.8189559577243967,
                 -0.6468222288756241,
                 -1.6166030169206462,
                 -1.1247642047094615,
                 -1.5746170823854422,
                 -0.4678059691956431,
                 -0.6817147826175952},
                {0.,
                 0.4145384440977923,
                 0.0432498341745755,
                 0.2159221672678355,
                 0.5399745339753012,
                 0.0627336303098122,
                 -0.1631138338388989,
                 0.2876936477590399,
                 0.0924634343088704,
                 0.820970464835294},
                {0.,
                 0.,
                 0.8009723247566577,
                 0.2467828455102149,
                 0.9011434603476536,
                 0.1745159258713337,
                 0.4469666475320985,
                 0.5186696050660639,
                 -0.0398310260167535,
                 0.1936105292921997},
                {0.,
                 0.,
                 0.,
                 -0.4358731022610104,
                 -0.351682869414533,
                 -0.1676233275811678,
                 0.2296736319774871,
                 0.0972324032495857,
                 0.1199094213119632,
                 0.1556371881989975},
                {0.,
                 0.,
                 0.,
                 0.,
                 0.120561296448323,
                 0.3369911999240203,
                 0.3155802843315489,
                 0.1124970816045022,
                 0.1189824829973467,
                 0.1610633542281176}
            };

            auto res2 = linalg::qr(py_a, linalg::qrmode::complete);
            CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl));
            CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl));

            // py_resq3_q_cmpl = res_q3[0]
            xarray<double> py_resq3_q_cmpl = {
                {-0.3778764545594464, 0.2477850983490846, 0.2323946032026168, 0.6442783657634201, -0.571585571841376
                },
                {-0.5182592859033026, -0.5116427882060656, 0.0755411199296714, 0.3718390470559858, 0.570664728309004
                },
                {-0.5985844007732788, 0.3414264138444293, -0.6868036045376356, -0.2311050279755065, 0.0039992397751237
                },
                {-0.4782760145698324, -0.1296501056095487, 0.5617369601749854, -0.625900329033334, -0.2171250093654842
                },
                {-0.0395659791067296, 0.7371859035261209, 0.3912015899373973, 0.0384753772446493, 0.5481536630508248}
            };
            // py_resq3_r_cmpl = res_q3[1]
            xarray<double> py_resq3_r_cmpl = {
                {-1.1430852952870696,
                 -0.4988738747365855,
                 -1.0982282164248067,
                 -0.8189559577243967,
                 -0.6468222288756241,
                 -1.6166030169206462,
                 -1.1247642047094615,
                 -1.5746170823854422,
                 -0.4678059691956431,
                 -0.6817147826175952},
                {0.,
                 0.4145384440977923,
                 0.0432498341745755,
                 0.2159221672678355,
                 0.5399745339753012,
                 0.0627336303098122,
                 -0.1631138338388989,
                 0.2876936477590399,
                 0.0924634343088704,
                 0.820970464835294},
                {0.,
                 0.,
                 0.8009723247566577,
                 0.2467828455102149,
                 0.9011434603476536,
                 0.1745159258713337,
                 0.4469666475320985,
                 0.5186696050660639,
                 -0.0398310260167535,
                 0.1936105292921997},
                {0.,
                 0.,
                 0.,
                 -0.4358731022610104,
                 -0.351682869414533,
                 -0.1676233275811678,
                 0.2296736319774871,
                 0.0972324032495857,
                 0.1199094213119632,
                 0.1556371881989975},
                {0.,
                 0.,
                 0.,
                 0.,
                 0.120561296448323,
                 0.3369911999240203,
                 0.3155802843315489,
                 0.1124970816045022,
                 0.1189824829973467,
                 0.1610633542281176}
            };

            auto res3 = linalg::qr(py_a, linalg::qrmode::reduced);
            CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl));
            CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl));

            // py_resq4_r_r = res_q4
            xarray<double> py_resq4_r_r = {
                {-1.1430852952870696,
                 -0.4988738747365855,
                 -1.0982282164248067,
                 -0.8189559577243967,
                 -0.6468222288756241,
                 -1.6166030169206462,
                 -1.1247642047094615,
                 -1.5746170823854422,
                 -0.4678059691956431,
                 -0.6817147826175952},
                {0.,
                 0.4145384440977923,
                 0.0432498341745755,
                 0.2159221672678355,
                 0.5399745339753012,
                 0.0627336303098122,
                 -0.1631138338388989,
                 0.2876936477590399,
                 0.0924634343088704,
                 0.820970464835294},
                {0.,
                 0.,
                 0.8009723247566577,
                 0.2467828455102149,
                 0.9011434603476536,
                 0.1745159258713337,
                 0.4469666475320985,
                 0.5186696050660639,
                 -0.0398310260167535,
                 0.1936105292921997},
                {0.,
                 0.,
                 0.,
                 -0.4358731022610104,
                 -0.351682869414533,
                 -0.1676233275811678,
                 0.2296736319774871,
                 0.0972324032495857,
                 0.1199094213119632,
                 0.1556371881989975},
                {0.,
                 0.,
                 0.,
                 0.,
                 0.120561296448323,
                 0.3369911999240203,
                 0.3155802843315489,
                 0.1124970816045022,
                 0.1189824829973467,
                 0.1610633542281176}
            };

            auto res4 = linalg::qr(py_a, linalg::qrmode::r);
            CHECK(allclose(std::get<1>(res4), py_resq4_r_r));
        }
    }
}