File: test_dot_extended.cpp

package info (click to toggle)
xtensor-blas 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,068 kB
  • sloc: cpp: 97,896; makefile: 202; perl: 178; python: 153
file content (189 lines) | stat: -rw-r--r-- 8,606 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/***************************************************************************
 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
 * Copyright (c) QuantStack                                                 *
 *                                                                          *
 * Distributed under the terms of the BSD 3-Clause License.                 *
 *                                                                          *
 * The full license is in the file LICENSE, distributed with this software. *
 ****************************************************************************/

// This file is generated from test/files/cppy_source/test_dot_extended.cppy by
// preprocess.py!

#include <algorithm>

#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"

#include "gtest/gtest.h"
#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    using namespace xt::placeholders;

    /*py
    a = np.random.random((2, 3, 5))
    b = np.random.random((4, 5))
    dr = np.dot(a, b.T)
    */
    TEST(xtest_extended, dot_broadcast)
    {
        // py_a
        xarray<double> py_a = {
            {{0.3745401188473625, 0.9507143064099162, 0.7319939418114051, 0.5986584841970366, 0.1560186404424365},
             {0.1559945203362026, 0.0580836121681995, 0.8661761457749352, 0.6011150117432088, 0.7080725777960455},
             {0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006
             }},

            {{0.1834045098534338, 0.3042422429595377, 0.5247564316322378, 0.4319450186421158, 0.2912291401980419},
             {0.6118528947223795, 0.1394938606520418, 0.2921446485352182, 0.3663618432936917, 0.4560699842170359},
             {0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977}}
        };
        // py_b
        xarray<double> py_b = {
            {0.6075448519014384, 0.1705241236872915, 0.0650515929852795, 0.9488855372533332, 0.9656320330745594},
            {0.8083973481164611, 0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013},
            {0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169},
            {0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527}
        };
        // py_dr
        xarray<double> py_dr = {
            {{1.1560019913607258, 1.1421672030085086, 1.1263990512143978, 1.2813094834150083},
             {1.415151366639716, 0.9513625344824885, 0.807426629014782, 1.0314517921651605},
             {0.6091122748507029, 0.6187149240291543, 0.7515524775267591, 0.898595256683809}},

            {{0.8885299172713558, 0.7159304454839006, 0.6592223836380569, 0.7792380767202456},
             {1.2025508600129964, 1.0170636073271262, 0.6049520893427571, 0.8853834024749684},
             {1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}}
        };

        xt::xtensor<double, 2> bas = xt::transpose(py_b);

        auto xres = xt::linalg::dot(py_a, xt::transpose(py_b));
        auto xres2 = xt::linalg::dot(py_a, bas);
        std::cout << xres << std::endl;
        EXPECT_TRUE(xt::allclose(xres, py_dr));
        EXPECT_TRUE(xt::allclose(xres2, py_dr));
    }

    /*py
    a = np.random.random((2, 3, 5))
    b = np.random.random((5))
    dr = np.dot(a, b)
    */
    TEST(xtest_extended, dot_broadcast_2)
    {
        // py_a
        xarray<double> py_a = {
            {{0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851},
             {0.9218742350231168, 0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643},
             {0.388677289689482, 0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808}},

            {{0.5426960831582485, 0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173},
             {0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171},
             {0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297}}
        };
        // py_b
        xarray<double> py_b =
            {0.8631034258755935, 0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622};
        // py_dr
        xarray<double> py_dr = {
            {1.8736790686065976, 1.0197269167779506, 0.8888679673881792},
            {1.1333287572487494, 1.0638629967411402, 1.1932578950872312}
        };

        auto xres = xt::linalg::dot(py_a, py_b);
        std::cout << xres << std::endl;
        EXPECT_TRUE(xt::allclose(xres, py_dr));
    }

    /*py
    a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
    b = np.arange(4 * 5 * 3).reshape(4, 5, 3)
    dr = np.dot(a, b)
    */
    TEST(xtest_extended, dot_broadcast_3)
    {
        // py_a
        xarray<long> py_a = {
            {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}, {10, 11, 12, 13, 14}, {15, 16, 17, 18, 19}},

             {{20, 21, 22, 23, 24}, {25, 26, 27, 28, 29}, {30, 31, 32, 33, 34}, {35, 36, 37, 38, 39}},

             {{40, 41, 42, 43, 44}, {45, 46, 47, 48, 49}, {50, 51, 52, 53, 54}, {55, 56, 57, 58, 59}}},

            {{{60, 61, 62, 63, 64}, {65, 66, 67, 68, 69}, {70, 71, 72, 73, 74}, {75, 76, 77, 78, 79}},

             {{80, 81, 82, 83, 84}, {85, 86, 87, 88, 89}, {90, 91, 92, 93, 94}, {95, 96, 97, 98, 99}},

             {{100, 101, 102, 103, 104},
              {105, 106, 107, 108, 109},
              {110, 111, 112, 113, 114},
              {115, 116, 117, 118, 119}}}
        };
        // py_b
        xarray<long> py_b = {
            {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}, {12, 13, 14}},

            {{15, 16, 17}, {18, 19, 20}, {21, 22, 23}, {24, 25, 26}, {27, 28, 29}},

            {{30, 31, 32}, {33, 34, 35}, {36, 37, 38}, {39, 40, 41}, {42, 43, 44}},

            {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}}
        };
        // py_dr
        xarray<long> py_dr = {
            {{{{90, 100, 110}, {240, 250, 260}, {390, 400, 410}, {540, 550, 560}},

              {{240, 275, 310}, {765, 800, 835}, {1290, 1325, 1360}, {1815, 1850, 1885}},

              {{390, 450, 510}, {1290, 1350, 1410}, {2190, 2250, 2310}, {3090, 3150, 3210}},

              {{540, 625, 710}, {1815, 1900, 1985}, {3090, 3175, 3260}, {4365, 4450, 4535}}},

             {{{690, 800, 910}, {2340, 2450, 2560}, {3990, 4100, 4210}, {5640, 5750, 5860}},

              {{840, 975, 1110}, {2865, 3000, 3135}, {4890, 5025, 5160}, {6915, 7050, 7185}},

              {{990, 1150, 1310}, {3390, 3550, 3710}, {5790, 5950, 6110}, {8190, 8350, 8510}},

              {{1140, 1325, 1510}, {3915, 4100, 4285}, {6690, 6875, 7060}, {9465, 9650, 9835}}},

             {{{1290, 1500, 1710}, {4440, 4650, 4860}, {7590, 7800, 8010}, {10740, 10950, 11160}},

              {{1440, 1675, 1910}, {4965, 5200, 5435}, {8490, 8725, 8960}, {12015, 12250, 12485}},

              {{1590, 1850, 2110}, {5490, 5750, 6010}, {9390, 9650, 9910}, {13290, 13550, 13810}},

              {{1740, 2025, 2310}, {6015, 6300, 6585}, {10290, 10575, 10860}, {14565, 14850, 15135}}}},

            {{{{1890, 2200, 2510}, {6540, 6850, 7160}, {11190, 11500, 11810}, {15840, 16150, 16460}},

              {{2040, 2375, 2710}, {7065, 7400, 7735}, {12090, 12425, 12760}, {17115, 17450, 17785}},

              {{2190, 2550, 2910}, {7590, 7950, 8310}, {12990, 13350, 13710}, {18390, 18750, 19110}},

              {{2340, 2725, 3110}, {8115, 8500, 8885}, {13890, 14275, 14660}, {19665, 20050, 20435}}},

             {{{2490, 2900, 3310}, {8640, 9050, 9460}, {14790, 15200, 15610}, {20940, 21350, 21760}},

              {{2640, 3075, 3510}, {9165, 9600, 10035}, {15690, 16125, 16560}, {22215, 22650, 23085}},

              {{2790, 3250, 3710}, {9690, 10150, 10610}, {16590, 17050, 17510}, {23490, 23950, 24410}},

              {{2940, 3425, 3910}, {10215, 10700, 11185}, {17490, 17975, 18460}, {24765, 25250, 25735}}},

             {{{3090, 3600, 4110}, {10740, 11250, 11760}, {18390, 18900, 19410}, {26040, 26550, 27060}},

              {{3240, 3775, 4310}, {11265, 11800, 12335}, {19290, 19825, 20360}, {27315, 27850, 28385}},

              {{3390, 3950, 4510}, {11790, 12350, 12910}, {20190, 20750, 21310}, {28590, 29150, 29710}},

              {{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}}
        };

        auto xres = xt::linalg::dot(py_a, py_b);
        EXPECT_TRUE(xt::allclose(xres, py_dr));
    }
}  // namespace xt