File: example_byindices_test.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (74 lines) | stat: -rw-r--r-- 1,300 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package tensor

import "fmt"

func ExampleByIndices() {
	a := New(WithShape(2, 2), WithBacking([]float64{
		100, 200,
		300, 400,
	}))
	indices := New(WithBacking([]int{1, 1, 1, 0, 1}))
	b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1
	if err != nil {
		fmt.Println(err)
		return
	}

	fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\n", a, indices, b)

	// Output:
	// a:
	// ⎡100  200⎤
	// ⎣300  400⎦
	//
	// indices: [1  1  1  0  1]
	// b:
	// ⎡300  400⎤
	// ⎢300  400⎥
	// ⎢300  400⎥
	// ⎢100  200⎥
	// ⎣300  400⎦

}

func ExampleByIndicesB() {
	a := New(WithShape(2, 2), WithBacking([]float64{
		100, 200,
		300, 400,
	}))
	indices := New(WithBacking([]int{1, 1, 1, 0, 1}))
	b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1
	if err != nil {
		fmt.Println(err)
		return
	}

	outGrad := b.Clone().(*Dense)
	outGrad.Memset(1.0)

	grad, err := ByIndicesB(a, outGrad, indices, 0)
	if err != nil {
		fmt.Println(err)
		return
	}

	fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\ngrad:\n%v", a, indices, b, grad)

	// Output:
	// a:
	// ⎡100  200⎤
	// ⎣300  400⎦
	//
	// indices: [1  1  1  0  1]
	// b:
	// ⎡300  400⎤
	// ⎢300  400⎥
	// ⎢300  400⎥
	// ⎢100  200⎥
	// ⎣300  400⎦
	//
	// grad:
	// ⎡1  1⎤
	// ⎣4  4⎦

}