1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
|
package tensor
import (
"fmt"
"math"
"reflect"
"unsafe"
"github.com/chewxy/hm"
"github.com/pkg/errors"
)
// Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type
// which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia
type Dtype struct {
reflect.Type
}
// note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods
func (dt Dtype) Apply(hm.Subs) hm.Substitutable { return dt }
func (dt Dtype) FreeTypeVar() hm.TypeVarSet { return nil }
func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil }
func (dt Dtype) Types() hm.Types { return nil }
func (dt Dtype) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", dt.Name()) }
func (dt Dtype) Eq(other hm.Type) bool { return other == dt }
var numpyDtypes map[Dtype]string
var reverseNumpyDtypes map[string]Dtype
func init() {
numpyDtypes = map[Dtype]string{
Bool: "b1",
Int: fmt.Sprintf("i%d", Int.Size()),
Int8: "i1",
Int16: "i2",
Int32: "i4",
Int64: "i8",
Uint: fmt.Sprintf("u%d", Uint.Size()),
Uint8: "u1",
Uint16: "u2",
Uint32: "u4",
Uint64: "u8",
Float32: "f4",
Float64: "f8",
Complex64: "c8",
Complex128: "c16",
}
reverseNumpyDtypes = map[string]Dtype{
"b1": Bool,
"i1": Int8,
"i2": Int16,
"i4": Int32,
"i8": Int64,
"u1": Uint8,
"u2": Uint16,
"u4": Uint32,
"u8": Uint64,
"f4": Float32,
"f8": Float64,
"c8": Complex64,
"c16": Complex128,
}
}
// NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray,
// however, not all Dtypes are supported
func (dt Dtype) numpyDtype() (string, error) {
retVal, ok := numpyDtypes[dt]
if !ok {
return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt)
}
return retVal, nil
}
func fromNumpyDtype(t string) (Dtype, error) {
retVal, ok := reverseNumpyDtypes[t]
if !ok {
return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t)
}
if t == "i4" && Int.Size() == 4 {
return Int, nil
}
if t == "i8" && Int.Size() == 8 {
return Int, nil
}
if t == "u4" && Uint.Size() == 4 {
return Uint, nil
}
if t == "u8" && Uint.Size() == 8 {
return Uint, nil
}
return retVal, nil
}
type typeclass struct {
name string
set []Dtype
}
var parameterizedKinds = [...]reflect.Kind{
reflect.Array,
reflect.Chan,
reflect.Func,
reflect.Interface,
reflect.Map,
reflect.Ptr,
reflect.Slice,
reflect.Struct,
}
func isParameterizedKind(k reflect.Kind) bool {
for _, v := range parameterizedKinds {
if v == k {
return true
}
}
return false
}
// oh how nice it'd be if I could make them immutable
var (
Bool = Dtype{reflect.TypeOf(true)}
Int = Dtype{reflect.TypeOf(int(1))}
Int8 = Dtype{reflect.TypeOf(int8(1))}
Int16 = Dtype{reflect.TypeOf(int16(1))}
Int32 = Dtype{reflect.TypeOf(int32(1))}
Int64 = Dtype{reflect.TypeOf(int64(1))}
Uint = Dtype{reflect.TypeOf(uint(1))}
Uint8 = Dtype{reflect.TypeOf(uint8(1))}
Uint16 = Dtype{reflect.TypeOf(uint16(1))}
Uint32 = Dtype{reflect.TypeOf(uint32(1))}
Uint64 = Dtype{reflect.TypeOf(uint64(1))}
Float32 = Dtype{reflect.TypeOf(float32(1))}
Float64 = Dtype{reflect.TypeOf(float64(1))}
Complex64 = Dtype{reflect.TypeOf(complex64(1))}
Complex128 = Dtype{reflect.TypeOf(complex128(1))}
String = Dtype{reflect.TypeOf("")}
// aliases
Byte = Uint8
// extras
Uintptr = Dtype{reflect.TypeOf(uintptr(0))}
UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))}
)
// allTypes for indexing
var allTypes = &typeclass{
name: "τ",
set: []Dtype{
Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer,
},
}
// specialized types indicate that there are specialized code generated for these types
var specializedTypes = &typeclass{
name: "Specialized",
set: []Dtype{
Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String,
},
}
var addableTypes = &typeclass{
name: "Addable",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String,
},
}
var numberTypes = &typeclass{
name: "Number",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128,
},
}
var ordTypes = &typeclass{
name: "Ord",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String,
},
}
var eqTypes = &typeclass{
name: "Eq",
set: []Dtype{
Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer,
},
}
var unsignedTypes = &typeclass{
name: "Unsigned",
set: []Dtype{Uint, Uint8, Uint16, Uint32, Uint64},
}
var signedTypes = &typeclass{
name: "Signed",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128,
},
}
// this typeclass is ever only used by Sub tests
var signedNonComplexTypes = &typeclass{
name: "Signed NonComplex",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Float32, Float64,
},
}
var floatTypes = &typeclass{
name: "Float",
set: []Dtype{
Float32, Float64,
},
}
var complexTypes = &typeclass{
name: "Complex Numbers",
set: []Dtype{Complex64, Complex128},
}
var floatcmplxTypes = &typeclass{
name: "Real",
set: []Dtype{
Float32, Float64, Complex64, Complex128,
},
}
var nonComplexNumberTypes = &typeclass{
name: "Non complex numbers",
set: []Dtype{
Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64,
},
}
// this typeclass is ever only used by Pow tests
var generatableTypes = &typeclass{
name: "Generatable types",
set: []Dtype{
Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String,
},
}
func isFloat(dt Dtype) bool {
return dt == Float64 || dt == Float32
}
func typeclassCheck(a Dtype, tc *typeclass) error {
if tc == nil {
return nil
}
for _, s := range tc.set {
if s == a {
return nil
}
}
return errors.Errorf("Type %v is not a member of %v", a, tc.name)
}
// RegisterNumber is a function required to register a new numerical Dtype.
// This package provides the following Dtype:
// Int
// Int8
// Int16
// Int32
// Int64
// Uint
// Uint8
// Uint16
// Uint32
// Uint64
// Float32
// Float64
// Complex64
// Complex128
//
// If a Dtype that is registered already exists on the list, it will not be added to the list.
func RegisterNumber(a Dtype) {
for _, dt := range numberTypes.set {
if dt == a {
return
}
}
numberTypes.set = append(numberTypes.set, a)
RegisterEq(a)
}
func RegisterFloat(a Dtype) {
for _, dt := range floatTypes.set {
if dt == a {
return
}
}
floatTypes.set = append(floatTypes.set, a)
RegisterNumber(a)
RegisterOrd(a)
}
// RegisterOrd registers a dtype as a type that can be typed
func RegisterOrd(a Dtype) {
for _, dt := range ordTypes.set {
if dt == a {
return
}
}
ordTypes.set = append(ordTypes.set, a)
RegisterEq(a)
}
// RegisterEq registers a dtype as a type that can be compared for equality
func RegisterEq(a Dtype) {
for _, dt := range eqTypes.set {
if dt == a {
return
}
}
eqTypes.set = append(eqTypes.set, a)
Register(a)
}
// Register registers a new Dtype
func Register(a Dtype) {
for _, dt := range allTypes.set {
if a == dt {
return
}
}
allTypes.set = append(allTypes.set, a)
}
func dtypeID(a Dtype) int {
for i, v := range allTypes.set {
if a == v {
return i
}
}
return -1
}
// NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte.
// But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm.
// So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best,
// Instead, we use constructors. Both Nuclear and Frobenius norm types are represented as NaNs
//
// The using of NaN and Inf as "special" Norm types lead to the need for IsInf() and IsFrobenius() and IsNuclear() method
type NormOrder float64
func Norm(ord int) NormOrder { return NormOrder(float64(ord)) }
func InfNorm() NormOrder { return NormOrder(math.Inf(1)) }
func NegInfNorm() NormOrder { return NormOrder(math.Inf(-1)) }
func UnorderedNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000001)) }
func FrobeniusNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000002)) }
func NuclearNorm() NormOrder { return NormOrder(math.Float64frombits(0x7ff8000000000003)) }
// Valid() is a helper method that deterines if the norm order is valid. A valid norm order is
// one where the fraction component is 0
func (n NormOrder) Valid() bool {
switch {
case math.IsNaN(float64(n)):
nb := math.Float64bits(float64(n))
if math.Float64bits(float64(UnorderedNorm())) == nb || math.Float64bits(float64(FrobeniusNorm())) == nb || math.Float64bits(float64(NuclearNorm())) == nb {
return true
}
case math.IsInf(float64(n), 0):
return true
default:
if _, frac := math.Modf(float64(n)); frac == 0.0 {
return true
}
}
return false
}
// IsUnordered returns true if the NormOrder is not an ordered norm
func (n NormOrder) IsUnordered() bool {
return math.Float64bits(float64(n)) == math.Float64bits(float64(UnorderedNorm()))
}
// IsFrobenius returns true if the NormOrder is a Frobenius norm
func (n NormOrder) IsFrobenius() bool {
return math.Float64bits(float64(n)) == math.Float64bits(float64(FrobeniusNorm()))
}
// IsNuclear returns true if the NormOrder is a nuclear norm
func (n NormOrder) IsNuclear() bool {
return math.Float64bits(float64(n)) == math.Float64bits(float64(NuclearNorm()))
}
func (n NormOrder) IsInf(sign int) bool {
return math.IsInf(float64(n), sign)
}
func (n NormOrder) String() string {
switch {
case n.IsUnordered():
return "Unordered"
case n.IsFrobenius():
return "Frobenius"
case n.IsNuclear():
return "Nuclear"
case n.IsInf(1):
return "+Inf"
case n.IsInf(-1):
return "-Inf"
default:
return fmt.Sprintf("Norm %v", float64(n))
}
panic("unreachable")
}
// FuncOpt are optionals for calling Tensor function.
type FuncOpt func(*OpOpt)
// WithIncr passes in a Tensor to be incremented.
func WithIncr(incr Tensor) FuncOpt {
f := func(opt *OpOpt) {
opt.incr = incr
}
return f
}
// WithReuse passes in a Tensor to be reused.
func WithReuse(reuse Tensor) FuncOpt {
f := func(opt *OpOpt) {
opt.reuse = reuse
}
return f
}
// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions
func UseSafe() FuncOpt {
f := func(opt *OpOpt) {
opt.unsafe = false
}
return f
}
// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace
func UseUnsafe() FuncOpt {
f := func(opt *OpOpt) {
opt.unsafe = true
}
return f
}
// AsSameType makes sure that the return Tensor is the same type as input Tensors.
func AsSameType() FuncOpt {
f := func(opt *OpOpt) {
opt.same = true
}
return f
}
// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64
func As(t Dtype) FuncOpt {
f := func(opt *OpOpt) {
opt.t = t
}
return f
}
|