File: check_gehrd.hh

package info (click to toggle)
lapackpp 2024.10.26-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,500 kB
  • sloc: cpp: 80,181; ansic: 27,660; python: 4,838; xml: 182; perl: 99; makefile: 53; sh: 23
file content (59 lines) | stat: -rw-r--r-- 2,363 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "blas.hh"
#include "lapack.hh"

#include <vector>

// -----------------------------------------------------------------------------
// Given original A and factored H as the output of gehrd, computes:
// results[0] = || A - U H U^H || / (n ||A||)
// results[1] = || I - U U^H || / n
template< typename scalar_t >
void check_gehrd(
    int64_t n,
    scalar_t const* A, int64_t lda,
    scalar_t const* H, int64_t ldh,
    scalar_t const* tau,
    int64_t verbose,
    blas::real_type< scalar_t > results[2] )
{
    using namespace blas;
    using real_t = blas::real_type< scalar_t >;

    size_t size_A = lda * n;
    std::vector< scalar_t > U( size_A ), H2( size_A ), work( size_A );

    // zero out below sub-diagonal in Hessenberg H2
    lapack::lacpy( lapack::MatrixType::General, n, n, H, lda, &H2[0], lda );
    lapack::laset( lapack::MatrixType::Lower, n-2, n-2, 0.0, 0.0, &H2[2], lda );
    // generate U
    lapack::lacpy( lapack::MatrixType::General, n, n, H, lda, &U[0], lda );
    lapack::unghr( n, 1, n, &U[0], lda, tau );
    // work = U H2
    blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, n, n, n,
                1.0, &U[0], lda,
                     &H2[0], lda,
                0.0, &work[0], lda );
    // H2 = A - (U H2) U^H
    lapack::lacpy( lapack::MatrixType::General, n, n, A, lda, &H2[0], lda );
    blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, n, n, n,
                -1.0, &work[0], lda,
                      &U[0], lda,
                 1.0, &H2[0], lda );
    real_t Anorm = lapack::lange( lapack::Norm::One, n, n, &A[0], lda );
    real_t error = lapack::lange( lapack::Norm::One, n, n, &H2[0], lda );
    results[0] = error / Anorm / n;

    // work = I - U U^H
    lapack::laset( lapack::MatrixType::General, n, n, 0.0, 1.0, &work[0], lda );
    blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, n, n, n,
                -1.0, &U[0], lda,
                      &U[0], lda,
                 1.0, &work[0], lda );
    error = lapack::lange( lapack::Norm::One, n, n, &work[0], lda );
    results[1] = error / n;
}