File: test_blasdot.py

package info (click to toggle)
python-numpy 1%3A1.8.2-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 21,336 kB
  • ctags: 18,503
  • sloc: ansic: 149,662; python: 85,440; cpp: 968; makefile: 367; f90: 164; sh: 130; fortran: 125; perl: 58
file content (153 lines) | stat: -rw-r--r-- 5,483 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from __future__ import division, absolute_import, print_function

import numpy as np
import sys
from numpy.core import zeros, float64
from numpy.testing import dec, TestCase, assert_almost_equal, assert_, \
     assert_raises, assert_array_equal, assert_allclose, assert_equal
from numpy.core.multiarray import inner as inner_

DECPREC = 14

class TestInner(TestCase):
    def test_vecself(self):
        """Ticket 844."""
        # Inner product of a vector with itself segfaults or give meaningless
        # result
        a = zeros(shape = (1, 80), dtype = float64)
        p = inner_(a, a)
        assert_almost_equal(p, 0, decimal = DECPREC)

try:
    import numpy.core._dotblas as _dotblas
except ImportError:
    _dotblas = None

@dec.skipif(_dotblas is None, "Numpy is not compiled with _dotblas")
def test_blasdot_used():
    from numpy.core import dot, vdot, inner, alterdot, restoredot
    assert_(dot is _dotblas.dot)
    assert_(vdot is _dotblas.vdot)
    assert_(inner is _dotblas.inner)
    assert_(alterdot is _dotblas.alterdot)
    assert_(restoredot is _dotblas.restoredot)


def test_dot_2args():
    from numpy.core import dot

    a = np.array([[1, 2], [3, 4]], dtype=float)
    b = np.array([[1, 0], [1, 1]], dtype=float)
    c = np.array([[3, 2], [7, 4]], dtype=float)

    d = dot(a, b)
    assert_allclose(c, d)

def test_dot_3args():
    np.random.seed(22)
    f = np.random.random_sample((1024, 16))
    v = np.random.random_sample((16, 32))

    r = np.empty((1024, 32))
    for i in range(12):
        np.dot(f, v, r)
    assert_equal(sys.getrefcount(r), 2)
    r2 = np.dot(f, v, out=None)
    assert_array_equal(r2, r)
    assert_(r is np.dot(f, v, out=r))

    v = v[:, 0].copy() # v.shape == (16,)
    r = r[:, 0].copy() # r.shape == (1024,)
    r2 = np.dot(f, v)
    assert_(r is np.dot(f, v, r))
    assert_array_equal(r2, r)

def test_dot_3args_errors():
    np.random.seed(22)
    f = np.random.random_sample((1024, 16))
    v = np.random.random_sample((16, 32))

    r = np.empty((1024, 31))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((1024,))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((32,))
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((32, 1024))
    assert_raises(ValueError, np.dot, f, v, r)
    assert_raises(ValueError, np.dot, f, v, r.T)

    r = np.empty((1024, 64))
    assert_raises(ValueError, np.dot, f, v, r[:, ::2])
    assert_raises(ValueError, np.dot, f, v, r[:, :32])

    r = np.empty((1024, 32), dtype=np.float32)
    assert_raises(ValueError, np.dot, f, v, r)

    r = np.empty((1024, 32), dtype=int)
    assert_raises(ValueError, np.dot, f, v, r)

def test_dot_array_order():
    """ Test numpy dot with different order C, F

    Comparing results with multiarray dot.
    Double and single precisions array are compared using relative
    precision of 7 and 5 decimals respectively.
    Use 30 decimal when comparing exact operations like:
        (a.b)' = b'.a'
    """
    _dot = np.core.multiarray.dot
    a_dim, b_dim, c_dim = 10, 4, 7
    orders = ["C", "F"]
    dtypes_prec = {np.float64: 7, np.float32: 5}
    np.random.seed(7)

    for arr_type, prec in dtypes_prec.items():
        for a_order in orders:
            a = np.asarray(np.random.randn(a_dim, a_dim),
                dtype=arr_type, order=a_order)
            assert_array_equal(np.dot(a, a), a.dot(a))
            # (a.a)' = a'.a', note that mse~=1e-31 needs almost_equal
            assert_almost_equal(a.dot(a), a.T.dot(a.T).T, decimal=prec)

            #
            # Check with making explicit copy
            #
            a_T = a.T.copy(order=a_order)
            assert_almost_equal(a_T.dot(a_T), a.T.dot(a.T), decimal=prec)
            assert_almost_equal(a.dot(a_T), a.dot(a.T), decimal=prec)
            assert_almost_equal(a_T.dot(a), a.T.dot(a), decimal=prec)

            #
            # Compare with multiarray dot
            #
            assert_almost_equal(a.dot(a), _dot(a, a), decimal=prec)
            assert_almost_equal(a.T.dot(a), _dot(a.T, a), decimal=prec)
            assert_almost_equal(a.dot(a.T), _dot(a, a.T), decimal=prec)
            assert_almost_equal(a.T.dot(a.T), _dot(a.T, a.T), decimal=prec)
            for res in a.dot(a), a.T.dot(a), a.dot(a.T), a.T.dot(a.T):
                assert res.flags.c_contiguous

            for b_order in orders:
                b = np.asarray(np.random.randn(a_dim, b_dim),
                    dtype=arr_type, order=b_order)
                b_T = b.T.copy(order=b_order)
                assert_almost_equal(a_T.dot(b), a.T.dot(b), decimal=prec)
                assert_almost_equal(b_T.dot(a), b.T.dot(a), decimal=prec)
                # (b'.a)' = a'.b
                assert_almost_equal(b.T.dot(a), a.T.dot(b).T, decimal=prec)
                assert_almost_equal(a.dot(b), _dot(a, b), decimal=prec)
                assert_almost_equal(b.T.dot(a), _dot(b.T, a), decimal=prec)


                for c_order in orders:
                    c = np.asarray(np.random.randn(b_dim, c_dim),
                        dtype=arr_type, order=c_order)
                    c_T = c.T.copy(order=c_order)
                    assert_almost_equal(c.T.dot(b.T), c_T.dot(b_T), decimal=prec)
                    assert_almost_equal(c.T.dot(b.T).T, b.dot(c), decimal=prec)
                    assert_almost_equal(b.dot(c), _dot(b, c), decimal=prec)
                    assert_almost_equal(c.T.dot(b.T), _dot(c.T, b.T), decimal=prec)