1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
package tensor
import (
"github.com/pkg/errors"
)
// this file handles matops. While by default most of these matops should already have been defined as part of the
// Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions
// Narrow narrows the tensor.
func Narrow(t Tensor, dim, start, length int) (View, error) {
dim = resolveAxis(dim, t.Dims())
slices := make([]Slice, MinInt(dim+1, t.Dims()))
slices[dim] = S(start, start+length, 1)
return t.Slice(slices...)
}
// Repeat repeats a Tensor along the axis and given the number of repeats.
func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) {
if r, ok := t.Engine().(Repeater); ok {
return r.Repeat(t, axis, repeats...)
}
return nil, errors.New("Engine does not support Repeat")
}
// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid.
func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) {
if r, ok := t.Engine().(Repeater); ok {
return r.RepeatReuse(t, reuse, axis, repeats...)
}
return nil, errors.New("Engine does not support Repeat")
}
// T safely transposes a Tensor. It returns a tensor that is not a view of the input tensor - rather, the data is all copied.
func T(t Tensor, axes ...int) (retVal Tensor, err error) {
switch tt := t.(type) {
case *Dense:
return tt.SafeT(axes...)
}
panic("Unreachable")
}
// Transpose performs transposition of a tensor according to its axes.
func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) {
switch tt := t.(type) {
case *Dense:
var ret *Dense
if ret, err = tt.SafeT(axes...); err != nil {
return
}
ret.Transpose()
retVal = ret
return
}
panic("Unreachable")
}
// Concat concatenates a list of Tensors. At the moment the operation only supports Tensors of the same type
// (*Dense can only be concatenated with a bunch of *Dense, CSCs can only be concatenated with a bunch of CSC, etc)
func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
if len(others) == 0 {
return t, nil
}
switch T := t.(type) {
case *Dense:
ts := make([]*Dense, len(others))
for i, o := range others {
if ot, ok := o.(*Dense); ok {
ts[i] = ot
continue
}
return nil, errors.Errorf("Expected all Tensors to be *Dense")
}
return T.Concat(axis, ts...)
}
panic("Unreachable")
}
// Copy copies a tensor to another. For *Dense views, only the relevant slots are copied.
func Copy(dst, src Tensor) error {
switch st := src.(type) {
case DenseTensor:
dt, ok := dst.(DenseTensor)
if !ok {
return errors.Errorf("Cannot copy from DenseTensor to %T", dst)
}
if st.RequiresIterator() || dt.RequiresIterator() {
siter := st.Iterator()
diter := dt.Iterator()
_, err := copyDenseIter(dt, st, diter, siter)
return err
}
copyDense(dt, st)
return nil
default:
return errors.Errorf("NYI for Copy %T", src)
}
panic("Unreachable")
}
// Stack stacks a list of other Tensors. At the moment the operation only supports Tensors of the same type.
// (*Dense can only be stacked with *Dense... etc)
func Stack(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
if len(others) == 0 {
return t, nil
}
switch T := t.(type) {
case DenseTensor:
var dts []DenseTensor
if dts, err = tensorsToDenseTensors(others); err != nil {
return nil, errors.Wrap(err, "Cannot convert others into a slice of DenseTensors")
}
return T.stackDense(axis, dts...)
}
panic("Unreachable")
}
// Materialize takes a View and copies out the data into a new allocation.
func Materialize(t Tensor) Tensor {
switch tt := t.(type) {
case View:
return tt.Materialize()
default:
return t
}
}
func Diag(t Tensor) (retVal Tensor, err error) {
if d, ok := t.Engine().(Diager); ok {
return d.Diag(t)
}
return nil, errors.Errorf("Unable to perform diagonalization of tensor ")
}
// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor.
// The `indices` tensor has to be a vector-like tensor of ints.
func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndices(a, indices, axis, opts...)
}
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
}
// ByIndicesB is the backpropagation of ByIndices.
func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
}
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
}
// LogSoftMax applies log softmax to the given tensor.
func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := x.Engine().(SoftMaxer); ok {
return sm.LogSoftMax(x, axis, opts...)
}
return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine())
}
// SoftMax applies softmax to the given tensor.
func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := x.Engine().(SoftMaxer); ok {
return sm.SoftMax(x, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine())
}
// SoftMaxB applies softmax backwards operation
func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := output.Engine().(SoftMaxer); ok {
return sm.SoftMaxB(output, grad, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
}
// LogSoftMaxB applies softmax backwards operation
func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := output.Engine().(SoftMaxer); ok {
return sm.LogSoftMaxB(output, grad, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
}
|