File: dsyev.go

package info (click to toggle)
golang-gonum-v1-gonum 0.15.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 18,792 kB
  • sloc: asm: 6,252; fortran: 5,271; sh: 377; ruby: 211; makefile: 98
file content (116 lines) | stat: -rw-r--r-- 2,597 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package testlapack

import (
	"testing"

	"golang.org/x/exp/rand"

	"gonum.org/v1/gonum/blas"
	"gonum.org/v1/gonum/blas/blas64"
	"gonum.org/v1/gonum/floats"
	"gonum.org/v1/gonum/lapack"
)

type Dsyever interface {
	Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool)
}

func DsyevTest(t *testing.T, impl Dsyever) {
	rnd := rand.New(rand.NewSource(1))
	for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
		for _, test := range []struct {
			n, lda int
		}{
			{1, 0},
			{2, 0},
			{5, 0},
			{10, 0},
			{100, 0},

			{1, 5},
			{2, 5},
			{5, 10},
			{10, 20},
			{100, 110},
		} {
			for cas := 0; cas < 10; cas++ {
				n := test.n
				lda := test.lda
				if lda == 0 {
					lda = n
				}
				a := make([]float64, n*lda)
				for i := range a {
					a[i] = rnd.NormFloat64()
				}
				aCopy := make([]float64, len(a))
				copy(aCopy, a)
				w := make([]float64, n)
				for i := range w {
					w[i] = rnd.NormFloat64()
				}

				work := make([]float64, 1)
				impl.Dsyev(lapack.EVCompute, uplo, n, a, lda, w, work, -1)
				work = make([]float64, int(work[0]))
				impl.Dsyev(lapack.EVCompute, uplo, n, a, lda, w, work, len(work))

				// Check that the decomposition is correct
				orig := blas64.General{
					Rows:   n,
					Cols:   n,
					Stride: n,
					Data:   make([]float64, n*n),
				}
				if uplo == blas.Upper {
					for i := 0; i < n; i++ {
						for j := i; j < n; j++ {
							v := aCopy[i*lda+j]
							orig.Data[i*orig.Stride+j] = v
							orig.Data[j*orig.Stride+i] = v
						}
					}
				} else {
					for i := 0; i < n; i++ {
						for j := 0; j <= i; j++ {
							v := aCopy[i*lda+j]
							orig.Data[i*orig.Stride+j] = v
							orig.Data[j*orig.Stride+i] = v
						}
					}
				}

				V := blas64.General{
					Rows:   n,
					Cols:   n,
					Stride: lda,
					Data:   a,
				}

				if !eigenDecompCorrect(w, orig, V) {
					t.Errorf("Decomposition mismatch")
				}

				// Check that the decomposition is correct when the eigenvectors
				// are not computed.
				wAns := make([]float64, len(w))
				copy(wAns, w)
				copy(a, aCopy)
				for i := range w {
					w[i] = rnd.Float64()
				}
				for i := range work {
					work[i] = rnd.Float64()
				}
				impl.Dsyev(lapack.EVNone, uplo, n, a, lda, w, work, len(work))
				if !floats.EqualApprox(w, wAns, 1e-8) {
					t.Errorf("Eigenvalue mismatch when vectors not computed")
				}
			}
		}
	}
}