File: test_blas.cpp

package info (click to toggle)
faiss 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,572 kB
  • sloc: cpp: 85,627; python: 27,889; sh: 905; ansic: 425; makefile: 41
file content (110 lines) | stat: -rw-r--r-- 2,479 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <cstdio>
#include <cstdlib>
#include <random>

#undef FINTEGER
#define FINTEGER long

extern "C" {

/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */

int sgemm_(
        const char* transa,
        const char* transb,
        FINTEGER* m,
        FINTEGER* n,
        FINTEGER* k,
        const float* alpha,
        const float* a,
        FINTEGER* lda,
        const float* b,
        FINTEGER* ldb,
        float* beta,
        float* c,
        FINTEGER* ldc);

/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */

int sgeqrf_(
        FINTEGER* m,
        FINTEGER* n,
        float* a,
        FINTEGER* lda,
        float* tau,
        float* work,
        FINTEGER* lwork,
        FINTEGER* info);
}

float* new_random_vec(int size) {
    float* x = new float[size];
    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;
    for (int i = 0; i < size; i++)
        x[i] = distrib(rng);
    return x;
}

int main() {
    FINTEGER m = 10, n = 20, k = 30;
    float *a = new_random_vec(m * k), *b = new_random_vec(n * k),
          *c = new float[n * m];
    float one = 1.0, zero = 0.0;

    printf("BLAS test\n");

    sgemm_("Not transposed",
           "Not transposed",
           &m,
           &n,
           &k,
           &one,
           a,
           &m,
           b,
           &k,
           &zero,
           c,
           &m);

    printf("errors=\n");

    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            float accu = 0;
            for (int l = 0; l < k; l++)
                accu += a[i + l * m] * b[l + j * k];
            printf("%6.3f ", accu - c[i + j * m]);
        }
        printf("\n");
    }

    long info = 0x64bL << 32;
    long mi = 0x64bL << 32 | m;
    float* tau = new float[m];
    FINTEGER lwork = -1;

    float work1;

    printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
    sgeqrf_(&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);

    // sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
    printf("info=%016lx\n", info);

    if (info >> 32 == 0x64b) {
        printf("Lapack uses 32-bit integers\n");
    } else {
        printf("Lapack uses 64-bit integers\n");
    }

    return 0;
}