1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
package tensor_test
import (
"fmt"
"gorgonia.org/tensor"
)
// In this example, we want to handle basic tensor operations for arbitray types (slicing, stacking, transposing)
// LongStruct is a type that is an arbitrarily long struct
type LongStruct struct {
a, b, c, d, e uint64
}
// Format implements fmt.Formatter for easier-to-read output of data
func (ls LongStruct) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "{a: %d, b: %d, c: %d, d: %d, e: %d}", ls.a, ls.b, ls.c, ls.d, ls.e)
}
type s int
func (ss s) Start() int { return int(ss) }
func (ss s) End() int { return int(ss) + 1 }
func (ss s) Step() int { return 1 }
func ExampleTranspose_extension() {
// For documentation if you're reading this on godoc:
//
// type LongStruct struct {
// a, b, c, d, e uint64
// }
T := tensor.New(tensor.WithShape(2, 2),
tensor.WithBacking([]LongStruct{
LongStruct{0, 0, 0, 0, 0},
LongStruct{1, 1, 1, 1, 1},
LongStruct{2, 2, 2, 2, 2},
LongStruct{3, 3, 3, 3, 3},
}),
)
fmt.Printf("Before:\n%v\n", T)
retVal, _ := tensor.Transpose(T) // an alternative would be to use T.T(); T.Transpose()
fmt.Printf("After:\n%v\n", retVal)
// Output:
// Before:
// ⎡{a: 0, b: 0, c: 0, d: 0, e: 0} {a: 1, b: 1, c: 1, d: 1, e: 1}⎤
// ⎣{a: 2, b: 2, c: 2, d: 2, e: 2} {a: 3, b: 3, c: 3, d: 3, e: 3}⎦
//
// After:
// ⎡{a: 0, b: 0, c: 0, d: 0, e: 0} {a: 2, b: 2, c: 2, d: 2, e: 2}⎤
// ⎣{a: 1, b: 1, c: 1, d: 1, e: 1} {a: 3, b: 3, c: 3, d: 3, e: 3}⎦
}
func Example_stackExtension() {
// For documentation if you're reading this on godoc:
//
// type LongStruct struct {
// a, b, c, d, e uint64
// }
T := tensor.New(tensor.WithShape(2, 2),
tensor.WithBacking([]LongStruct{
LongStruct{0, 0, 0, 0, 0},
LongStruct{1, 1, 1, 1, 1},
LongStruct{2, 2, 2, 2, 2},
LongStruct{3, 3, 3, 3, 3},
}),
)
S, _ := T.Slice(nil, s(1)) // s is a type that implements tensor.Slice
T2 := tensor.New(tensor.WithShape(2, 2),
tensor.WithBacking([]LongStruct{
LongStruct{10, 10, 10, 10, 10},
LongStruct{11, 11, 11, 11, 11},
LongStruct{12, 12, 12, 12, 12},
LongStruct{13, 13, 13, 13, 13},
}),
)
S2, _ := T2.Slice(nil, s(0))
// an alternative would be something like this
// T3, _ := S.(*tensor.Dense).Stack(1, S2.(*tensor.Dense))
T3, _ := tensor.Stack(1, S, S2)
fmt.Printf("Stacked:\n%v", T3)
// Output:
// Stacked:
// ⎡ {a: 1, b: 1, c: 1, d: 1, e: 1} {a: 10, b: 10, c: 10, d: 10, e: 10}⎤
// ⎣ {a: 3, b: 3, c: 3, d: 3, e: 3} {a: 12, b: 12, c: 12, d: 12, e: 12}⎦
}
|