File: libxsmm_spmdm.h

package info (click to toggle)
libxsmm 1.9-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 9,800 kB
  • sloc: ansic: 70,040; fortran: 5,281; makefile: 3,333; cpp: 3,185; sh: 2,136; f90: 1,763; pascal: 1,469; python: 762
file content (134 lines) | stat: -rw-r--r-- 5,464 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/******************************************************************************
** Copyright (c) 2016-2018, Intel Corporation                                **
** All rights reserved.                                                      **
**                                                                           **
** Redistribution and use in source and binary forms, with or without        **
** modification, are permitted provided that the following conditions        **
** are met:                                                                  **
** 1. Redistributions of source code must retain the above copyright         **
**    notice, this list of conditions and the following disclaimer.          **
** 2. Redistributions in binary form must reproduce the above copyright      **
**    notice, this list of conditions and the following disclaimer in the    **
**    documentation and/or other materials provided with the distribution.   **
** 3. Neither the name of the copyright holder nor the names of its          **
**    contributors may be used to endorse or promote products derived        **
**    from this software without specific prior written permission.          **
**                                                                           **
** THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS       **
** "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT         **
** LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR     **
** A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT      **
** HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,    **
** SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED  **
** TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR    **
** PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF    **
** LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING      **
** NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS        **
** SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.              **
******************************************************************************/
/* Nadathur Satish (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_SPMDM_H
#define LIBXSMM_SPMDM_H

#include "libxsmm_macros.h"


typedef enum libxsmm_spmdm_datatype {
  LIBXSMM_SPMDM_DATATYPE_F32,
  LIBXSMM_SPMDM_DATATYPE_BFLOAT16
} libxsmm_spmdm_datatype;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spmdm_handle {
  /* The following are the matrix multiply dimensions: A (sparse): m X k, B (dense): k X n, Output C (dense): m X n */
  int m;
  int n;
  int k;
  /* The block sizes for A, B and C. */
  /* Here we fix A to be divided into 128 X 128 blocks, B/C to be 128 X 48 for HSW/BDW and 128 X 96 for SKX */
  int bm;
  int bn;
  int bk;
  /* The number of blocks for the m, n and k dimensions */
  int mb;
  int nb;
  int kb;
  libxsmm_spmdm_datatype datatype;
  char * base_ptr_scratch_A;
  char * base_ptr_scratch_B_scratch_C;
  int memory_for_scratch_per_thread;
} libxsmm_spmdm_handle;

/**
 * This stores a single sparse splice (or block) of sparse matrix A using a CSR representation (rowidx, colidx, and values
 * Each splice corresponds to a bm X bk region of A, and stores local indexes
 */
LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_CSR_sparseslice {
  /* Since bm and bk are assumed to be <=256, a 16-bit integer is enough to store the local rowidx, colidx */
  uint16_t * rowidx;
  uint16_t * colidx;
  float*     values;
} libxsmm_CSR_sparseslice;


LIBXSMM_API void libxsmm_spmdm_init(
  int M, int N, int K,
  int max_threads,
  libxsmm_spmdm_handle* handle,
  libxsmm_CSR_sparseslice** libxsmm_output_csr);

LIBXSMM_API void libxsmm_spmdm_destroy(
  libxsmm_spmdm_handle * handle);

LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks(
  const libxsmm_spmdm_handle* handle);

LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks(
  const libxsmm_spmdm_handle* handle);

/** This converts a dense representation of the sparse matrix to 2D array of sparse slices. */
LIBXSMM_API void libxsmm_spmdm_createSparseSlice_fp32_thread(
  const libxsmm_spmdm_handle* handle,
  char transA,
  const float * A,
  libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
  int block_id,
  int tid, int nthreads);

LIBXSMM_API void libxsmm_spmdm_createSparseSlice_bfloat16_thread(
  const libxsmm_spmdm_handle* handle,
  char transA,
  const uint16_t * A,
  libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
  int block_id,
  int tid, int nthreads);

/** NOTE: This code currently ignores alpha input to the matrix multiply */
LIBXSMM_API void libxsmm_spmdm_compute_fp32_thread(
  const libxsmm_spmdm_handle* handle,
  char transA,
  char transB,
  const float *alpha,
  libxsmm_CSR_sparseslice* A_sparse,
  const float *B,
  char transC,
  const float *beta,
  float* C,
  int block_id,
  int tid, int nthreads);

/** NOTE: This code currently ignores alpha input to the matrix multiply */
LIBXSMM_API void libxsmm_spmdm_compute_bfloat16_thread(
  const libxsmm_spmdm_handle* handle,
  char transA,
  char transB,
  const uint16_t *alpha,
  libxsmm_CSR_sparseslice* A_sparse,
  const uint16_t *B,
  char transC,
  const uint16_t *beta,
  float* C,
  int block_id,
  int tid, int nthreads);

#endif /*LIBXSMM_SPMDM_H*/