File: test_tensordot.cpp

package info (click to toggle)
xtensor-blas 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,068 kB
  • sloc: cpp: 97,896; makefile: 202; perl: 178; python: 153
file content (207 lines) | stat: -rw-r--r-- 8,076 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
/***************************************************************************
 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
 * Copyright (c) QuantStack                                                 *
 *                                                                          *
 * Distributed under the terms of the BSD 3-Clause License.                 *
 *                                                                          *
 * The full license is in the file LICENSE, distributed with this software. *
 ****************************************************************************/

#include "xtensor/xarray.hpp"
#include "xtensor/xbuilder.hpp"
#include "xtensor/xstrided_view.hpp"
#include "xtensor/xview.hpp"

#include "gtest/gtest.h"
#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    TEST(xtensordot, outer_product)
    {
        xarray<double> a = xt::ones<double>({3, 3, 3});
        xarray<double> b = xt::ones<double>({2, 2}) * 5.0;
        xarray<double> e1 = xt::ones<double>({3, 3, 3, 2, 2}) * 5.0;

        auto r1 = linalg::tensordot(a, b, 0);
        EXPECT_EQ(e1, r1);
    }

    TEST(xtensordot, outer_product_cm)
    {
        xarray<float, layout_type::column_major> a = xt::ones<float>({3, 3, 3});
        xarray<float, layout_type::column_major> b = xt::ones<float>({2, 2}) * 5.0;
        xarray<float, layout_type::column_major> e1 = xt::ones<float>({3, 3, 3, 2, 2}) * 5.0;

        auto r1 = linalg::tensordot(a, b, 0);
        EXPECT_EQ(e1, r1);
    }

    TEST(xtensordot, outer_product_mixed_layout)
    {
        xarray<float, layout_type::column_major> a = xt::ones<float>({3, 3, 3});
        xarray<float> b = xt::ones<float>({2, 2}) * 5.0;
        xarray<float, layout_type::column_major> e1 = xt::ones<float>({3, 3, 3, 2, 2}) * 5.0;

        auto r1 = linalg::tensordot(a, b, 0);
        EXPECT_EQ(e1, r1);

        xarray<float> e2 = xt::ones<float>({2, 2, 3, 3, 3}) * 5.0;
        auto r2 = linalg::tensordot(b, a, 0);
        EXPECT_EQ(e2, r2);
    }

    TEST(xtensordot, inner_product)
    {
        xarray<double> a = xt::ones<double>({3, 3, 2, 2});
        xarray<double> b = xt::ones<double>({2, 2, 10});
        auto r1 = linalg::tensordot(a, b);
        EXPECT_TRUE(all(equal(r1, 4)));
        EXPECT_TRUE(r1.shape().size() == 3);
        EXPECT_TRUE(r1.shape()[0] == 3);
        EXPECT_TRUE(r1.shape()[1] == 3);
        EXPECT_TRUE(r1.shape()[2] == 10);

        EXPECT_THROW(linalg::tensordot(a, b, 3), std::runtime_error);
        EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error);
    }

    TEST(xtensordot, inner_product_cm)
    {
        xarray<double, layout_type::column_major> a = xt::ones<double>({3, 3, 2, 2});
        xarray<double, layout_type::column_major> b = xt::ones<double>({2, 2, 10});
        auto r1 = linalg::tensordot(a, b);
        EXPECT_TRUE(all(equal(r1, 4)));
        EXPECT_TRUE(r1.shape().size() == 3);
        EXPECT_TRUE(r1.shape()[0] == 3);
        EXPECT_TRUE(r1.shape()[1] == 3);
        EXPECT_TRUE(r1.shape()[2] == 10);

        EXPECT_THROW(linalg::tensordot(a, b, 3), std::runtime_error);
        EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error);
    }

    TEST(xtensordot, inner_product_mixed_layout)
    {
        xarray<double> a = xt::ones<double>({3, 3, 2, 2});
        xarray<double, layout_type::column_major> b = xt::ones<double>({3, 2, 2, 10});
        auto r1 = linalg::tensordot(a, b, 3);
        EXPECT_TRUE(all(equal(r1, 12.0)));
        EXPECT_TRUE(r1.shape().size() == 2);
        EXPECT_TRUE(r1.shape()[0] == 3);
        EXPECT_TRUE(r1.shape()[1] == 10);

        EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error);
    }

    TEST(xtensordot, tuple_ax)
    {
        xarray<double> a = {
            {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}},
            {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}},
            {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}
        };
        xarray<double> b = xt::ones<double>({2, 3, 2, 3});
        auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1});
        xarray<double> e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}};
        EXPECT_EQ(r1, e1);
        auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3});
        xarray<double> e2 = xarray<double>::from_shape({1, 1});
        e2(0, 0) = 630;
        EXPECT_EQ(r2(0, 0), e2(0, 0));
    }

    TEST(xtensordot, tuple_ax_cm)
    {
        xarray<double, layout_type::column_major> a = {
            {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}},
            {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}},
            {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}
        };
        xarray<double, layout_type::column_major> b = xt::ones<double>({2, 3, 2, 3});
        auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1});
        xarray<double, layout_type::column_major> e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}};
        EXPECT_EQ(r1, e1);
        auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3});
        xarray<double, layout_type::column_major> e2 = xarray<double>::from_shape({1, 1});
        e2(0, 0) = 630;
        EXPECT_EQ(r2(0, 0), e2(0, 0));
    }

    TEST(xtensordot, tuple_ax_mixed_layout)
    {
        xarray<double, layout_type::column_major> a = {
            {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}},
            {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}},
            {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}
        };
        xarray<double> b = xt::ones<double>({2, 3, 2, 3});
        auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1});
        xarray<double, layout_type::column_major> e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}};
        EXPECT_EQ(r1, e1);

        auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3});
        xarray<double, layout_type::column_major> e2 = {630};

        EXPECT_EQ(r2, e2);
    }

    TEST(xtensordot, view)
    {
        xarray<int> a = reshape_view(arange<int>(3 * 2 * 3 * 2), {3, 2, 3, 2});
        xarray<int> b = reshape_view(arange<int>(3 * 3 * 2 * 2), {3, 3, 2, 2});

        xarray<int> e1 = {{34, 90, 146}, {46, 134, 222}, {58, 178, 298}};

        auto res1 = linalg::tensordot(
            view(a, 0, all(), all(), all()),
            view(b, 0, all(), all(), all()),
            {0, 2},
            {1, 2}
        );

        EXPECT_EQ(res1, e1);
        EXPECT_EQ(res1.dimension(), 2u);
        EXPECT_EQ(res1.shape()[0], 3u);
        EXPECT_EQ(res1.shape()[1], 3u);
    }

    TEST(xtensordot, strided_view_range)
    {
        xarray<int> a = reshape_view(arange<int>(3 * 2 * 3 * 2), {3, 2, 3, 2});
        xarray<int> b = reshape_view(arange<int>(3 * 3 * 2 * 2), {3, 3, 2, 2});

        xarray<int> e1 = {{1064, 1144}, {1136, 1224}};

        auto res1 = linalg::tensordot(
            strided_view(a, {range(0, 2), all(), range(0, 2), all()}),
            strided_view(b, {range(0, 2), range(0, 2), all(), all()}),
            {0, 1, 2},
            {0, 1, 2}
        );
        EXPECT_EQ(res1, e1);
        EXPECT_EQ(res1.dimension(), 2u);
        EXPECT_EQ(res1.shape()[0], 2u);
        EXPECT_EQ(res1.shape()[1], 2u);
    }

    TEST(xtensordot, reducing_dim_view)
    {
        xarray<int> a = reshape_view(arange<int>(3 * 2 * 3 * 2), {3, 2, 3, 2});
        xarray<int> b = reshape_view(arange<int>(3 * 3 * 2 * 2), {3, 3, 2, 2});

        xarray<int> e = {1589};
        auto r = linalg::tensordot(view(a, 0, 1, all(), all()), view(b, 2, all(), 1, all()));
        EXPECT_EQ(r, e);
    }

    TEST(xtensordot, reducing_dim_strided_view)
    {
        xarray<int> a = reshape_view(arange<int>(3 * 2 * 3 * 2), {3, 2, 3, 2});
        xarray<int> b = reshape_view(arange<int>(3 * 3 * 2 * 2), {3, 3, 2, 2});

        xarray<int> e = {1589};
        auto r = linalg::tensordot(strided_view(a, {0, 1, all(), all()}), strided_view(b, {2, all(), 1, all()}));
        EXPECT_EQ(r, e);
    }
}  // namespace xt