File: test_orhr_col.cc

package info (click to toggle)
lapackpp 2024.10.26-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,500 kB
  • sloc: cpp: 80,181; ansic: 27,660; python: 4,838; xml: 182; perl: 99; makefile: 53; sh: 23
file content (122 lines) | stat: -rw-r--r-- 3,922 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "test.hh"
#include "lapack.hh"
#include "print_matrix.hh"
#include "error.hh"
#include "lapacke_wrappers.hh"

#include <vector>

#if LAPACK_VERSION >= 30900  // >= 3.9.0

// -----------------------------------------------------------------------------
template< typename scalar_t >
void test_orhr_col_work( Params& params, bool run )
{
    using real_t = blas::real_type< scalar_t >;

    // get & mark input values
    int64_t m = params.dim.m();
    int64_t n = params.dim.n();
    int64_t nb = params.nb();
    int64_t align = params.align();

    // mark non-standard output values
    params.ref_time();

    if (! run)
        return;

    // ---------- setup
    int64_t lda = roundup( blas::max( 1, m ), align );
    int64_t ldt = roundup( blas::max( 1, blas::min( nb, n ) ), align );
    size_t size_A = (size_t) lda * n;
    size_t size_T = (size_t) ldt * n;
    size_t size_D = (size_t) blas::min( m, n );

    std::vector< scalar_t > A_tst( size_A );
    std::vector< scalar_t > A_ref( size_A );
    std::vector< scalar_t > T_tst( size_T );
    std::vector< scalar_t > T_ref( size_T );
    std::vector< scalar_t > D_tst( size_D );
    std::vector< scalar_t > D_ref( size_D );

    int64_t idist = 1;
    int64_t iseed[4] = { 0, 1, 2, 3 };
    lapack::larnv( idist, iseed, A_tst.size(), &A_tst[0] );
    A_ref = A_tst;

    // ---------- run test
    testsweeper::flush_cache( params.cache() );
    double time = testsweeper::get_wtime();
    int64_t info_tst = lapack::orhr_col(
        m, n, nb, &A_tst[0], lda, &T_tst[0], ldt, &D_tst[0] );
    time = testsweeper::get_wtime() - time;
    if (info_tst != 0) {
        fprintf( stderr, "lapack::orhr_col returned error %lld\n",
                 llong( info_tst ) );
    }

    params.time() = time;

    if (params.ref() == 'y' || params.check() == 'y') {
    #if LAPACK_VERSION >= 31200 || defined( LAPACK_HAVE_MKL )
        // ---------- run reference
        testsweeper::flush_cache( params.cache() );
        time = testsweeper::get_wtime();
        // min works around bug in LAPACK <= 3.12
        int64_t info_ref = LAPACKE_orhr_col(
            m, n, blas::min( nb, n ), &A_ref[0], lda, &T_ref[0], ldt, &D_ref[0] );
        time = testsweeper::get_wtime() - time;
        if (info_ref != 0) {
            fprintf( stderr, "LAPACKE_orhr_col returned error %lld\n",
                     llong( info_ref ) );
        }

        params.ref_time() = time;

        // ---------- check error compared to reference
        real_t error = 0;
        if (info_tst != info_ref) {
            error = 1;
        }
        error += abs_error( A_tst, A_ref );
        error += abs_error( T_tst, T_ref );
        error += abs_error( D_tst, D_ref );
        params.error() = error;
        params.okay() = (error == 0);  // expect lapackpp == lapacke
    #else
        // LAPACKE_unhr_col not yet in LAPACK
        params.msg() = "check requires LAPACK >= 3.12 or Intel MKL";
    #endif  // LAPACK_HAVE_MKL
    }
}

#endif  // LAPACK >= 3.9.0

// -----------------------------------------------------------------------------
void test_orhr_col( Params& params, bool run )
{
#if LAPACK_VERSION >= 30900  // >= 3.9.0
    switch (params.datatype()) {
        case testsweeper::DataType::Single:
            test_orhr_col_work< float >( params, run );
            break;

        case testsweeper::DataType::Double:
            test_orhr_col_work< double >( params, run );
            break;

        default:
            throw std::runtime_error( "unknown datatype" );
            break;
    }
#else
    fprintf( stderr, "orhr_col requires LAPACK >= 3.9.0\n\n" );
    exit(0);
#endif
}