File: check_gemm.hh

package info (click to toggle)
blaspp 2024.10.26-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,636 kB
  • sloc: cpp: 29,332; ansic: 8,448; python: 2,192; xml: 182; perl: 101; makefile: 53; sh: 7
file content (163 lines) | stat: -rw-r--r-- 5,044 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#ifndef CHECK_GEMM_HH
#define CHECK_GEMM_HH

#include "blas/util.hh"

// Test headers.
#include "lapack_wrappers.hh"

#include <limits>

// -----------------------------------------------------------------------------
// Computes error for multiplication with general matrix result.
// Covers dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm?.
// Cnorm is norm of original C, before multiplication operation.
template <typename T>
void check_gemm(
    int64_t m, int64_t n, int64_t k,
    T alpha,
    T beta,
    blas::real_type<T> Anorm,
    blas::real_type<T> Bnorm,
    blas::real_type<T> Cnorm,
    T const* Cref, int64_t ldcref,
    T* C, int64_t ldc,
    bool verbose,
    blas::real_type<T> error[1],
    bool* okay )
{
    #define    C(i_, j_)    C[ (i_) + (j_)*ldc ]
    #define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]

    using std::sqrt;
    using std::abs;
    using real_t = blas::real_type<T>;

    assert( m >= 0 );
    assert( n >= 0 );
    assert( k >= 0 );
    assert( ldc >= m );
    assert( ldcref >= m );

    // C -= Cref
    for (int64_t j = 0; j < n; ++j) {
        for (int64_t i = 0; i < m; ++i) {
            C(i, j) -= Cref(i, j);
        }
    }

    real_t work[1], Cout_norm;
    Cout_norm = lapack_lange( "f", m, n, C, ldc, work );
    error[0] = Cout_norm
             / (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
                 + 2 * abs( beta ) * Cnorm);
    if (verbose) {
        printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)"
                " * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
                " + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e\n",
                Cout_norm, llong( k ),
                abs( alpha ), Anorm, Bnorm,
                abs( beta ), Cnorm, error[0] );
    }

    // complex needs extra factor; see Higham, 2002, sec. 3.6.
    if (blas::is_complex<T>::value) {
        error[0] /= 2*sqrt(2);
    }

    real_t u = 0.5 * std::numeric_limits< real_t >::epsilon();
    *okay = (error[0] < u);

    #undef C
    #undef Cref
}

// -----------------------------------------------------------------------------
// Computes error for multiplication with symmetric or Hermitian matrix result.
// Covers syr, syr2, syrk, syr2k, her, her2, herk, her2k.
// Cnorm is norm of original C, before multiplication operation.
//
// alpha and beta are either real or complex, depending on routine:
//          zher    zher2   zherk   zher2k  zsyr    zsyr2   zsyrk   zsyr2k
// alpha    real    complex real    complex complex complex complex complex
// beta     --      --      real    real    --      --      complex complex
// zsyr2 doesn't exist in standard BLAS or LAPACK.
template <typename TA, typename TB, typename T>
void check_herk(
    blas::Uplo uplo,
    int64_t n, int64_t k,
    TA alpha,
    TB beta,
    blas::real_type<T> Anorm,
    blas::real_type<T> Bnorm,
    blas::real_type<T> Cnorm,
    T const* Cref, int64_t ldcref,
    T* C, int64_t ldc,
    bool verbose,
    blas::real_type<T> error[1],
    bool* okay )
{
    #define    C(i_, j_)    C[ (i_) + (j_)*ldc ]
    #define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]

    using std::sqrt;
    using std::abs;
    typedef blas::real_type<T> real_t;

    assert( n >= 0 );
    assert( k >= 0 );
    assert( ldc >= n );
    assert( ldcref >= n );

    // C -= Cref
    if (uplo == blas::Uplo::Lower) {
        for (int64_t j = 0; j < n; ++j) {
            for (int64_t i = j; i < n; ++i) {
                C(i, j) -= Cref(i, j);
            }
        }
    }
    else {
        for (int64_t j = 0; j < n; ++j) {
            for (int64_t i = 0; i <= j; ++i) {
                C(i, j) -= Cref(i, j);
            }
        }
    }

    // For a rank-2k update, this should be
    // sqrt(k+3) |alpha| (norm(A)*norm(B^T) + norm(B)*norm(A^T))
    //     + 3 |beta| norm(C)
    // However, so far using the same bound as rank-k works fine.
    real_t work[1], Cout_norm;
    Cout_norm = lapack_lanhe( "f", to_c_string( uplo ), n, C, ldc, work );
    error[0] = Cout_norm
             / (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
                 + 2 * abs( beta ) * Cnorm);
    if (verbose) {
        printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)"
                " * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
                " + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e\n",
                Cout_norm, llong( k ),
                abs( alpha ), Anorm, Bnorm,
                abs( beta ), Cnorm, error[0] );
    }

    // complex needs extra factor; see Higham, 2002, sec. 3.6.
    if (blas::is_complex<T>::value) {
        error[0] /= 2*sqrt(2);
    }

    real_t u = 0.5 * std::numeric_limits< real_t >::epsilon();
    *okay = (error[0] < u);

    #undef C
    #undef Cref
}

#endif        //  #ifndef CHECK_GEMM_HH