File: example_extension_test.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (101 lines) | stat: -rw-r--r-- 2,249 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package tensor_test

import (
	//"errors"
	"fmt"
	"reflect"

	"github.com/pkg/errors"
	"gorgonia.org/tensor"
)

// In this example, we want to create and handle a tensor of *MyType

// First, define MyType

// MyType is defined
type MyType struct {
	x, y int
}

func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) }

// MyDtype this the dtype of MyType. This value is populated in the init() function below
var MyDtype tensor.Dtype

// MyEngine supports additions of MyType, as well as other Dtypes
type MyEngine struct {
	tensor.StdEng
}

// For simplicity's sake, we'd only want to handle MyType-MyType or MyType-Int interactions
// Also, we only expect Dense tensors
// You're of course free to define your own rules

// Add adds two tensors
func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) {
	switch a.Dtype() {
	case MyDtype:
		switch b.Dtype() {
		case MyDtype:
			data := a.Data().([]*MyType)
			datb := b.Data().([]*MyType)
			for i, v := range data {
				v.x += datb[i].x
				v.y += datb[i].y
			}
			return a, nil
		case tensor.Int:
			data := a.Data().([]*MyType)
			datb := b.Data().([]int)
			for i, v := range data {
				v.x += datb[i]
				v.y += datb[i]
			}
			return a, nil
		}
	case tensor.Int:
		switch b.Dtype() {
		case MyDtype:
			data := a.Data().([]int)
			datb := b.Data().([]*MyType)
			for i, v := range datb {
				v.x += data[i]
				v.y += data[i]
			}
		default:
			return e.StdEng.Add(a, b, opts...)
		}
	default:
		return e.StdEng.Add(a, b, opts...)
	}
	return nil, errors.New("Unreachable")
}

func init() {
	MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})}
}

func Example_extension() {
	T := tensor.New(tensor.WithEngine(MyEngine{}),
		tensor.WithShape(2, 2),
		tensor.WithBacking([]*MyType{
			&MyType{0, 0}, &MyType{0, 1},
			&MyType{1, 0}, &MyType{1, 1},
		}))
	ones := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]int{1, 1, 1, 1}), tensor.WithEngine(MyEngine{}))
	T2, _ := T.Add(ones)

	fmt.Printf("T:\n%+v", T)
	fmt.Printf("T2:\n%+v", T2)

	// output:
	//T:
	// Matrix (2, 2) [2 1]
	// ⎡(1, 1)  (1, 2)⎤
	// ⎣(2, 1)  (2, 2)⎦
	// T2:
	// Matrix (2, 2) [2 1]
	// ⎡(1, 1)  (1, 2)⎤
	// ⎣(2, 1)  (2, 2)⎦
}