File: GB_AxB_dot_generic.c

package info (click to toggle)
suitesparse 1%3A7.10.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, trixie
  • size: 254,920 kB
  • sloc: ansic: 1,134,743; cpp: 46,133; makefile: 4,875; fortran: 2,087; java: 1,826; sh: 996; ruby: 725; python: 495; asm: 371; sed: 166; awk: 44
file content (219 lines) | stat: -rw-r--r-- 8,178 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
//------------------------------------------------------------------------------
// GB_AxB_dot_generic: generic template for all dot-product methods
//------------------------------------------------------------------------------

// SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2025, All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

//------------------------------------------------------------------------------

// This template serves the dot2 and dot3 methods, but not dot4, since dot4 is
// not implemented for generic kernels.  The #including file defines
// GB_DOT2_GENERIC or GB_DOT3_GENERIC.

// This file does not use GB_DECLARE_TERMINAL_CONST (zterminal).  Instead, it
// defines zterminal itself.

#include "mxm/include/GB_mxm_shared_definitions.h"
#include "generic/GB_generic.h"

{

    //--------------------------------------------------------------------------
    // get operators, functions, workspace, contents of A, B, C
    //--------------------------------------------------------------------------

    ASSERT (!C->iso) ;

    GxB_binary_function fmult = mult->binop_function ;    // NULL if positional
    GxB_index_binary_function fmult_idx = mult->idxbinop_function ;
    GxB_binary_function fadd  = add->op->binop_function ;
    GB_Opcode opcode = mult->opcode ;
//  bool op_is_builtin_positional =
//      GB_IS_BUILTIN_BINOP_CODE_POSITIONAL (opcode) ;

    ASSERT (C->type == add->op->ztype) ;
    size_t csize = C->type->size ;
    size_t asize = A_is_pattern ? 0 : A->type->size ;
    size_t bsize = B_is_pattern ? 0 : B->type->size ;

    size_t zsize = csize ;      // C->type always matches add->op->ztype
    size_t xsize = mult->xtype->size ;
    size_t ysize = mult->ytype->size ;

    // scalar workspace: because of typecasting, the x/y types need not
    // be the same as the size of the A and B types.
    // flipxy false: aki = (xtype) A(k,i) and bkj = (ytype) B(k,j)
    // flipxy true:  aki = (ytype) A(k,i) and bkj = (xtype) B(k,j)
    size_t aki_size = flipxy ? ysize : xsize ;
    size_t bkj_size = flipxy ? xsize : ysize ;

    bool is_terminal = (add->terminal != NULL) ;

    GB_cast_function cast_A, cast_B ;
    if (flipxy)
    { 
        // A is typecasted to y, and B is typecasted to x
        cast_A = A_is_pattern ? NULL : 
                 GB_cast_factory (mult->ytype->code, A->type->code) ;
        cast_B = B_is_pattern ? NULL : 
                 GB_cast_factory (mult->xtype->code, B->type->code) ;
    }
    else
    { 
        // A is typecasted to x, and B is typecasted to y
        cast_A = A_is_pattern ? NULL :
                 GB_cast_factory (mult->xtype->code, A->type->code) ;
        cast_B = B_is_pattern ? NULL :
                 GB_cast_factory (mult->ytype->code, B->type->code) ;
    }

    //--------------------------------------------------------------------------
    // C = A'*B via dot products, function pointers, and typecasting
    //--------------------------------------------------------------------------

    // aki = A(i,k), located in Ax [A_iso?0:(pA)]
    #undef  GB_A_IS_PATTERN
    #define GB_A_IS_PATTERN 0
    #undef  GB_DECLAREA
    #define GB_DECLAREA(aki)                                        \
        GB_void aki [GB_VLA(aki_size)] ;
    #undef  GB_GETA
    #define GB_GETA(aki,Ax,pA,A_iso)                                \
        if (!A_is_pattern) cast_A (aki, Ax +((A_iso) ? 0:(pA)*asize), asize)

    // bkj = B(k,j), located in Bx [B_iso?0:pB]
    #undef  GB_B_IS_PATTERN
    #define GB_B_IS_PATTERN 0
    #undef  GB_DECLAREB
    #define GB_DECLAREB(bkj)                                        \
        GB_void bkj [GB_VLA(bkj_size)] ;
    #undef  GB_GETB
    #define GB_GETB(bkj,Bx,pB,B_iso)                                \
        if (!B_is_pattern) cast_B (bkj, Bx +((B_iso) ? 0:(pB)*bsize), bsize)

    // instead of GB_DECLARE_TERMINAL_CONST (zterminal):
    GB_void *restrict zterminal = (GB_void *) add->terminal ;
    GB_void *restrict zidentity = (GB_void *) add->identity ;

    // define cij for each task
    #undef  GB_DECLARE_IDENTITY
    #define GB_DECLARE_IDENTITY(cij)            \
        GB_void cij [GB_VLA(zsize)] ;           \
        memcpy (cij, zidentity, zsize)

    // Cx [p] = cij (note csize == zsize)
    #undef  GB_PUTC
    #define GB_PUTC(cij,Cx,p) memcpy (Cx +((p)*csize), cij, csize)

    // break if cij reaches the terminal value
    #undef  GB_IF_TERMINAL_BREAK
    #define GB_IF_TERMINAL_BREAK(z,zterminal)                       \
        if (is_terminal && memcmp (z, zterminal, zsize) == 0)       \
        {                                                           \
            break ;                                                 \
        }
    #undef  GB_TERMINAL_CONDITION
    #define GB_TERMINAL_CONDITION(z,zterminal)                      \
        (is_terminal && memcmp (z, zterminal, zsize) == 0)

    // C(i,j) += (A')(i,k) * B(k,j)
    #undef  GB_MULTADD
    #define GB_MULTADD(cij, aki, bkj, i, k, j)                      \
        GB_void zwork [GB_VLA(zsize)] ;                             \
        GB_MULT (zwork, aki, bkj, i, k, j) ;                        \
        fadd (cij, cij, zwork)

    // generic types for C and Z
    #undef  GB_C_TYPE
    #define GB_C_TYPE GB_void

    #undef  GB_Z_TYPE
    #define GB_Z_TYPE GB_void

    if (opcode == GB_FIRST_binop_code)
    { 
        // t = A(i,k)
        // fmult is not used and can be NULL (for user-defined types)
        ASSERT (!flipxy) ;
        ASSERT (B_is_pattern) ;
        #undef  GB_MULT
        #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, aik, zsize)
        #if defined ( GB_DOT2_GENERIC )
        #include "mxm/template/GB_AxB_dot2_meta.c"
        #elif defined ( GB_DOT3_GENERIC )
        #include "mxm/template/GB_AxB_dot3_meta.c"
        #endif
    }
    else if (opcode == GB_SECOND_binop_code)
    { 
        // t = B(i,k)
        // fmult is not used and can be NULL (for user-defined types)
        ASSERT (!flipxy) ;
        ASSERT (A_is_pattern) ;
        #undef  GB_MULT
        #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, bkj, zsize)
        #if defined ( GB_DOT2_GENERIC )
        #include "mxm/template/GB_AxB_dot2_meta.c"
        #elif defined ( GB_DOT3_GENERIC )
        #include "mxm/template/GB_AxB_dot3_meta.c"
        #endif
    }
    else if (fmult != NULL)
    {
        // standard binary op
        if (flipxy)
        { 
            // t = B(k,j) * (A')(i,k)
            #undef  GB_MULT
            #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, bkj, aki)
            #if defined ( GB_DOT2_GENERIC )
            #include "mxm/template/GB_AxB_dot2_meta.c"
            #elif defined ( GB_DOT3_GENERIC )
            #include "mxm/template/GB_AxB_dot3_meta.c"
            #endif
        }
        else
        { 
            // t = (A')(i,k) * B(k,j)
            #undef  GB_MULT
            #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, aki, bkj)
            #if defined ( GB_DOT2_GENERIC )
            #include "mxm/template/GB_AxB_dot2_meta.c"
            #elif defined ( GB_DOT3_GENERIC )
            #include "mxm/template/GB_AxB_dot3_meta.c"
            #endif
        }
    }
    else
    {
        // index binary op
        ASSERT (fmult_idx != NULL) ;
        const void *theta = mult->theta ;
        if (flipxy)
        { 
            // t = B(k,j) * (A')(i,k)
            #undef  GB_MULT
            #define GB_MULT(t, aki, bkj, i, k, j) \
                fmult_idx (t, bkj, j, k, aki, k, i, theta)
            #if defined ( GB_DOT2_GENERIC )
            #include "mxm/template/GB_AxB_dot2_meta.c"
            #elif defined ( GB_DOT3_GENERIC )
            #include "mxm/template/GB_AxB_dot3_meta.c"
            #endif
        }
        else
        { 
            // t = (A')(i,k) * B(k,j)
            #undef  GB_MULT
            #define GB_MULT(t, aki, bkj, i, k, j) \
                fmult_idx (t, aki, i, k, bkj, k, j, theta)
            #if defined ( GB_DOT2_GENERIC )
            #include "mxm/template/GB_AxB_dot2_meta.c"
            #elif defined ( GB_DOT3_GENERIC )
            #include "mxm/template/GB_AxB_dot3_meta.c"
            #endif
        }
    }
}