File: check_ortho.hh

package info (click to toggle)
lapackpp 2024.10.26-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,500 kB
  • sloc: cpp: 80,181; ansic: 27,660; python: 4,838; xml: 182; perl: 99; makefile: 53; sh: 23
file content (53 lines) | stat: -rw-r--r-- 1,714 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "blas.hh"
#include "lapack.hh"
#include "error.hh"

#include <vector>

// -----------------------------------------------------------------------------
// Computes error measure:
// || I - U^H U || / m  if rowcol == Col (cols are orthogonal; m >= n)
// or
// || I - U U^H || / n  if rowcol == Row (rows are orthogonal; m <= n)
// Similar to LAPACK testing zunt01
template< typename scalar_t >
blas::real_type< scalar_t > check_orthogonality(
    lapack::RowCol rowcol,
    int64_t m, int64_t n,
    scalar_t const* U, int64_t ldu )
{
    using namespace blas;
    using namespace lapack;
    using real_t = blas::real_type< scalar_t >;

    int64_t minmn = min( m, n );
    int64_t ldr = minmn;
    int64_t k;
    Op transU;
    if (rowcol == RowCol::Row) {
        if (m > n)
            throw lapack::Error( "rowcol == row && m > n", __func__ );
        transU = Op::NoTrans;
        k = n;
    }
    else {
        if (m < n)
            throw lapack::Error( "rowcol == col && m < n", __func__ );
        transU = Op::ConjTrans;
        k = m;
    }

    // R = I - U^H U (col) or I - U U^H (row)
    std::vector< scalar_t > R( minmn * minmn );
    laset( MatrixType::Upper, minmn, minmn, 0.0, 1.0, &R[0], ldr );
    herk( Layout::ColMajor, Uplo::Upper, transU, minmn, k, -1.0, U, ldu, 1.0, &R[0], ldr );

    // resid = || R || / k
    real_t resid = lanhe( Norm::One, Uplo::Upper, minmn, &R[0], ldr ) / k;
    return resid;
}