File: type_test.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (66 lines) | stat: -rw-r--r-- 1,699 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package tensor

import (
	"reflect"
	"testing"
)

type Float16 uint16

func TestRegisterType(t *testing.T) {
	dt := Dtype{reflect.TypeOf(Float16(0))}
	RegisterFloat(dt)

	if err := typeclassCheck(dt, floatTypes); err != nil {
		t.Errorf("Expected %v to be in floatTypes: %v", dt, err)
	}
	if err := typeclassCheck(dt, numberTypes); err != nil {
		t.Errorf("Expected %v to be in numberTypes: %v", dt, err)
	}
	if err := typeclassCheck(dt, ordTypes); err != nil {
		t.Errorf("Expected %v to be in ordTypes: %v", dt, err)
	}
	if err := typeclassCheck(dt, eqTypes); err != nil {
		t.Errorf("Expected %v to be in eqTypes: %v", dt, err)
	}

}

func TestDtypeConversions(t *testing.T) {
	for k, v := range reverseNumpyDtypes {
		if npdt, err := v.numpyDtype(); npdt != k {
			t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt)
		} else if err != nil {
			t.Errorf("Error: %v", err)
		}
	}
	dt := Dtype{reflect.TypeOf(Float16(0))}
	if _, err := dt.numpyDtype(); err == nil {
		t.Errorf("Expected an error when passing in type unknown to np")
	}

	for k, v := range numpyDtypes {
		if dt, err := fromNumpyDtype(v); dt != k {
			// special cases
			if Int.Size() == 4 && v == "i4" && dt == Int {
				continue
			}
			if Int.Size() == 8 && v == "i8" && dt == Int {
				continue
			}

			if Uint.Size() == 4 && v == "u4" && dt == Uint {
				continue
			}
			if Uint.Size() == 8 && v == "u8" && dt == Uint {
				continue
			}
			t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt)
		} else if err != nil {
			t.Errorf("Error: %v", err)
		}
	}
	if _, err := fromNumpyDtype("EDIUH"); err == nil {
		t.Error("Expected error when nonsense is passed into fromNumpyDtype")
	}
}