1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
/***************************************************************************
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
* Copyright (c) QuantStack *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/
// This file is generated from test/files/cppy_source/test_dot_extended.cppy by
// preprocess.py!
#include <algorithm>
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"
#include "gtest/gtest.h"
#include "xtensor-blas/xlinalg.hpp"
namespace xt
{
using namespace xt::placeholders;
/*py
a = np.random.random((2, 3, 5))
b = np.random.random((4, 5))
dr = np.dot(a, b.T)
*/
TEST(xtest_extended, dot_broadcast)
{
// py_a
xarray<double> py_a = {
{{0.3745401188473625, 0.9507143064099162, 0.7319939418114051, 0.5986584841970366, 0.1560186404424365},
{0.1559945203362026, 0.0580836121681995, 0.8661761457749352, 0.6011150117432088, 0.7080725777960455},
{0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006
}},
{{0.1834045098534338, 0.3042422429595377, 0.5247564316322378, 0.4319450186421158, 0.2912291401980419},
{0.6118528947223795, 0.1394938606520418, 0.2921446485352182, 0.3663618432936917, 0.4560699842170359},
{0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977}}
};
// py_b
xarray<double> py_b = {
{0.6075448519014384, 0.1705241236872915, 0.0650515929852795, 0.9488855372533332, 0.9656320330745594},
{0.8083973481164611, 0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013},
{0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169},
{0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527}
};
// py_dr
xarray<double> py_dr = {
{{1.1560019913607258, 1.1421672030085086, 1.1263990512143978, 1.2813094834150083},
{1.415151366639716, 0.9513625344824885, 0.807426629014782, 1.0314517921651605},
{0.6091122748507029, 0.6187149240291543, 0.7515524775267591, 0.898595256683809}},
{{0.8885299172713558, 0.7159304454839006, 0.6592223836380569, 0.7792380767202456},
{1.2025508600129964, 1.0170636073271262, 0.6049520893427571, 0.8853834024749684},
{1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}}
};
xt::xtensor<double, 2> bas = xt::transpose(py_b);
auto xres = xt::linalg::dot(py_a, xt::transpose(py_b));
auto xres2 = xt::linalg::dot(py_a, bas);
std::cout << xres << std::endl;
EXPECT_TRUE(xt::allclose(xres, py_dr));
EXPECT_TRUE(xt::allclose(xres2, py_dr));
}
/*py
a = np.random.random((2, 3, 5))
b = np.random.random((5))
dr = np.dot(a, b)
*/
TEST(xtest_extended, dot_broadcast_2)
{
// py_a
xarray<double> py_a = {
{{0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851},
{0.9218742350231168, 0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643},
{0.388677289689482, 0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808}},
{{0.5426960831582485, 0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173},
{0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171},
{0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297}}
};
// py_b
xarray<double> py_b =
{0.8631034258755935, 0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622};
// py_dr
xarray<double> py_dr = {
{1.8736790686065976, 1.0197269167779506, 0.8888679673881792},
{1.1333287572487494, 1.0638629967411402, 1.1932578950872312}
};
auto xres = xt::linalg::dot(py_a, py_b);
std::cout << xres << std::endl;
EXPECT_TRUE(xt::allclose(xres, py_dr));
}
/*py
a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
b = np.arange(4 * 5 * 3).reshape(4, 5, 3)
dr = np.dot(a, b)
*/
TEST(xtest_extended, dot_broadcast_3)
{
// py_a
xarray<long> py_a = {
{{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}, {10, 11, 12, 13, 14}, {15, 16, 17, 18, 19}},
{{20, 21, 22, 23, 24}, {25, 26, 27, 28, 29}, {30, 31, 32, 33, 34}, {35, 36, 37, 38, 39}},
{{40, 41, 42, 43, 44}, {45, 46, 47, 48, 49}, {50, 51, 52, 53, 54}, {55, 56, 57, 58, 59}}},
{{{60, 61, 62, 63, 64}, {65, 66, 67, 68, 69}, {70, 71, 72, 73, 74}, {75, 76, 77, 78, 79}},
{{80, 81, 82, 83, 84}, {85, 86, 87, 88, 89}, {90, 91, 92, 93, 94}, {95, 96, 97, 98, 99}},
{{100, 101, 102, 103, 104},
{105, 106, 107, 108, 109},
{110, 111, 112, 113, 114},
{115, 116, 117, 118, 119}}}
};
// py_b
xarray<long> py_b = {
{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}, {12, 13, 14}},
{{15, 16, 17}, {18, 19, 20}, {21, 22, 23}, {24, 25, 26}, {27, 28, 29}},
{{30, 31, 32}, {33, 34, 35}, {36, 37, 38}, {39, 40, 41}, {42, 43, 44}},
{{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}}
};
// py_dr
xarray<long> py_dr = {
{{{{90, 100, 110}, {240, 250, 260}, {390, 400, 410}, {540, 550, 560}},
{{240, 275, 310}, {765, 800, 835}, {1290, 1325, 1360}, {1815, 1850, 1885}},
{{390, 450, 510}, {1290, 1350, 1410}, {2190, 2250, 2310}, {3090, 3150, 3210}},
{{540, 625, 710}, {1815, 1900, 1985}, {3090, 3175, 3260}, {4365, 4450, 4535}}},
{{{690, 800, 910}, {2340, 2450, 2560}, {3990, 4100, 4210}, {5640, 5750, 5860}},
{{840, 975, 1110}, {2865, 3000, 3135}, {4890, 5025, 5160}, {6915, 7050, 7185}},
{{990, 1150, 1310}, {3390, 3550, 3710}, {5790, 5950, 6110}, {8190, 8350, 8510}},
{{1140, 1325, 1510}, {3915, 4100, 4285}, {6690, 6875, 7060}, {9465, 9650, 9835}}},
{{{1290, 1500, 1710}, {4440, 4650, 4860}, {7590, 7800, 8010}, {10740, 10950, 11160}},
{{1440, 1675, 1910}, {4965, 5200, 5435}, {8490, 8725, 8960}, {12015, 12250, 12485}},
{{1590, 1850, 2110}, {5490, 5750, 6010}, {9390, 9650, 9910}, {13290, 13550, 13810}},
{{1740, 2025, 2310}, {6015, 6300, 6585}, {10290, 10575, 10860}, {14565, 14850, 15135}}}},
{{{{1890, 2200, 2510}, {6540, 6850, 7160}, {11190, 11500, 11810}, {15840, 16150, 16460}},
{{2040, 2375, 2710}, {7065, 7400, 7735}, {12090, 12425, 12760}, {17115, 17450, 17785}},
{{2190, 2550, 2910}, {7590, 7950, 8310}, {12990, 13350, 13710}, {18390, 18750, 19110}},
{{2340, 2725, 3110}, {8115, 8500, 8885}, {13890, 14275, 14660}, {19665, 20050, 20435}}},
{{{2490, 2900, 3310}, {8640, 9050, 9460}, {14790, 15200, 15610}, {20940, 21350, 21760}},
{{2640, 3075, 3510}, {9165, 9600, 10035}, {15690, 16125, 16560}, {22215, 22650, 23085}},
{{2790, 3250, 3710}, {9690, 10150, 10610}, {16590, 17050, 17510}, {23490, 23950, 24410}},
{{2940, 3425, 3910}, {10215, 10700, 11185}, {17490, 17975, 18460}, {24765, 25250, 25735}}},
{{{3090, 3600, 4110}, {10740, 11250, 11760}, {18390, 18900, 19410}, {26040, 26550, 27060}},
{{3240, 3775, 4310}, {11265, 11800, 12335}, {19290, 19825, 20360}, {27315, 27850, 28385}},
{{3390, 3950, 4510}, {11790, 12350, 12910}, {20190, 20750, 21310}, {28590, 29150, 29710}},
{{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}}
};
auto xres = xt::linalg::dot(py_a, py_b);
EXPECT_TRUE(xt::allclose(xres, py_dr));
}
} // namespace xt
|