1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
// Package tensor is a package that provides efficient, generic n-dimensional arrays in Go.
// Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations.
package tensor // import "gorgonia.org/tensor"
import (
"encoding/gob"
"fmt"
"io"
"github.com/pkg/errors"
)
var (
_ Tensor = &Dense{}
_ Tensor = &CS{}
_ View = &Dense{}
)
func init() {
gob.Register(&Dense{})
gob.Register(&CS{})
}
// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor.
// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors.
type Tensor interface {
// info about the ndarray
Shape() Shape
Strides() []int
Dtype() Dtype
Dims() int
Size() int
DataSize() int
// Data access related
RequiresIterator() bool
Iterator() Iterator
DataOrder() DataOrder
// ops
Slicer
At(...int) (interface{}, error)
SetAt(v interface{}, coord ...int) error
Reshape(...int) error
T(axes ...int) error
UT()
Transpose() error // Transpose actually moves the data
Apply(fn interface{}, opts ...FuncOpt) (Tensor, error)
// data related interface
Zeroer
MemSetter
Dataer
Eq
Cloner
// type overloading methods
IsScalar() bool
ScalarValue() interface{}
// engine/memory related stuff
// all Tensors should be able to be expressed of as a slab of memory
// Note: the size of each element can be acquired by T.Dtype().Size()
Memory // Tensors all implement Memory
Engine() Engine // Engine can be nil
IsNativelyAccessible() bool // Can Go access the memory
IsManuallyManaged() bool // Must Go manage the memory
// formatters
fmt.Formatter
fmt.Stringer
// all Tensors are serializable to these formats
WriteNpy(io.Writer) error
ReadNpy(io.Reader) error
gob.GobEncoder
gob.GobDecoder
standardEngine() standardEngine
headerer
arrayer
}
// New creates a new Dense Tensor. For sparse arrays use their relevant construction function
func New(opts ...ConsOpt) *Dense {
d := borrowDense()
for _, opt := range opts {
opt(d)
}
d.fix()
if err := d.sanity(); err != nil {
panic(err)
}
return d
}
func assertDense(t Tensor) (*Dense, error) {
if t == nil {
return nil, errors.New("nil is not a *Dense")
}
if retVal, ok := t.(*Dense); ok {
return retVal, nil
}
if retVal, ok := t.(Densor); ok {
return retVal.Dense(), nil
}
return nil, errors.Errorf("%T is not *Dense", t)
}
func getDenseTensor(t Tensor) (DenseTensor, error) {
switch tt := t.(type) {
case DenseTensor:
return tt, nil
case Densor:
return tt.Dense(), nil
default:
return nil, errors.Errorf("Tensor %T is not a DenseTensor", t)
}
}
// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) {
if t == nil {
return
}
if err = typeclassCheck(t.Dtype(), floatTypes); err != nil {
err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype())
return
}
if retVal, err = getDenseTensor(t); err != nil {
err = errors.Wrapf(err, opFail, "getFloatDense")
return
}
if retVal == nil {
return
}
return
}
// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) {
if t == nil {
return
}
if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil {
err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype())
return
}
if retVal, err = getDenseTensor(t); err != nil {
err = errors.Wrapf(err, opFail, "getFloatDense")
return
}
if retVal == nil {
return
}
return
}
func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) {
var sliced Tensor
if sliced, err = t.Slice(slices...); err != nil {
return nil, err
}
return sliced.(*Dense), nil
}
|