1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
|
package tensor
import (
"fmt"
"github.com/pkg/errors"
)
var scalarShape = Shape{}
// ScalarShape represents a scalar. It has no dimensions, no sizes
func ScalarShape() Shape { return scalarShape }
// Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns.
// Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns.
//
// Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and
// a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that
// row and column vectors and vanilla vectors are comparable under some circumstances
type Shape []int
// TotalSize returns the number of elements expected in a Tensor of a certain shape
func (s Shape) TotalSize() int {
return ProdInts([]int(s))
}
// CalcStrides calculates the default strides for a shape
func (s Shape) CalcStrides() []int {
if s.IsScalar() {
return nil
}
retVal := BorrowInts(len(s))
// if s.IsVector() {
// retVal[0] = 1
// retVal = retVal[:1]
// return retVal
// }
acc := 1
for i := len(s) - 1; i >= 0; i-- {
retVal[i] = acc
d := s[i]
if d < 0 {
panic("negative dimension size does not make sense")
}
acc *= d
}
return retVal
}
// CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions
// during calculation of stride
func (s Shape) CalcStridesWithMask(mask []bool) []int {
if s.IsScalarEquiv() {
return nil
}
retVal := BorrowInts(len(s))
if s.IsVector() {
retVal[0] = 1
retVal = retVal[:1]
return retVal
}
if len(mask) != s.Dims() {
panic("mask length must be equal to number of shape dimensions")
}
acc := 1
for i := len(s) - 1; i >= 0; i-- {
if mask[i] {
retVal[i] = acc
} else {
retVal[i] = 0
}
d := s[i]
if d < 0 {
panic("negative dimension size does not make sense")
}
if mask[i] {
acc *= d
}
}
return retVal
}
// CalcStridesColMajor is like CalcStrides, but assumes a col major layout
func (s Shape) CalcStridesColMajor() []int {
if s.IsScalarEquiv() {
return nil
}
retVal := BorrowInts(len(s))
if s.IsVector() {
retVal[0] = 1
retVal = retVal[:1]
return retVal
}
acc := 1
for i := 0; i < len(s); i++ {
retVal[i] = acc
d := s[i]
if d < 0 {
panic("negative dimension size does not make sense")
}
acc *= d
}
return retVal
}
// Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors.
//
// If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size;
// if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size
func (s Shape) Eq(other Shape) bool {
if s.IsScalar() && other.IsScalar() {
return true
}
if s.IsVector() && other.IsVector() {
switch {
case len(s) == 2 && len(other) == 1:
if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) {
return true
}
return false
case len(s) == 1 && len(other) == 2:
if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) {
return true
}
return false
}
}
if len(s) != len(other) {
return false
}
for i, v := range s {
if other[i] != v {
return false
}
}
return true
}
// Clone clones a shape.
func (s Shape) Clone() Shape {
retVal := BorrowInts(len(s))
copy(retVal, s)
return retVal
}
// IsScalar returns true if the access pattern indicates it's a scalar value
func (s Shape) IsScalar() bool {
return len(s) == 0
}
// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value
func (s Shape) IsScalarEquiv() bool {
if len(s) == 0 {
return true
}
isEquiv := true
for i := range s {
if s[i] != 1 {
return false
}
}
return isEquiv
}
// IsVector returns whether the access pattern falls into one of three possible definitions of vectors:
// vanilla vector (not a row or a col)
// column vector
// row vector
func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) }
// IsColVec returns true when the access pattern has the shape (x, 1)
func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) }
// IsRowVec returns true when the access pattern has the shape (1, x)
func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) }
// IsVectorLike returns true when the shape looks like a vector
// e.g. a number that is surrounded by 1s:
// (1, 1, ... 1, 10, 1, 1... 1)
func (s Shape) IsVectorLike() bool {
var nonOnes int
for _, i := range s {
if i != 1 {
nonOnes++
}
}
return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike.
}
// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices
func (s Shape) IsMatrix() bool { return len(s) == 2 }
// Dims returns the number of dimensions in the shape
func (s Shape) Dims() int { return len(s) }
// DimSize returns the size of the dimension wanted.
//
// This method implemnents the DimSizer interface in Gorgonia.
func (s Shape) DimSize(d int) (size int, err error) {
if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) {
err = errors.Errorf(dimMismatch, len(s), d)
return
}
switch {
case s.IsScalar():
return 0, nil
default:
return s[d], nil
}
}
// S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape
func (s Shape) S(slices ...Slice) (retVal Shape, err error) {
opDims := len(s)
if len(slices) > opDims {
err = errors.Errorf(dimMismatch, opDims, len(slices))
return
}
retVal = s.Clone()
for d, size := range s {
var sl Slice // default is a nil Slice
if d <= len(slices)-1 {
sl = slices[d]
}
var start, end, step int
if start, end, step, err = SliceDetails(sl, size); err != nil {
return
}
if step > 0 {
retVal[d] = (end - start) / step
//fix
if retVal[d] <= 0 {
retVal[d] = 1
}
} else {
retVal[d] = (end - start)
}
}
// drop any dimension with size 1, except the last dimension
offset := 0
dims := s.Dims()
for d := 0; d < dims; d++ {
if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ {
retVal = append(retVal[:d], retVal[d+1:]...)
d--
dims--
offset++
}
}
if retVal.IsScalar() {
ReturnInts(retVal)
return ScalarShape(), nil
}
return
}
// Repeat returns the expected new shape given the repetition parameters.
func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) {
switch {
case axis == AllAxes:
size = s.TotalSize()
newShape = Shape{size}
axis = 0
case s.IsScalar():
size = 1
// special case for row vecs
if axis == 1 {
newShape = Shape{1, 0}
} else {
// otherwise it will be repeated into a vanilla vector
newShape = Shape{0}
}
case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1:
size = 1
newShape = s.Clone()
newShape = append(newShape, 1)
default:
if axis >= len(s) {
// error
err = errors.Errorf(invalidAxis, axis, s.Dims())
return
}
size = s[axis]
newShape = s.Clone()
}
// special case to allow generic repeats
if len(repeats) == 1 {
rep := repeats[0]
repeats = make([]int, size)
for i := range repeats {
repeats[i] = rep
}
}
reps := len(repeats)
if reps != size {
err = errors.Errorf(broadcastError, size, reps)
return
}
newSize := SumInts(repeats)
newShape[axis] = newSize
finalRepeats = repeats
return
}
// Concat returns the expected new shape given the concatenation parameters
func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) {
dims := s.Dims()
// check that all the concatenates have the same dimensions
for _, shp := range ss {
if shp.Dims() != dims {
err = errors.Errorf(dimMismatch, dims, shp.Dims())
return
}
}
// special case
if axis == AllAxes {
axis = 0
}
// nope... no negative indexing here.
if axis < 0 {
err = errors.Errorf(invalidAxis, axis, len(s))
return
}
if axis >= dims {
err = errors.Errorf(invalidAxis, axis, len(s))
return
}
newShape = Shape(BorrowInts(dims))
copy(newShape, s)
for _, shp := range ss {
for d := 0; d < dims; d++ {
if d == axis {
newShape[d] += shp[d]
} else {
// validate that the rest of the dimensions match up
if newShape[d] != shp[d] {
err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d)
return
}
}
}
}
return
}
// Format implements fmt.Formatter, and formats a shape nicely
func (s Shape) Format(st fmt.State, r rune) {
switch r {
case 'v', 's':
st.Write([]byte("("))
for i, v := range s {
fmt.Fprintf(st, "%d", v)
if i < len(s)-1 {
st.Write([]byte(", "))
}
}
st.Write([]byte(")"))
default:
fmt.Fprintf(st, "%v", []int(s))
}
}
|