File: check_heev.hh

package info (click to toggle)
lapackpp 2024.10.26-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,500 kB
  • sloc: cpp: 80,181; ansic: 27,660; python: 4,838; xml: 182; perl: 99; makefile: 53; sh: 23
file content (97 lines) | stat: -rw-r--r-- 3,515 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "blas.hh"
#include "lapack.hh"
#include "error.hh"
#include "check_ortho.hh"
#include "scale.hh"

#include <vector>

//------------------------------------------------------------------------------
// Computes error measures:
// If jobz != NoVec:
//     result[ 0 ] = || A - Z Lambda Z^H || / (n ||A||) if nfound == n;
//     result[ 0 ] = || Z^H A Z - Lambda || / (n ||A||) otherwise.
//     result[ 1 ] = || I - Z^H Z || / n, if jobz != NoVec.
// result[ 2 ] = 0 if Lambda is in non-decreasing order, else >= 1.
template< typename scalar_t >
void check_heev(
    lapack::Job jobz,
    lapack::Uplo uplo, int64_t n,
    scalar_t const* A, int64_t lda,
    int64_t nfound,
    blas::real_type< scalar_t > const* Lambda,
    scalar_t const* Z, int64_t ldz,
    blas::real_type< scalar_t > result[ 3 ] )
{
    using namespace blas;
    using namespace lapack;
    using real_t = blas::real_type< scalar_t >;

    // Constants
    const scalar_t one  = 1;
    const scalar_t zero = 0;

    if (jobz == Job::Vec) {
        real_t Anorm = lapack::lanhe( Norm::One, uplo, n, A, lda );

        // R is nfound-by-nfound, whether n == nfound or not.
        int64_t ldr = nfound;
        std::vector< scalar_t > R_( ldr*nfound );
        scalar_t* R = &R_[ 0 ];

        if (n == nfound) {
            // || A - Z Lambda Z^H ||
            std::vector< scalar_t > ZLambda_( ldz*n );
            scalar_t* ZLambda = &ZLambda_[ 0 ];

            // ZLambda = Z Lambda is n-by-n.
            lapack::lacpy( MatrixType::General, n, n,
                           Z, ldz,
                           ZLambda, ldz );
            col_scale( n, n, ZLambda, ldz, Lambda );
            // R = A - (Z Lambda) Z^H; could use gemmtr instead of gemm.
            lapack::lacpy( MatrixType::General, n, n,
                           A, lda,
                           R, ldr );
            blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, n, n, n,
                        -one, ZLambda, ldz,
                              Z, ldz,
                        one,  R, ldr );
        }
        else {
            // || Z^H A Z - Lambda ||
            std::vector< scalar_t > AZ_( lda*nfound );
            scalar_t* AZ = &AZ_[ 0 ];

            // AZ = A Z is n-by-nfound.
            blas::hemm( Layout::ColMajor, Side::Left, uplo, n, nfound,
                        one,  A, lda,
                              Z, ldz,
                        zero, AZ, lda );
            // R = Z^H (A Z); could use gemmtr instead of gemm.
            blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans,
                        nfound, nfound, n,
                        one,  Z, ldz,
                              AZ, lda,
                        zero, R, ldr );
            // R -= Lambda, along diagonal.
            blas::axpy( nfound, -one, Lambda, 1, R, ldr + 1 );
        }
        result[ 0 ] = lapack::lanhe( Norm::One, uplo, nfound, R, ldr )
                    / (n * Anorm);

        result[ 1 ] = check_orthogonality( RowCol::Col, n, nfound, Z, ldz );
    }

    // Check that Lambda is non-decreasing.
    result[ 2 ] = 0;
    for (int64_t i = 0; i < nfound - 1; ++i) {
        if (Lambda[ i ] > Lambda[ i+1 ])
            result[ 2 ] += 1;
    }
}