File: dense_assign.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (96 lines) | stat: -rw-r--r-- 2,108 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package tensor

import (
	"github.com/pkg/errors"
)

func overlaps(a, b DenseTensor) bool {
	if a.cap() == 0 || b.cap() == 0 {
		return false
	}
	aarr := a.arr()
	barr := b.arr()
	if aarr.Uintptr() == barr.Uintptr() {
		return true
	}
	aptr := aarr.Uintptr()
	bptr := barr.Uintptr()

	capA := aptr + uintptr(cap(aarr.Header.Raw))
	capB := bptr + uintptr(cap(barr.Header.Raw))

	switch {
	case aptr < bptr:
		if bptr < capA {
			return true
		}
	case aptr > bptr:
		if aptr < capB {
			return true
		}
	}
	return false
}

func assignArray(dest, src DenseTensor) (err error) {
	// var copiedSrc bool

	if src.IsScalar() {
		panic("HELP")
	}

	dd := dest.Dims()
	sd := src.Dims()

	dstrides := dest.Strides()
	sstrides := src.Strides()

	var ds, ss int
	ds = dstrides[0]
	if src.IsVector() {
		ss = sstrides[0]
	} else {
		ss = sstrides[sd-1]
	}

	// when dd == 1, and the strides point in the same direction
	// we copy to a temporary if there is an overlap of data
	if ((dd == 1 && sd >= 1 && ds*ss < 0) || dd > 1) && overlaps(dest, src) {
		// create temp
		// copiedSrc = true
	}

	// broadcast src to dest for raw iteration
	tmpShape := Shape(BorrowInts(sd))
	tmpStrides := BorrowInts(len(src.Strides()))
	copy(tmpShape, src.Shape())
	copy(tmpStrides, src.Strides())
	defer ReturnInts(tmpShape)
	defer ReturnInts(tmpStrides)

	if sd > dd {
		tmpDim := sd
		for tmpDim > dd && tmpShape[0] == 1 {
			tmpDim--

			// this is better than tmpShape = tmpShape[1:]
			// because we are going to return these ints later
			copy(tmpShape, tmpShape[1:])
			copy(tmpStrides, tmpStrides[1:])
		}
	}

	var newStrides []int
	if newStrides, err = BroadcastStrides(dest.Shape(), tmpShape, dstrides, tmpStrides); err != nil {
		err = errors.Wrapf(err, "BroadcastStrides failed")
		return
	}
	dap := dest.Info()
	sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ)

	diter := newFlatIterator(dap)
	siter := newFlatIterator(&sap)
	_, err = copyDenseIter(dest, src, diter, siter)
	sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up.  Don't double free
	return
}