File: matrix_default_mul.h

package info (click to toggle)
mldemos 0.5.1-3
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 32,224 kB
  • ctags: 46,525
  • sloc: cpp: 306,887; ansic: 167,718; ml: 126; sh: 109; makefile: 2
file content (134 lines) | stat: -rw-r--r-- 4,751 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_
#define DLIB_MATRIx_DEFAULT_MULTIPLY_

#include "../geometry.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "../enable_if.h"

namespace dlib
{

// ------------------------------------------------------------------------------------

    namespace ma
    {
        template < typename EXP, typename enable = void >
        struct matrix_is_vector { static const bool value = false; };
        template < typename EXP >
        struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
    }

// ------------------------------------------------------------------------------------

    /*!  This file defines the default_matrix_multiply() function.  It is a function 
         that conforms to the following definition:

        template <
            typename matrix_dest_type,
            typename EXP1,
            typename EXP2
            >
        void default_matrix_multiply (
            matrix_dest_type& dest,
            const EXP1& lhs,
            const EXP2& rhs
        );
            requires
                - (lhs*rhs).destructively_aliases(dest) == false
                - dest.nr() == (lhs*rhs).nr()
                - dest.nc() == (lhs*rhs).nc()
            ensures
                - #dest == dest + lhs*rhs
    !*/

// ------------------------------------------------------------------------------------

    template <
        typename matrix_dest_type,
        typename EXP1,
        typename EXP2
        >
    typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type 
    default_matrix_multiply (
        matrix_dest_type& dest,
        const EXP1& lhs,
        const EXP2& rhs
    )
    {
        matrix_assign_default(dest, lhs*rhs, 1, true);
    }

// ------------------------------------------------------------------------------------

    template <
        typename matrix_dest_type,
        typename EXP1,
        typename EXP2
        >
    typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type 
    default_matrix_multiply (
        matrix_dest_type& dest,
        const EXP1& lhs,
        const EXP2& rhs
    )
    {
        const long bs = 90;

        // if the matrices are small enough then just use the simple multiply algorithm
        if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
        {
            matrix_assign_default(dest, lhs*rhs, 1, true);
        }
        else
        {
            // if the lhs and rhs matrices are big enough we should use a cache friendly
            // algorithm that computes the matrix multiply in blocks.  


            // Loop over all the blocks in the lhs matrix
            for (long r = 0; r < lhs.nr(); r+=bs)
            {
                for (long c = 0; c < lhs.nc(); c+=bs)
                {
                    // make a rect for the block from lhs 
                    rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1));

                    // now loop over all the rhs blocks we have to multiply with the current lhs block
                    for (long i = 0; i < rhs.nc(); i += bs)
                    {
                        // make a rect for the block from rhs 
                        rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1));

                        // make a target rect in res
                        rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());

                        // This loop is optimized assuming that the data is laid out in 
                        // row major order in memory.
                        for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r)
                        {
                            for (long c = lhs_block.left(); c<= lhs_block.right(); ++c)
                            {
                                const typename EXP2::type temp = lhs(r,c);
                                for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
                                {
                                    dest(r,i) += rhs(c,i)*temp;
                                }
                            }
                        }
                    }
                }
            }
        }


    }

// ------------------------------------------------------------------------------------

}

#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_