1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
package tensor
import (
"math/rand"
"testing"
"testing/quick"
"time"
"github.com/stretchr/testify/assert"
)
func TestDense_ShallowClone(t *testing.T) {
T := New(Of(Float64), WithBacking([]float64{1, 2, 3, 4}))
T2 := T.ShallowClone()
T2.slice(0, 2)
T2.Float64s()[0] = 1000
assert.Equal(t, T.Data().([]float64)[0:2], T2.Data())
assert.Equal(t, T.Engine(), T2.Engine())
assert.Equal(t, T.oe, T2.oe)
assert.Equal(t, T.flag, T2.flag)
}
func TestDense_Clone(t *testing.T) {
assert := assert.New(t)
cloneChk := func(q *Dense) bool {
a := q.Clone().(*Dense)
if !q.Shape().Eq(a.Shape()) {
t.Errorf("Shape Difference: %v %v", q.Shape(), a.Shape())
return false
}
if len(q.Strides()) != len(a.Strides()) {
t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
return false
}
for i, s := range q.Strides() {
if a.Strides()[i] != s {
t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
return false
}
}
if q.o != a.o {
t.Errorf("Data Order difference : %v %v", q.o, a.o)
return false
}
if q.Δ != a.Δ {
t.Errorf("Triangle Difference: %v %v", q.Δ, a.Δ)
return false
}
if q.flag != a.flag {
t.Errorf("Flag difference : %v %v", q.flag, a.flag)
return false
}
if q.e != a.e {
t.Errorf("Engine difference; %T %T", q.e, a.e)
return false
}
if q.oe != a.oe {
t.Errorf("Optimized Engine difference; %T %T", q.oe, a.oe)
return false
}
if len(q.transposeWith) != len(a.transposeWith) {
t.Errorf("TransposeWith difference: %v %v", q.transposeWith, a.transposeWith)
return false
}
assert.Equal(q.mask, a.mask, "mask difference")
assert.Equal(q.maskIsSoft, a.maskIsSoft, "mask is soft ")
return true
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
if err := quick.Check(cloneChk, &quick.Config{Rand: r}); err != nil {
t.Error(err)
}
}
func TestDenseMasked(t *testing.T) {
T := New(Of(Float64), WithShape(3, 2))
T.ResetMask()
assert.Equal(t, []bool{false, false, false, false, false, false}, T.mask)
}
func TestFromScalar(t *testing.T) {
T := New(FromScalar(3.14))
data := T.Float64s()
assert.Equal(t, []float64{3.14}, data)
}
func Test_recycledDense(t *testing.T) {
T := recycledDense(Float64, ScalarShape())
assert.Equal(t, float64(0), T.Data())
assert.Equal(t, StdEng{}, T.e)
assert.Equal(t, StdEng{}, T.oe)
}
func TestDense_unsqueeze(t *testing.T) {
assert := assert.New(t)
T := New(WithShape(3, 3, 2), WithBacking([]float64{
1, 2, 3, 4, 5, 6,
60, 50, 40, 30, 20, 10,
100, 200, 300, 400, 500, 600,
}))
if err := T.unsqueeze(0); err != nil {
t.Fatal(err)
}
assert.True(T.Shape().Eq(Shape{1, 3, 3, 2}))
assert.Equal([]int{6, 6, 2, 1}, T.Strides()) // if you do shapes.CalcStrides() it'd be {18,6,2,1}
// reset
T.Reshape(3, 3, 2)
if err := T.unsqueeze(1); err != nil {
t.Fatal(err)
}
assert.True(T.Shape().Eq(Shape{3, 1, 3, 2}))
assert.Equal([]int{6, 2, 2, 1}, T.Strides())
// reset
T.Reshape(3, 3, 2)
if err := T.unsqueeze(2); err != nil {
t.Fatal(err)
}
t.Logf("%v", T)
assert.True(T.Shape().Eq(Shape{3, 3, 1, 2}))
assert.Equal([]int{6, 2, 1, 1}, T.Strides())
// reset
T.Reshape(3, 3, 2)
if err := T.unsqueeze(3); err != nil {
t.Fatal(err)
}
t.Logf("%v", T)
assert.True(T.Shape().Eq(Shape{3, 3, 2, 1}))
assert.Equal([]int{6, 2, 1, 1}, T.Strides())
}
|