1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
//------------------------------------------------------------------------------
// GB_mex_band: C = tril (triu (A,lo), hi), or with A'
//------------------------------------------------------------------------------
// SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2022, All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//------------------------------------------------------------------------------
// Apply a select operator to a matrix
#include "GB_mex.h"
#define USAGE "C = GB_mex_band (A, lo, hi, atranspose)"
#define FREE_ALL \
{ \
GrB_Scalar_free_(&Thunk) ; \
GrB_Matrix_free_(&C) ; \
GrB_Matrix_free_(&A) ; \
GrB_Scalar_free_(&Thunk_type) ; \
GxB_SelectOp_free_(&op) ; \
GrB_Descriptor_free_(&desc) ; \
GB_mx_put_global (true) ; \
}
#define OK(method) \
{ \
info = method ; \
if (info != GrB_SUCCESS) \
{ \
FREE_ALL ; \
mexErrMsgTxt ("GraphBLAS failed") ; \
} \
}
typedef struct
{
int64_t lo ;
int64_t hi ;
} LoHi_type ;
bool LoHi_band (GrB_Index i, GrB_Index j,
/* x is unused: */ const void *x, const LoHi_type *thunk) ;
bool LoHi_band (GrB_Index i, GrB_Index j,
/* x is unused: */ const void *x, const LoHi_type *thunk)
{
int64_t i2 = (int64_t) i ;
int64_t j2 = (int64_t) j ;
return ((thunk->lo <= (j2-i2)) && ((j2-i2) <= thunk->hi)) ;
}
void mexFunction
(
int nargout,
mxArray *pargout [ ],
int nargin,
const mxArray *pargin [ ]
)
{
bool malloc_debug = GB_mx_get_global (true) ;
GrB_Matrix C = NULL ;
GrB_Matrix A = NULL ;
GxB_SelectOp op = NULL ;
GrB_Info info ;
GrB_Descriptor desc = NULL ;
GrB_Scalar Thunk = NULL ;
GrB_Type Thunk_type = NULL ;
#define GET_DEEP_COPY ;
#define FREE_DEEP_COPY ;
// check inputs
if (nargout > 1 || nargin < 3 || nargin > 4)
{
mexErrMsgTxt ("Usage: " USAGE) ;
}
// get A (shallow copy)
A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
if (A == NULL)
{
FREE_ALL ;
mexErrMsgTxt ("A failed") ;
}
// create the Thunk
LoHi_type bandwidth ;
OK (GrB_Type_new (&Thunk_type, sizeof (LoHi_type))) ;
// get lo and hi
bandwidth.lo = (int64_t) mxGetScalar (pargin [1]) ;
bandwidth.hi = (int64_t) mxGetScalar (pargin [2]) ;
OK (GrB_Scalar_new (&Thunk, Thunk_type)) ;
OK (GrB_Scalar_setElement_UDT (Thunk, (void *) &bandwidth)) ;
OK (GrB_Scalar_wait_(Thunk, GrB_MATERIALIZE)) ;
// get atranspose
bool atranspose = false ;
if (nargin > 3) atranspose = (bool) mxGetScalar (pargin [3]) ;
if (atranspose)
{
OK (GrB_Descriptor_new (&desc)) ;
OK (GxB_Desc_set (desc, GrB_INP0, GrB_TRAN)) ;
}
// create operator
// use the user-defined operator, from the LoHi_band function
METHOD (GxB_SelectOp_new (&op, (GxB_select_function) LoHi_band,
NULL, Thunk_type)) ;
GrB_Index nrows, ncols ;
GrB_Matrix_nrows (&nrows, A) ;
GrB_Matrix_ncols (&ncols, A) ;
if (bandwidth.lo == 0 && bandwidth.hi == 0 && nrows == 10 && ncols == 10)
{
GxB_SelectOp_fprint_ (op, 3, NULL) ;
}
// create result matrix C
if (atranspose)
{
OK (GrB_Matrix_new (&C, GrB_FP64, A->vdim, A->vlen)) ;
}
else
{
OK (GrB_Matrix_new (&C, GrB_FP64, A->vlen, A->vdim)) ;
}
// C<Mask> = accum(C,op(A))
if (GB_NCOLS (C) == 1 && !atranspose)
{
// this is just to test the Vector version
OK (GxB_Vector_select_((GrB_Vector) C, NULL, NULL, op, (GrB_Vector) A,
Thunk, NULL)) ;
}
else
{
OK (GxB_Matrix_select_(C, NULL, NULL, op, A, Thunk, desc)) ;
}
// return C as a sparse matrix and free the GraphBLAS C
pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", false) ;
FREE_ALL ;
}
|