1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
// this is similar to blaspp/test/check_gemm.hh,
// except it uses LAPACK++ instead of calling Fortran LAPACK.
#ifndef CHECK_GEMM_HH
#define CHECK_GEMM_HH
#include "blas/util.hh"
#include "lapack.hh"
#include <limits>
// -----------------------------------------------------------------------------
// Computes error for multiplication with general matrix result.
// Covers dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm?.
// Cnorm is norm of original C, before multiplication operation.
template< typename T >
void check_gemm(
int64_t m, int64_t n, int64_t k,
T alpha,
T beta,
blas::real_type<T> Anorm,
blas::real_type<T> Bnorm,
blas::real_type<T> Cnorm,
T const* Cref, int64_t ldcref,
T* C, int64_t ldc,
blas::real_type<T> error[1],
int64_t* okay )
{
#define C(i_, j_) C[ (i_) + (j_)*ldc ]
#define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]
using real_t = blas::real_type<T>;
require( m >= 0 );
require( n >= 0 );
require( k >= 0 );
require( ldc >= m );
require( ldcref >= m );
// C -= Cref
for (int64_t j = 0; j < n; ++j) {
for (int64_t i = 0; i < m; ++i) {
C(i,j) -= Cref(i,j);
}
}
error[0] = lapack::lange( lapack::Norm::Fro, m, n, C, ldc )
/ (sqrt(real_t(k)+2)*std::abs(alpha)*Anorm*Bnorm + 2*std::abs(beta)*Cnorm);
// Allow 3*eps; complex needs 2*sqrt(2) factor; see Higham, 2002, sec. 3.6.
real_t eps = std::numeric_limits< real_t >::epsilon();
*okay = (error[0] < 3*eps);
#undef C
#undef Cref
}
// -----------------------------------------------------------------------------
// Computes error for multiplication with symmetric or Hermitian matrix result.
// Covers syr, syr2, syrk, syr2k, her, her2, herk, her2k.
// Cnorm is norm of original C, before multiplication operation.
//
// alpha and beta are either real or complex, depending on routine:
// zher zher2 zherk zher2k zsyr zsyr2 zsyrk zsyr2k
// alpha real complex real complex complex complex complex complex
// beta -- -- real real -- -- complex complex
// zsyr2 doesn't exist in standard BLAS or LAPACK.
template< typename TA, typename TB, typename T >
void check_herk(
blas::Uplo uplo,
int64_t n, int64_t k,
TA alpha,
TB beta,
blas::real_type<T> Anorm,
blas::real_type<T> Bnorm,
blas::real_type<T> Cnorm,
T const* Cref, int64_t ldcref,
T* C, int64_t ldc,
blas::real_type<T> error[1],
int64_t* okay )
{
#define C(i_, j_) C[ (i_) + (j_)*ldc ]
#define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]
using real_t = blas::real_type<T>;
require( n >= 0 );
require( k >= 0 );
require( ldc >= n );
require( ldcref >= n );
// C -= Cref
if (uplo == blas::Uplo::Lower) {
for (int64_t j = 0; j < n; ++j) {
for (int64_t i = j; i < n; ++i) {
C(i,j) -= Cref(i,j);
}
}
}
else {
for (int64_t j = 0; j < n; ++j) {
for (int64_t i = 0; i <= j; ++i) {
C(i,j) -= Cref(i,j);
}
}
}
error[0] = lapack::lanhe( lapack::Norm::Fro, uplo, n, C, ldc )
/ (sqrt(real_t(k)+2)*std::abs(alpha)*Anorm*Bnorm + 2*std::abs(beta)*Cnorm);
// Allow 3*eps; complex needs 2*sqrt(2) factor; see Higham, 2002, sec. 3.6.
real_t eps = std::numeric_limits< real_t >::epsilon();
*okay = (error[0] < 3*eps);
#undef C
#undef Cref
}
#endif // #ifndef CHECK_GEMM_HH
|