File: dgetri.go

package info (click to toggle)
golang-gonum-v1-gonum 0.15.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 18,792 kB
  • sloc: asm: 6,252; fortran: 5,271; sh: 377; ruby: 211; makefile: 98
file content (116 lines) | stat: -rw-r--r-- 3,183 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package gonum

import (
	"gonum.org/v1/gonum/blas"
	"gonum.org/v1/gonum/blas/blas64"
)

// Dgetri computes the inverse of the matrix A using the LU factorization computed
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
// Dgetrf and on exit contains the reciprocal of the original matrix.
//
// Dgetri will not perform the inversion if the matrix is singular, and returns
// a boolean indicating whether the inversion was successful.
//
// work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= n and this function will panic otherwise.
// Dgetri is a blocked inversion, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dgetri,
// the optimal work length will be stored into work[0].
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
	iws := max(1, n)
	switch {
	case n < 0:
		panic(nLT0)
	case lda < max(1, n):
		panic(badLdA)
	case lwork < iws && lwork != -1:
		panic(badLWork)
	case len(work) < max(1, lwork):
		panic(shortWork)
	}

	if n == 0 {
		work[0] = 1
		return true
	}

	nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1)
	if lwork == -1 {
		work[0] = float64(n * nb)
		return true
	}

	switch {
	case len(a) < (n-1)*lda+n:
		panic(shortA)
	case len(ipiv) != n:
		panic(badLenIpiv)
	}

	// Form inv(U).
	ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda)
	if !ok {
		return false
	}

	nbmin := 2
	if 1 < nb && nb < n {
		iws = max(n*nb, 1)
		if lwork < iws {
			nb = lwork / n
			nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1))
		}
	}
	ldwork := nb

	bi := blas64.Implementation()
	// Solve the equation inv(A)*L = inv(U) for inv(A).
	// TODO(btracey): Replace this with a more row-major oriented algorithm.
	if nb < nbmin || n <= nb {
		// Unblocked code.
		for j := n - 1; j >= 0; j-- {
			for i := j + 1; i < n; i++ {
				// Copy current column of L to work and replace with zeros.
				work[i] = a[i*lda+j]
				a[i*lda+j] = 0
			}
			// Compute current column of inv(A).
			if j < n-1 {
				bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1):], 1, 1, a[j:], lda)
			}
		}
	} else {
		// Blocked code.
		nn := ((n - 1) / nb) * nb
		for j := nn; j >= 0; j -= nb {
			jb := min(nb, n-j)
			// Copy current block column of L to work and replace
			// with zeros.
			for jj := j; jj < j+jb; jj++ {
				for i := jj + 1; i < n; i++ {
					work[i*ldwork+(jj-j)] = a[i*lda+jj]
					a[i*lda+jj] = 0
				}
			}
			// Compute current block column of inv(A).
			if j+jb < n {
				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda)
			}
			bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda)
		}
	}
	// Apply column interchanges.
	for j := n - 2; j >= 0; j-- {
		jp := ipiv[j]
		if jp != j {
			bi.Dswap(n, a[j:], lda, a[jp:], lda)
		}
	}
	work[0] = float64(iws)
	return true
}