File: test_extended_xsort.cpp

package info (click to toggle)
xtensor 0.25.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,476 kB
  • sloc: cpp: 65,302; makefile: 202; python: 171; javascript: 8
file content (847 lines) | stat: -rw-r--r-- 41,509 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
/***************************************************************************
 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
 * Copyright (c) QuantStack                                                 *
 *                                                                          *
 * Distributed under the terms of the BSD 3-Clause License.                 *
 *                                                                          *
 * The full license is in the file LICENSE, distributed with this software. *
 ****************************************************************************/
// This file is generated from test/files/cppy_source/test_extended_xsort.cppy by preprocess.py!
// Warning: This file should not be modified directly! Instead, modify the `*.cppy` file.


#include <xtensor/xarray.hpp>
#include <xtensor/xio.hpp>
#include <xtensor/xmath.hpp>
#include <xtensor/xsort.hpp>
#include <xtensor/xview.hpp>

#include "test_common_macros.hpp"

namespace xt
{
    using namespace xt::placeholders;

    template <class T>
    bool check_partition_equal(const T& a1, const T& a2, std::size_t kth)
    {
        auto p = a1[kth];
        EXPECT_EQ(p, a2[kth]);

        for (std::size_t i = 0; i < kth; ++i)
        {
            EXPECT_TRUE(a1[i] < p);
            EXPECT_TRUE(a2[i] < p);
        }

        for (std::size_t i = kth + 1; i < a1.size(); ++i)
        {
            EXPECT_TRUE(p < a1[i]);
            EXPECT_TRUE(p < a2[i]);
        }
        return true;
    }

    template <class X, class Y, class Z>
    bool check_argpartition_equal(const X& data, const Y& a1, const Z& a2, std::size_t kth)
    {
        auto p = static_cast<std::size_t>(a1[kth]);
        EXPECT_EQ(p, std::size_t(a2[kth]));
        auto el = data[static_cast<std::size_t>(a1[kth])];
        for (std::size_t i = 0; i < kth; ++i)
        {
            EXPECT_TRUE(data[static_cast<std::size_t>(a1[i])] < el);
            EXPECT_TRUE(data[static_cast<std::size_t>(a2[i])] < el);
        }

        for (std::size_t i = kth + std::size_t(1); i < a1.size(); ++i)
        {
            EXPECT_TRUE(el < data[static_cast<std::size_t>(a1[i])]);
            EXPECT_TRUE(el < data[static_cast<std::size_t>(a2[i])]);
        }
        return true;
    }

    /*py
    a = np.random.randint(0, 1000, size=(20,))
    */
    TEST(xtest_extended, partition)
    {
        // py_a
        xarray<long> py_a = {102, 435, 860, 270, 106, 71,  700, 20,  614, 121,
                             466, 214, 330, 458, 87,  372, 99,  871, 663, 130};

        // py_p5 = np.partition(a, 5)
        xarray<long> py_p5 = {20,  71,  87,  99,  102, 106, 121, 700, 614, 435,
                              466, 214, 330, 458, 270, 372, 860, 871, 663, 130};
        // py_p0 = np.partition(a, 0)
        xarray<long> py_p0 = {20,  435, 860, 270, 106, 71,  700, 102, 614, 121,
                              466, 214, 330, 458, 87,  372, 99,  871, 663, 130};
        // py_p13 = np.partition(a, 13)
        xarray<long> py_p13 = {20,  102, 99,  87,  106, 71,  121, 270, 130, 435,
                               372, 214, 330, 458, 614, 466, 860, 871, 663, 700};
        // py_p19 = np.partition(a, 19)
        xarray<long> py_p19 = {20,  102, 99,  87,  106, 71,  121, 270, 130, 435,
                               372, 214, 330, 458, 663, 614, 466, 700, 860, 871};

        // py_a5 = np.argpartition(a, 5)
        xarray<long> py_a5 = {7, 5, 14, 16, 0, 4, 9, 6, 8, 1, 10, 11, 12, 13, 3, 15, 2, 17, 18, 19};
        // py_a0 = np.argpartition(a, 0)
        xarray<long> py_a0 = {7, 1, 2, 3, 4, 5, 6, 0, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19};
        // py_a13 = np.argpartition(a, 13)
        xarray<long> py_a13 = {7, 0, 16, 14, 4, 5, 9, 3, 19, 1, 15, 11, 12, 13, 8, 10, 2, 17, 18, 6};
        // py_a19 = np.argpartition(a, 19)
        xarray<long> py_a19 = {7, 0, 16, 14, 4, 5, 9, 3, 19, 1, 15, 11, 12, 13, 18, 8, 10, 6, 2, 17};

        auto part_a0 = xt::partition(py_a, 0);

        check_partition_equal(py_p0, part_a0, 0);
        check_partition_equal(py_p5, xt::partition(py_a, 5), 5);
        check_partition_equal(py_p13, xt::partition(py_a, 13), 13);
        check_partition_equal(py_p19, xt::partition(py_a, 19), 19);

        auto parta_a0 = xt::argpartition(py_a, 0);
        check_argpartition_equal(py_a, py_a0, parta_a0, 0);
        check_argpartition_equal(py_a, py_a5, xt::argpartition(py_a, 5), 5);
        check_argpartition_equal(py_a, py_a13, xt::argpartition(py_a, 13), 13);
        check_argpartition_equal(py_a, py_a19, xt::argpartition(py_a, 19), 19);

        // py_median = np.median(a)
        double py_median = 300.0;
        EXPECT_EQ(static_cast<decltype(py_a)::value_type>(py_median), xt::median(py_a));
    }

    /*py
    a = np.random.randint(0, 20, size=(20,))
    */
    TEST(xtest_extended, multi_partition)
    {
        // py_a
        xarray<long> py_a = {1, 11, 5, 1, 0, 11, 11, 16, 9, 15, 14, 14, 18, 11, 19, 2, 4, 18, 6, 8};

        // py_p0 = np.partition(a, (4, 5, 6))
        xarray<long> py_p0 = {1, 1, 0, 2, 4, 5, 6, 8, 9, 11, 14, 14, 18, 11, 19, 16, 11, 18, 11, 15};
        // py_p1 = np.partition(a, (2, 7, 12))
        xarray<long> py_p1 = {0, 1, 1, 2, 4, 5, 6, 8, 9, 11, 11, 11, 11, 15, 19, 16, 14, 18, 18, 14};

        auto part_p0 = xt::partition(py_a, {4, 5, 6});
        auto part_p1 = xt::partition(py_a, {2, 7, 12});

        EXPECT_EQ(part_p0(4), py_p0(4));
        EXPECT_EQ(part_p0(5), py_p0(5));
        EXPECT_EQ(part_p0(6), py_p0(6));

        EXPECT_EQ(part_p1(2), py_p1(2));
        EXPECT_EQ(part_p1(7), py_p1(7));
        EXPECT_EQ(part_p1(12), py_p1(12));

        // py_a0 = np.argpartition(a, (4, 5, 6))
        xarray<long> py_a0 = {0, 3, 4, 15, 16, 2, 18, 19, 8, 1, 10, 11, 12, 13, 14, 7, 6, 17, 5, 9};
        // py_a1 = np.argpartition(a, (2, 7, 12))
        xarray<long> py_a1 = {4, 3, 0, 15, 16, 2, 18, 19, 8, 13, 1, 6, 5, 9, 14, 7, 10, 17, 12, 11};

        auto part_a0 = xt::argpartition(py_a, {4, 5, 6});
        auto part_a1 = xt::argpartition(py_a, {2, 7, 12});

        EXPECT_EQ(py_a[part_a0(4)], py_a[static_cast<std::size_t>(py_a0(4))]);
        EXPECT_EQ(py_a[part_a0(5)], py_a[static_cast<std::size_t>(py_a0(5))]);
        EXPECT_EQ(py_a[part_a0(6)], py_a[static_cast<std::size_t>(py_a0(6))]);

        EXPECT_EQ(py_a[part_a1(2)], py_a[static_cast<std::size_t>(py_a1(2))]);
        EXPECT_EQ(py_a[part_a1(7)], py_a[static_cast<std::size_t>(py_a1(7))]);
        EXPECT_EQ(py_a[part_a1(12)], py_a[static_cast<std::size_t>(py_a1(12))]);
    }

    /*py
    a = np.random.rand(4, 5, 6)
    */
    TEST(xtest_extended, quantile_axis)
    {
        // py_a
        xarray<double> py_a = {
            {{0.0650515929852795, 0.9488855372533332, 0.9656320330745594, 0.8083973481164611, 0.3046137691733707, 0.0976721140063839
             },
             {0.6842330265121569, 0.4401524937396013, 0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821
             },
             {0.2587799816000169, 0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527
             },
             {0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168
             },
             {0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643, 0.388677289689482, 0.2713490317738959
             }},

            {{0.8287375091519293, 0.3567533266935893, 0.2809345096873808, 0.5426960831582485, 0.1409242249747626, 0.8021969807540397
             },
             {0.0745506436797708, 0.9868869366005173, 0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342
             },
             {0.7068573438476171, 0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297
             },
             {0.8631034258755935, 0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622, 0.325183322026747
             },
             {0.7296061783380641, 0.6375574713552131, 0.8872127425763265, 0.4722149251619493, 0.1195942459383017, 0.713244787222995
             }},

            {{0.7607850486168974, 0.5612771975694962, 0.770967179954561, 0.4937955963643907, 0.5227328293819941, 0.4275410183585496
             },
             {0.0254191267440952, 0.1078914269933045, 0.0314291856867343, 0.6364104112637804, 0.3143559810763267, 0.5085706911647028
             },
             {0.907566473926093, 0.2492922291488749, 0.4103829230356297, 0.7555511385430487, 0.2287981654916225, 0.076979909828793
             },
             {0.289751452913768, 0.1612212872540044, 0.9296976523425731, 0.808120379564417, 0.6334037565104235, 0.8714605901877177
             },
             {0.8036720768991145, 0.1865700588860358, 0.8925589984899778, 0.5393422419156507, 0.8074401551640625, 0.8960912999234932
             }},

            {{0.3180034749718639, 0.1100519245276768, 0.2279351625419417, 0.4271077886262563, 0.8180147659224931, 0.8607305832563434
             },
             {0.0069521305311907, 0.5107473025775657, 0.417411003148779, 0.2221078104707302, 0.1198653673336828, 0.337615171403628
             },
             {0.9429097039125192, 0.3232029320207552, 0.5187906217433661, 0.7030189588951778, 0.363629602379294, 0.9717820827209607
             },
             {0.9624472949421112, 0.2517822958253642, 0.4972485058923855, 0.3008783098167697, 0.2848404943774676, 0.0368869473545328
             },
             {0.6095643339798968, 0.5026790232288615, 0.0514787512499894, 0.2786464642366114, 0.9082658859666537, 0.2395618906669724
             }}
        };

        // py_q0 = np.quantile(a, [0., .3, .1, 1.], axis=0)
        xarray<double> py_q0 = {
            {{0.0650515929852795, 0.1100519245276768, 0.2279351625419417, 0.4271077886262563, 0.1409242249747626, 0.0976721140063839
             },
             {0.0069521305311907, 0.1078914269933045, 0.0314291856867343, 0.1987156815341724, 0.0055221171236024, 0.337615171403628
             },
             {0.2587799816000169, 0.2492922291488749, 0.311711076089411, 0.0740446517340904, 0.2287981654916225, 0.076979909828793
             },
             {0.289751452913768, 0.1612212872540044, 0.3308980248526492, 0.0635583502860236, 0.2848404943774676, 0.0368869473545328
             },
             {0.0884925020519195, 0.1865700588860358, 0.0452272889105381, 0.2786464642366114, 0.1195942459383017, 0.2395618906669724
             }},

            {{0.2927082867732054, 0.332083186476998, 0.2756345749728368, 0.4871268155905773, 0.2882448147535099, 0.394554127923333
             },
             {0.0235724271228047, 0.4069263870649716, 0.1129773299289744, 0.2197685975770745, 0.0315018807160568, 0.4914751391885953
             },
             {0.662049607622857, 0.3158118617335672, 0.4005157383410078, 0.4754656842334387, 0.3454989722390076, 0.111980144555496
             },
             {0.8057682285794109, 0.2427261949682282, 0.4806134577884118, 0.277146313863695, 0.3083681389818427, 0.2963536845595256
             },
             {0.5574571507870991, 0.1950415820658343, 0.0508536050160442, 0.3206619441105991, 0.361768985314364, 0.2681703176632035
             }},

            {{0.1409371575812548, 0.1840623451774505, 0.2438349666855734, 0.4471141309476966, 0.1900310882343451, 0.1966327853120336
             },
             {0.012492229395062, 0.2075697470171935, 0.0586119004341476, 0.2057333202151398, 0.0141820383210872, 0.3889018273319504
             },
             {0.393203190274297, 0.271465440010439, 0.3413126301732766, 0.2078516625672065, 0.2676984344074175, 0.088646654737694
             },
             {0.4617570448023157, 0.1883895898254123, 0.3808031691645701, 0.1347543381452475, 0.292683042578926, 0.1233758597561971
             },
             {0.2448140516303127, 0.1893938999459686, 0.0471027276123735, 0.2926516241946073, 0.2003191590636558, 0.2490980329990495
             }},

            {{0.8287375091519293, 0.9488855372533332, 0.9656320330745594, 0.8083973481164611, 0.8180147659224931, 0.8607305832563434
             },
             {0.6842330265121569, 0.9868869366005173, 0.7722447692966574, 0.6364104112637804, 0.3143559810763267, 0.9093204020787821
             },
             {0.9429097039125192, 0.7290071680409873, 0.7712703466859457, 0.7555511385430487, 0.5467102793432796, 0.9717820827209607
             },
             {0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.6334037565104235, 0.9218742350231168
             },
             {0.8036720768991145, 0.6375574713552131, 0.8925589984899778, 0.5393422419156507, 0.9082658859666537, 0.8960912999234932
             }}
        };
        // py_q1 = np.quantile(a, [0., .3, .1, 1.], axis=1)
        xarray<double> py_q1 = {
            {{0.0650515929852795, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643, 0.0343885211152184, 0.0976721140063839
             },
             {0.0745506436797708, 0.3567533266935893, 0.2809345096873808, 0.0635583502860236, 0.0055221171236024, 0.1158690595251297
             },
             {0.0254191267440952, 0.1078914269933045, 0.0314291856867343, 0.4937955963643907, 0.2287981654916225, 0.076979909828793
             },
             {0.0069521305311907, 0.1100519245276768, 0.0514787512499894, 0.2221078104707302, 0.1198653673336828, 0.0368869473545328
             }},

            {{0.122549997961539, 0.4846264518624774, 0.1599728030937052, 0.5001551323245783, 0.321426473276593, 0.2021533707752008
             },
             {0.7114071107457065, 0.6261499957330889, 0.4189724892193085, 0.0989788576941068, 0.1238602417455939, 0.4027956150659966
             },
             {0.3839581720543939, 0.1662910415804107, 0.4824997744194159, 0.5587558757852766, 0.3560313507374601, 0.4437469529197803
             },
             {0.3763156467734705, 0.2660664230644424, 0.2658303306633091, 0.2830928333526431, 0.3005983159778329, 0.2591725468143035
             }},

            {{0.0744279566119355, 0.2936507149473276, 0.0759516672842344, 0.3932689625024667, 0.1424786203384793, 0.1325450506140411
             },
             {0.3274733237469094, 0.4633712467471767, 0.3009199157534881, 0.0677528708652503, 0.0511509686494821, 0.1995947645257767
             },
             {0.1311520572119643, 0.1292233710975844, 0.1830106806262924, 0.5120142545848947, 0.2630212917255041, 0.2172043532406956
             },
             {0.13137266830746, 0.1667440730467517, 0.1220613157667703, 0.2447232719770827, 0.1858554181511967, 0.1179569246795086
             }},

            {{0.9695846277645586, 0.9488855372533332, 0.9656320330745594, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168
             },
             {0.8631034258755935, 0.9868869366005173, 0.8872127425763265, 0.5426960831582485, 0.3584657285442726, 0.8154614284548342
             },
             {0.907566473926093, 0.5612771975694962, 0.9296976523425731, 0.808120379564417, 0.8074401551640625, 0.8960912999234932
             },
             {0.9624472949421112, 0.5107473025775657, 0.5187906217433661, 0.7030189588951778, 0.9082658859666537, 0.9717820827209607
             }}
        };
        // py_q2 = np.quantile(a, [0., .3, .1, 1.], axis=2)
        xarray<double> py_q2 = {
            {{0.0650515929852795, 0.0343885211152184, 0.184854455525527, 0.5978999788110851, 0.0452272889105381},
             {0.1409242249747626, 0.0055221171236024, 0.0740446517340904, 0.0635583502860236, 0.1195942459383017},
             {0.4275410183585496, 0.0254191267440952, 0.076979909828793, 0.1612212872540044, 0.1865700588860358},
             {0.1100519245276768, 0.0069521305311907, 0.3232029320207552, 0.0368869473545328, 0.0514787512499894
             }},

            {{0.2011429415898773, 0.2810953642921901, 0.2852455288447139, 0.8349800868943817, 0.1422376822355323},
             {0.318843918190485, 0.1366331626069716, 0.2371673940347012, 0.3180828218712046, 0.5548861982585812},
             {0.5082642128731925, 0.0696603063400194, 0.2390451973202487, 0.4615776047120957, 0.6715071594073826},
             {0.2729693187569028, 0.1709865889022065, 0.44121011206133, 0.2683113951014159, 0.2591041774517919}},

            {{0.0813618534958317, 0.0782133779799986, 0.221817218562772, 0.6865164010860998, 0.0668598954812288},
             {0.2109293673310717, 0.0400363804016866, 0.09495685562961, 0.1872703360008429, 0.2959045855501255},
             {0.4606683073614702, 0.0284241562154147, 0.1528890376602077, 0.2254863700838862, 0.3629561504008433},
             {0.1689935435348092, 0.0634087489324368, 0.3434162672000246, 0.1443346215899485, 0.1455203209584809
             }},

            {{0.9656320330745594, 0.9093204020787821, 0.662522284353982, 0.9695846277645586, 0.388677289689482},
             {0.8287375091519293, 0.9868869366005173, 0.7712703466859457, 0.8631034258755935, 0.8872127425763265},
             {0.770967179954561, 0.6364104112637804, 0.907566473926093, 0.9296976523425731, 0.8960912999234932},
             {0.8607305832563434, 0.5107473025775657, 0.9717820827209607, 0.9624472949421112, 0.9082658859666537}}
        };

        EXPECT_TRUE(xt::allclose(py_q0, xt::quantile(py_a, {0., .3, .1, 1.}, 0)));
        EXPECT_TRUE(xt::allclose(py_q1, xt::quantile(py_a, {0., .3, .1, 1.}, 1)));
        EXPECT_TRUE(xt::allclose(py_q2, xt::quantile(py_a, {0., .3, .1, 1.}, 2)));
    }

    /*py
    a = np.random.rand(4, 5, 6)
    */
    TEST(xtest_extended, quantile_methods)
    {
        // py_a
        xarray<double> py_a = {
            {{0.1448948720912231, 0.489452760277563, 0.9856504541106007, 0.2420552715115004, 0.6721355474058786, 0.7616196153287176
             },
             {0.2376375439923997, 0.7282163486118596, 0.3677831327192532, 0.6323058305935795, 0.6335297107608947, 0.5357746840747585
             },
             {0.0902897700544083, 0.835302495589238, 0.3207800649717358, 0.1865185103998542, 0.0407751415547639, 0.5908929431882418
             },
             {0.6775643618422824, 0.0165878289278562, 0.512093058299281, 0.226495775197938, 0.6451727904094499, 0.1743664290049914
             },
             {0.690937738102466, 0.3867353463005374, 0.9367299887367345, 0.1375209441459933, 0.3410663510502585, 0.1134735212405891
             }},

            {{0.9246936182785628, 0.877339353380981, 0.2579416277151556, 0.659984046034179, 0.8172222002012158, 0.5552008115994623
             },
             {0.5296505783560065, 0.2418522909004517, 0.0931027678058992, 0.8972157579533268, 0.9004180571633305, 0.6331014572732679
             },
             {0.3390297910487007, 0.3492095746126609, 0.7259556788702394, 0.8971102599525771, 0.8870864242651173, 0.7798755458576239
             },
             {0.6420316461542878, 0.0841399649950488, 0.1616287140946138, 0.8985541885270792, 0.6064290596595899, 0.0091970516166296
             },
             {0.1014715428660321, 0.6635017691080558, 0.0050615838462187, 0.1608080514174987, 0.5487337893665861, 0.6918951976926933
             }},

            {{0.6519612595026005, 0.2242693094605598, 0.7121792213475359, 0.2372490874968001, 0.3253996981592677, 0.7464914051180241
             },
             {0.6496328990472147, 0.8492234104941779, 0.6576128923003434, 0.5683086033354716, 0.0936747678280925, 0.3677158030594335
             },
             {0.2652023676817254, 0.2439896433790836, 0.9730105547524456, 0.3930977246667604, 0.8920465551771133, 0.6311386259972629
             },
             {0.7948113035416484, 0.5026370931051921, 0.5769038846263591, 0.4925176938188639, 0.1952429877980445, 0.7224521152615053
             },
             {0.2807723624408558, 0.0243159664314538, 0.6454722959071678, 0.1771106794070489, 0.9404585843529143, 0.9539285770025874
             }},

            {{0.9148643902204485, 0.3701587002554444, 0.0154566165288674, 0.9283185625877254, 0.4281841483173143, 0.9666548190436696
             },
             {0.9636199770892528, 0.8530094554673601, 0.2944488920695857, 0.3850977286019253, 0.8511366715168569, 0.3169220051562777
             },
             {0.1694927466860925, 0.5568012624583502, 0.936154774160781, 0.696029796674973, 0.570061170089365, 0.0971764937707685
             },
             {0.6150072266991697, 0.9900538501042633, 0.140084015236524, 0.5183296523637367, 0.8773730719279554, 0.7407686177542044
             },
             {0.697015740995268, 0.7024840839871093, 0.3594911512197552, 0.2935918442644934, 0.8093611554785136, 0.8101133946791808
             }}
        };

        // py_q4 = np.quantile(a, [0., .3, .1, 1.], method="interpolated_inverted_cdf")
        xarray<double> py_q4 = {0.0050615838462187, 0.3169220051562777, 0.1014715428660321, 0.9900538501042633};
        // py_q5 = np.quantile(a, [0., .3, .1, 1.], method="hazen")
        xarray<double> py_q5 = {0.0050615838462187, 0.3188510350640067, 0.1074725320533106, 0.9900538501042633};
        // py_q6 = np.quantile(a, [0., .3, .1, 1.], method="weibull")
        xarray<double> py_q6 = {0.0050615838462187, 0.3180794231009151, 0.1026717407034878, 0.9900538501042633};
        // py_q7 = np.quantile(a, [0., .3, .1, 1.], method="linear")
        xarray<double> py_q7 = {0.0050615838462187, 0.3196226470270984, 0.1122733234031334, 0.9900538501042633};
        // py_q8 = np.quantile(a, [0., .3, .1, 1.], method="median_unbiased")
        xarray<double> py_q8 = {0.0050615838462187, 0.3185938310763095, 0.1058722682700363, 0.9900538501042633};
        // py_q9 = np.quantile(a, [0., .3, .1, 1.], method="normal_unbiased")
        xarray<double> py_q9 = {0.0050615838462187, 0.3186581320732338, 0.1062723342158549, 0.9900538501042633};

        EXPECT_TRUE(
            xt::allclose(py_q4, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::interpolated_inverted_cdf))
        );
        EXPECT_TRUE(xt::allclose(py_q5, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::hazen)));
        EXPECT_TRUE(xt::allclose(py_q6, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::weibull)));
        EXPECT_TRUE(xt::allclose(py_q7, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::linear)));
        EXPECT_TRUE(xt::allclose(py_q8, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::median_unbiased)));
        EXPECT_TRUE(xt::allclose(py_q9, xt::quantile(py_a, {0., .3, .1, 1.}, quantile_method::normal_unbiased)));
    }

    /*py
    a = np.random.rand(5, 5, 5)
    */
    TEST(xtest_extended, axis_median)
    {
        // py_a
        xarray<double> py_a = {
            {{0.8670723185801037, 0.9132405525564713, 0.5113423988609378, 0.5015162946871996, 0.7982951789667752},
             {0.6499639307777652, 0.7019668772577033, 0.795792669436101, 0.8900053418175663, 0.3379951568515358},
             {0.375582952639944, 0.093981939840869, 0.578280140996174, 0.0359422737967421, 0.4655980181324602},
             {0.5426446347075766, 0.2865412521282844, 0.5908332605690108, 0.0305002499390494, 0.0373481887492144},
             {0.8226005606596583, 0.3601906414112629, 0.1270605126518848, 0.5222432600548044, 0.7699935530986108
             }},

            {{0.2158210274968432, 0.6228904758190003, 0.085347464993768, 0.0516817211686077, 0.531354631568148},
             {0.5406351216101065, 0.6374299014982066, 0.7260913337226615, 0.9758520794625346, 0.5163003483011953},
             {0.322956472941246, 0.7951861947687037, 0.2708322512620742, 0.4389714207056361, 0.078456381342266},
             {0.0253507434154575, 0.9626484146779251, 0.8359801205122058, 0.695974206093698, 0.4089529444142699},
             {0.1732943200708458, 0.156437042671086, 0.2502428981645953, 0.5492266647061205, 0.7145959227000623}},

            {{0.6601973767177313, 0.2799338969459428, 0.9548652806631941, 0.7378969166957685, 0.5543540525114007},
             {0.6117207462343522, 0.4196000624277899, 0.2477309895011575, 0.3559726786512616, 0.7578461104643691},
             {0.0143934886297559, 0.1160726405069162, 0.0460026420217527, 0.0407288023189701, 0.8554605840110072},
             {0.7036578593800237, 0.4741738290873252, 0.0978341606510015, 0.4916158751168324, 0.4734717707805657},
             {0.1732018699100152, 0.433851649237973, 0.3985047343973734, 0.6158500980522165, 0.6350936508676438}},

            {{0.0453040097720445, 0.3746126146264712, 0.6258599157142364, 0.5031362585800877, 0.8564898411883223},
             {0.658693631618945, 0.1629344270814297, 0.0705687474004298, 0.6424192782063156, 0.0265113105416218},
             {0.5857755812734633, 0.9402302414249576, 0.575474177875879, 0.3881699262065219, 0.6432882184423532},
             {0.4582528904915166, 0.5456167893159349, 0.9414648087765252, 0.3861026378007743, 0.9611905638239142},
             {0.9053506419560637, 0.1957911347892964, 0.0693613008751655, 0.1007780013774267, 0.0182218256515497
             }},

            {{0.0944429607559284, 0.6830067734163568, 0.071188648460229, 0.3189756302937613, 0.8448753109694546},
             {0.0232719357358259, 0.8144684825889358, 0.2818547747733999, 0.1181648276216563, 0.6967371653641506},
             {0.628942846779884, 0.877472013527053, 0.7350710438038858, 0.8034809303848486, 0.2820345725713065},
             {0.1774395437797228, 0.7506147516408583, 0.806834739267264, 0.9905051420006733, 0.4126176769114265},
             {0.3720180857927832, 0.7764129607419968, 0.3408035402530178, 0.9307573256035647, 0.8584127518430118}}
        };
        // py_m = np.median(a)
        double py_m = 0.5113423988609378;

        // py_m0 = np.median(a, 0)
        xarray<double> py_m0 = {
            {0.2158210274968432, 0.6228904758190003, 0.5113423988609378, 0.5015162946871996, 0.7982951789667752},
            {0.6117207462343522, 0.6374299014982066, 0.2818547747733999, 0.6424192782063156, 0.5163003483011953},
            {0.375582952639944, 0.7951861947687037, 0.575474177875879, 0.3881699262065219, 0.4655980181324602},
            {0.4582528904915166, 0.5456167893159349, 0.806834739267264, 0.4916158751168324, 0.4126176769114265},
            {0.3720180857927832, 0.3601906414112629, 0.2502428981645953, 0.5492266647061205, 0.7145959227000623}
        };
        // py_m1 = np.median(a, 1)
        xarray<double> py_m1 = {
            {0.6499639307777652, 0.3601906414112629, 0.578280140996174, 0.5015162946871996, 0.4655980181324602},
            {0.2158210274968432, 0.6374299014982066, 0.2708322512620742, 0.5492266647061205, 0.5163003483011953},
            {0.6117207462343522, 0.4196000624277899, 0.2477309895011575, 0.4916158751168324, 0.6350936508676438},
            {0.5857755812734633, 0.3746126146264712, 0.575474177875879, 0.3881699262065219, 0.6432882184423532},
            {0.1774395437797228, 0.7764129607419968, 0.3408035402530178, 0.8034809303848486, 0.6967371653641506}
        };
        // py_m2 = np.median(a, 2)
        xarray<double> py_m2 = {
            {0.7982951789667752, 0.7019668772577033, 0.375582952639944, 0.2865412521282844, 0.5222432600548044},
            {0.2158210274968432, 0.6374299014982066, 0.322956472941246, 0.695974206093698, 0.2502428981645953},
            {0.6601973767177313, 0.4196000624277899, 0.0460026420217527, 0.4741738290873252, 0.433851649237973},
            {0.5031362585800877, 0.1629344270814297, 0.5857755812734633, 0.5456167893159349, 0.1007780013774267},
            {0.3189756302937613, 0.2818547747733999, 0.7350710438038858, 0.7506147516408583, 0.7764129607419968}
        };

        EXPECT_EQ(py_m, xt::median(py_a));
        EXPECT_EQ(py_m0, xt::median(py_a, 0));
        EXPECT_EQ(py_m1, xt::median(py_a, 1));
        EXPECT_EQ(py_m2, xt::median(py_a, 2));
    }

    /*py
    a = np.random.permutation(np.arange(5 * 5 * 5)).reshape(5, 5, 5)
    */
    TEST(xtest_extended, axis_partition)
    {
        // py_a
        xarray<long> py_a = {
            {{110, 67, 43, 114, 86},
             {117, 31, 40, 46, 62},
             {10, 78, 33, 103, 14},
             {23, 101, 66, 91, 89},
             {20, 123, 32, 50, 106}},

            {{69, 108, 96, 64, 65},
             {59, 55, 76, 19, 119},
             {92, 2, 42, 25, 9},
             {63, 79, 115, 30, 5},
             {35, 3, 53, 90, 105}},

            {{71, 21, 0, 44, 47},
             {7, 102, 37, 36, 28},
             {97, 1, 72, 26, 49},
             {73, 81, 39, 109, 45},
             {6, 116, 80, 100, 17}},

            {{74, 34, 4, 13, 113},
             {57, 41, 87, 38, 56},
             {93, 121, 52, 84, 95},
             {24, 118, 68, 15, 82},
             {51, 94, 77, 27, 70}},

            {{8, 75, 107, 60, 11},
             {99, 48, 18, 58, 122},
             {85, 120, 111, 83, 61},
             {124, 16, 29, 104, 98},
             {88, 22, 12, 54, 112}}
        };

        // py_p0 = np.partition(a, 2, 0)
        xarray<long> py_p0 = {
            {{8, 21, 0, 13, 11}, {7, 31, 18, 19, 28}, {10, 1, 33, 25, 9}, {23, 16, 29, 15, 5}, {6, 3, 12, 27, 17}
            },

            {{69, 34, 4, 44, 47},
             {57, 41, 37, 36, 56},
             {85, 2, 42, 26, 14},
             {24, 79, 39, 30, 45},
             {20, 22, 32, 50, 70}},

            {{71, 67, 43, 60, 65},
             {59, 48, 40, 38, 62},
             {92, 78, 52, 83, 49},
             {63, 81, 66, 91, 82},
             {35, 94, 53, 54, 105}},

            {{74, 108, 96, 114, 113},
             {117, 55, 87, 46, 119},
             {93, 121, 72, 84, 95},
             {73, 118, 68, 109, 89},
             {51, 116, 77, 90, 106}},

            {{110, 75, 107, 64, 86},
             {99, 102, 76, 58, 122},
             {97, 120, 111, 103, 61},
             {124, 101, 115, 104, 98},
             {88, 123, 80, 100, 112}}
        };
        // py_p1 = np.partition(a, 4, 1)
        xarray<long> py_p1 = {
            {{10, 31, 32, 91, 62},
             {20, 67, 33, 50, 14},
             {23, 78, 40, 46, 86},
             {110, 101, 43, 103, 89},
             {117, 123, 66, 114, 106}},

            {{63, 2, 42, 30, 5},
             {35, 3, 53, 25, 9},
             {59, 55, 76, 19, 65},
             {69, 79, 96, 64, 105},
             {92, 108, 115, 90, 119}},

            {{7, 1, 39, 36, 45},
             {6, 21, 0, 26, 17},
             {71, 81, 37, 44, 28},
             {73, 102, 72, 100, 47},
             {97, 116, 80, 109, 49}},

            {{24, 41, 4, 15, 82},
             {51, 34, 52, 13, 70},
             {57, 94, 68, 27, 56},
             {74, 118, 77, 38, 95},
             {93, 121, 87, 84, 113}},

            {{8, 16, 29, 58, 11},
             {85, 22, 12, 54, 61},
             {88, 48, 18, 60, 98},
             {99, 75, 107, 83, 112},
             {124, 120, 111, 104, 122}}
        };
        // py_p2 = np.partition(a, 3, 2)
        xarray<long> py_p2 = {
            {{67, 43, 86, 110, 114},
             {46, 40, 31, 62, 117},
             {10, 14, 33, 78, 103},
             {23, 66, 89, 91, 101},
             {20, 32, 50, 106, 123}},

            {{64, 65, 69, 96, 108},
             {19, 59, 55, 76, 119},
             {25, 9, 2, 42, 92},
             {30, 5, 63, 79, 115},
             {3, 35, 53, 90, 105}},

            {{44, 0, 21, 47, 71},
             {7, 28, 36, 37, 102},
             {26, 49, 1, 72, 97},
             {39, 45, 73, 81, 109},
             {6, 17, 80, 100, 116}},

            {{13, 4, 34, 74, 113},
             {38, 56, 41, 57, 87},
             {84, 52, 93, 95, 121},
             {15, 24, 68, 82, 118},
             {27, 51, 70, 77, 94}},

            {{8, 11, 60, 75, 107},
             {58, 18, 48, 99, 122},
             {83, 61, 85, 111, 120},
             {16, 29, 98, 104, 124},
             {54, 12, 22, 88, 112}}
        };

        auto p0 = xt::partition(py_a, 2, 0);
        auto p1 = xt::partition(py_a, 4, 1);
        auto p2 = xt::partition(py_a, 3, 2);

        EXPECT_EQ(xt::view(py_p0, 2, all(), all()), xt::view(p0, 2, all(), all()));
        EXPECT_EQ(xt::view(py_p1, all(), 4, all()), xt::view(p1, all(), 4, all()));
        EXPECT_EQ(xt::view(py_p2, all(), all(), 3), xt::view(p2, all(), all(), 3));

        // py_a0 = np.argpartition(a, 2, 0)
        xarray<long> py_a0 = {
            {{4, 2, 2, 3, 4}, {2, 0, 4, 1, 2}, {0, 2, 0, 1, 1}, {0, 4, 4, 3, 1}, {2, 1, 4, 3, 2}},

            {{1, 3, 3, 2, 2}, {3, 3, 2, 2, 3}, {4, 1, 1, 2, 0}, {3, 1, 2, 1, 2}, {0, 4, 0, 0, 3}},

            {{2, 0, 0, 4, 1}, {1, 4, 0, 3, 0}, {1, 0, 3, 4, 2}, {1, 2, 0, 0, 3}, {1, 3, 1, 4, 1}},

            {{3, 1, 1, 0, 3}, {0, 1, 3, 0, 1}, {3, 3, 2, 3, 3}, {2, 3, 3, 2, 0}, {3, 2, 3, 1, 0}},

            {{0, 4, 4, 1, 0}, {4, 2, 1, 4, 4}, {2, 4, 4, 0, 4}, {4, 0, 1, 4, 4}, {4, 0, 2, 2, 4}}
        };
        // py_a1 = np.argpartition(a, 4, 1)
        xarray<long> py_a1 = {
            {{2, 1, 4, 3, 1}, {4, 0, 2, 4, 2}, {3, 2, 1, 1, 0}, {0, 3, 0, 2, 3}, {1, 4, 3, 0, 4}},

            {{3, 2, 2, 3, 3}, {4, 4, 4, 2, 2}, {1, 1, 1, 1, 0}, {0, 3, 0, 0, 4}, {2, 0, 3, 4, 1}},

            {{1, 2, 3, 1, 3}, {4, 0, 0, 2, 4}, {0, 3, 1, 0, 1}, {3, 1, 2, 4, 0}, {2, 4, 4, 3, 2}},

            {{3, 1, 0, 3, 3}, {4, 0, 2, 0, 4}, {1, 4, 3, 4, 1}, {0, 3, 4, 1, 2}, {2, 2, 1, 2, 0}},

            {{0, 3, 3, 1, 0}, {2, 4, 4, 4, 2}, {4, 1, 1, 0, 3}, {1, 0, 0, 2, 4}, {3, 2, 2, 3, 1}}
        };
        // py_a2 = np.argpartition(a, 3, 2)
        xarray<long> py_a2 = {
            {{1, 2, 4, 0, 3}, {3, 2, 1, 4, 0}, {0, 4, 2, 1, 3}, {0, 2, 4, 3, 1}, {0, 2, 3, 4, 1}},

            {{3, 4, 0, 2, 1}, {3, 0, 1, 2, 4}, {3, 4, 1, 2, 0}, {3, 4, 0, 1, 2}, {1, 0, 2, 3, 4}},

            {{3, 2, 1, 4, 0}, {0, 4, 3, 2, 1}, {3, 4, 1, 2, 0}, {2, 4, 0, 1, 3}, {0, 4, 2, 3, 1}},

            {{3, 2, 1, 0, 4}, {3, 4, 1, 0, 2}, {3, 2, 0, 4, 1}, {3, 0, 2, 4, 1}, {3, 0, 4, 2, 1}},

            {{0, 4, 3, 1, 2}, {3, 2, 1, 0, 4}, {3, 4, 0, 2, 1}, {1, 2, 4, 3, 0}, {3, 2, 1, 0, 4}}
        };

        auto a0 = xt::argpartition(py_a, 2, 0);
        auto a1 = xt::argpartition(py_a, 4, 1);
        auto a2 = xt::argpartition(py_a, 3, 2);

        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a0, 2, all(), all())), xt::view(a0, 2, all(), all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a1, all(), 4, all())), xt::view(a1, all(), 4, all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a2, all(), all(), 3)), xt::view(a2, all(), all(), 3));
    }

    /*py
    a = np.random.permutation(np.arange(5 * 5 * 5)).reshape(5, 5, 5)
    */
    TEST(xtest_extended, multi_k_axis_partition)
    {
        // py_a
        xarray<long> py_a = {
            {{87, 38, 85, 104, 100},
             {69, 50, 60, 108, 42},
             {19, 113, 66, 122, 54},
             {81, 31, 109, 111, 78},
             {43, 93, 6, 105, 1}},

            {{98, 110, 97, 3, 77},
             {61, 44, 118, 8, 123},
             {52, 96, 18, 39, 112},
             {41, 36, 22, 119, 37},
             {51, 121, 107, 88, 94}},

            {{79, 47, 20, 120, 80},
             {92, 33, 70, 82, 67},
             {90, 58, 21, 84, 99},
             {25, 10, 124, 17, 64},
             {114, 4, 29, 55, 9}},

            {{65, 24, 46, 68, 5},
             {34, 45, 12, 28, 76},
             {83, 32, 72, 16, 62},
             {26, 63, 40, 106, 103},
             {49, 59, 57, 102, 89}},

            {{35, 91, 7, 27, 23},
             {75, 101, 71, 115, 95},
             {73, 11, 74, 56, 86},
             {15, 2, 117, 53, 116},
             {0, 30, 14, 13, 48}}
        };

        // py_p0 = np.partition(a, (1, 2), 0)
        xarray<long> py_p0 = {
            {{35, 24, 7, 3, 5}, {34, 33, 12, 8, 42}, {19, 11, 18, 16, 54}, {15, 2, 22, 17, 37}, {0, 4, 6, 13, 1}},

            {{65, 38, 20, 27, 23},
             {61, 44, 60, 28, 67},
             {52, 32, 21, 39, 62},
             {25, 10, 40, 53, 64},
             {43, 30, 14, 55, 9}},

            {{79, 47, 46, 68, 77},
             {69, 45, 70, 82, 76},
             {73, 58, 66, 56, 86},
             {26, 31, 109, 106, 78},
             {49, 59, 29, 88, 48}},

            {{98, 110, 97, 120, 100},
             {92, 50, 118, 108, 123},
             {83, 96, 72, 122, 112},
             {41, 63, 124, 111, 103},
             {114, 93, 57, 102, 89}},

            {{87, 91, 85, 104, 80},
             {75, 101, 71, 115, 95},
             {90, 113, 74, 84, 99},
             {81, 36, 117, 119, 116},
             {51, 121, 107, 105, 94}}
        };
        // py_p1 = np.partition(a, (1, 4), 1)
        xarray<long> py_p1 = {
            {{19, 31, 6, 104, 1},
             {43, 38, 60, 105, 42},
             {69, 50, 66, 108, 54},
             {81, 93, 85, 111, 78},
             {87, 113, 109, 122, 100}},

            {{41, 36, 18, 3, 37},
             {51, 44, 22, 8, 77},
             {52, 96, 97, 39, 94},
             {61, 110, 107, 88, 112},
             {98, 121, 118, 119, 123}},

            {{25, 4, 20, 17, 9},
             {79, 10, 21, 55, 64},
             {90, 33, 29, 82, 67},
             {92, 47, 70, 84, 80},
             {114, 58, 124, 120, 99}},

            {{26, 24, 12, 16, 5},
             {34, 32, 40, 28, 62},
             {49, 45, 46, 68, 76},
             {65, 59, 57, 102, 89},
             {83, 63, 72, 106, 103}},

            {{0, 2, 7, 13, 23},
             {15, 11, 14, 27, 48},
             {35, 30, 71, 53, 86},
             {73, 91, 74, 56, 95},
             {75, 101, 117, 115, 116}}
        };
        // py_p2 = np.partition(a, (1, 3), 2)
        xarray<long> py_p2 = {
            {{38, 85, 87, 100, 104},
             {42, 50, 60, 69, 108},
             {19, 54, 66, 113, 122},
             {31, 78, 81, 109, 111},
             {1, 6, 43, 93, 105}},

            {{3, 77, 97, 98, 110},
             {8, 44, 61, 118, 123},
             {18, 39, 52, 96, 112},
             {22, 36, 37, 41, 119},
             {51, 88, 94, 107, 121}},

            {{20, 47, 79, 80, 120},
             {33, 67, 70, 82, 92},
             {21, 58, 84, 90, 99},
             {10, 17, 25, 64, 124},
             {4, 9, 29, 55, 114}},

            {{5, 24, 46, 65, 68},
             {12, 28, 34, 45, 76},
             {16, 32, 62, 72, 83},
             {26, 40, 63, 103, 106},
             {49, 57, 59, 89, 102}},

            {{7, 23, 27, 35, 91},
             {71, 75, 95, 101, 115},
             {11, 56, 73, 74, 86},
             {2, 15, 53, 116, 117},
             {0, 13, 14, 30, 48}}
        };

        auto p0 = xt::partition(py_a, {1, 2}, 0);
        auto p1 = xt::partition(py_a, {1, 4}, 1);
        auto p2 = xt::partition(py_a, {1, 3}, 2);

        EXPECT_EQ(xt::view(py_p0, 2, all(), all()), xt::view(p0, 2, all(), all()));
        EXPECT_EQ(xt::view(py_p1, all(), 4, all()), xt::view(p1, all(), 4, all()));
        EXPECT_EQ(xt::view(py_p2, all(), all(), 3), xt::view(p2, all(), all(), 3));

        EXPECT_EQ(xt::view(py_p0, 1, all(), all()), xt::view(p0, 1, all(), all()));
        EXPECT_EQ(xt::view(py_p1, all(), 1, all()), xt::view(p1, all(), 1, all()));
        EXPECT_EQ(xt::view(py_p2, all(), all(), 1), xt::view(p2, all(), all(), 1));

        // py_a0 = np.argpartition(a, (1, 2), 0)
        xarray<long> py_a0 = {
            {{4, 3, 4, 1, 3}, {3, 2, 3, 1, 0}, {0, 4, 1, 3, 0}, {4, 4, 1, 2, 1}, {4, 2, 0, 4, 0}},

            {{3, 0, 2, 4, 4}, {1, 1, 0, 3, 2}, {1, 3, 2, 1, 3}, {2, 2, 3, 4, 2}, {0, 4, 4, 2, 2}},

            {{2, 2, 3, 3, 1}, {0, 3, 2, 2, 3}, {4, 2, 0, 4, 4}, {3, 0, 0, 3, 0}, {3, 3, 2, 1, 4}},

            {{1, 1, 1, 2, 0}, {2, 0, 1, 0, 1}, {3, 1, 3, 0, 1}, {1, 3, 2, 0, 3}, {2, 0, 3, 3, 3}},

            {{0, 4, 0, 0, 2}, {4, 4, 4, 4, 4}, {2, 0, 4, 2, 2}, {0, 1, 4, 1, 4}, {1, 1, 1, 0, 1}}
        };
        // py_a1 = np.argpartition(a, (1, 4), 1)
        xarray<long> py_a1 = {
            {{2, 3, 4, 0, 4}, {4, 0, 1, 4, 1}, {1, 1, 2, 1, 2}, {3, 4, 0, 3, 3}, {0, 2, 3, 2, 0}},

            {{3, 3, 2, 0, 3}, {4, 1, 3, 1, 0}, {2, 2, 0, 2, 4}, {1, 0, 4, 4, 2}, {0, 4, 1, 3, 1}},

            {{3, 4, 0, 3, 4}, {0, 3, 2, 4, 3}, {2, 1, 4, 1, 1}, {1, 0, 1, 2, 0}, {4, 2, 3, 0, 2}},

            {{3, 0, 1, 2, 0}, {1, 2, 3, 1, 2}, {4, 1, 0, 0, 1}, {0, 4, 4, 4, 4}, {2, 3, 2, 3, 3}},

            {{4, 3, 0, 4, 0}, {3, 2, 4, 0, 4}, {0, 4, 1, 3, 2}, {2, 0, 2, 2, 1}, {1, 1, 3, 1, 3}}
        };
        // py_a2 = np.argpartition(a, (1, 3), 2)
        xarray<long> py_a2 = {
            {{1, 2, 0, 4, 3}, {4, 1, 2, 0, 3}, {0, 4, 2, 1, 3}, {1, 4, 0, 2, 3}, {4, 2, 0, 1, 3}},

            {{3, 4, 2, 0, 1}, {3, 1, 0, 2, 4}, {2, 3, 0, 1, 4}, {2, 1, 4, 0, 3}, {0, 3, 4, 2, 1}},

            {{2, 1, 0, 4, 3}, {1, 4, 2, 3, 0}, {2, 1, 3, 0, 4}, {1, 3, 0, 4, 2}, {1, 4, 2, 3, 0}},

            {{4, 1, 2, 0, 3}, {2, 3, 0, 1, 4}, {3, 1, 4, 2, 0}, {0, 2, 1, 4, 3}, {0, 2, 1, 4, 3}},

            {{2, 4, 3, 0, 1}, {2, 0, 4, 1, 3}, {1, 3, 0, 2, 4}, {1, 0, 3, 4, 2}, {0, 3, 2, 1, 4}}
        };

        auto a0 = xt::argpartition(py_a, {1, 2}, 0);
        auto a1 = xt::argpartition(py_a, {1, 4}, 1);
        auto a2 = xt::argpartition(py_a, {1, 3}, 2);

        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a0, 2, all(), all())), xt::view(a0, 2, all(), all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a1, all(), 4, all())), xt::view(a1, all(), 4, all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a2, all(), all(), 3)), xt::view(a2, all(), all(), 3));

        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a0, 1, all(), all())), xt::view(a0, 1, all(), all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a1, all(), 1, all())), xt::view(a1, all(), 1, all()));
        EXPECT_EQ(xt::cast<std::size_t>(xt::view(py_a2, all(), all(), 1)), xt::view(a2, all(), all(), 1));
    }
}