1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
package tensor
import (
"math/rand"
"testing"
"testing/quick"
"time"
"unsafe"
)
func getMutateVal(dt Dtype) interface{} {
switch dt {
case Int:
return int(1)
case Int8:
return int8(1)
case Int16:
return int16(1)
case Int32:
return int32(1)
case Int64:
return int64(1)
case Uint:
return uint(1)
case Uint8:
return uint8(1)
case Uint16:
return uint16(1)
case Uint32:
return uint32(1)
case Uint64:
return uint64(1)
case Float32:
return float32(1)
case Float64:
return float64(1)
case Complex64:
var c complex64 = 1
return c
case Complex128:
var c complex128 = 1
return c
case Bool:
return true
case String:
return "Hello World"
case Uintptr:
return uintptr(0xdeadbeef)
case UnsafePointer:
return unsafe.Pointer(uintptr(0xdeadbeef))
}
return nil
}
func getMutateFn(dt Dtype) interface{} {
switch dt {
case Int:
return mutateI
case Int8:
return mutateI8
case Int16:
return mutateI16
case Int32:
return mutateI32
case Int64:
return mutateI64
case Uint:
return mutateU
case Uint8:
return mutateU8
case Uint16:
return mutateU16
case Uint32:
return mutateU32
case Uint64:
return mutateU64
case Float32:
return mutateF32
case Float64:
return mutateF64
case Complex64:
return mutateC64
case Complex128:
return mutateC128
case Bool:
return mutateB
case String:
return mutateStr
case Uintptr:
return mutateUintptr
case UnsafePointer:
return mutateUnsafePointer
}
return nil
}
func TestDense_Apply(t *testing.T) {
var r *rand.Rand
mut := func(q *Dense) bool {
var mutVal interface{}
if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
return true // we'll temporarily skip those we cannot mutate/get a mutation value
}
var fn interface{}
if fn = getMutateFn(q.Dtype()); fn == nil {
return true // we'll skip those that we cannot mutate
}
we, eqFail := willerr(q, nil, nil)
_, ok := q.Engine().(Mapper)
we = we || !ok
a := q.Clone().(*Dense)
correct := q.Clone().(*Dense)
correct.Memset(mutVal)
ret, err := a.Apply(fn)
if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
if err != nil {
return false
}
return true
}
if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
return false
}
// wrong fn type/illogical values
if _, err = a.Apply(getMutateFn); err == nil {
t.Error("Expected an error")
return false
}
return true
}
r = rand.New(rand.NewSource(time.Now().UnixNano()))
if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
t.Errorf("Applying mutation function failed %v", err)
}
}
func TestDense_Apply_unsafe(t *testing.T) {
var r *rand.Rand
mut := func(q *Dense) bool {
var mutVal interface{}
if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
return true // we'll temporarily skip those we cannot mutate/get a mutation value
}
var fn interface{}
if fn = getMutateFn(q.Dtype()); fn == nil {
return true // we'll skip those that we cannot mutate
}
we, eqFail := willerr(q, nil, nil)
_, ok := q.Engine().(Mapper)
we = we || !ok
a := q.Clone().(*Dense)
correct := q.Clone().(*Dense)
correct.Memset(mutVal)
ret, err := a.Apply(fn, UseUnsafe())
if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
if err != nil {
return false
}
return true
}
if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
return false
}
if ret != a {
t.Error("Expected ret == correct (Unsafe option was used)")
return false
}
return true
}
r = rand.New(rand.NewSource(time.Now().UnixNano()))
if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
t.Errorf("Applying mutation function failed %v", err)
}
}
func TestDense_Apply_reuse(t *testing.T) {
var r *rand.Rand
mut := func(q *Dense) bool {
var mutVal interface{}
if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
return true // we'll temporarily skip those we cannot mutate/get a mutation value
}
var fn interface{}
if fn = getMutateFn(q.Dtype()); fn == nil {
return true // we'll skip those that we cannot mutate
}
we, eqFail := willerr(q, nil, nil)
_, ok := q.Engine().(Mapper)
we = we || !ok
a := q.Clone().(*Dense)
reuse := q.Clone().(*Dense)
reuse.Zero()
correct := q.Clone().(*Dense)
correct.Memset(mutVal)
ret, err := a.Apply(fn, WithReuse(reuse))
if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
if err != nil {
return false
}
return true
}
if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
return false
}
if ret != reuse {
t.Error("Expected ret == correct (Unsafe option was used)")
return false
}
return true
}
r = rand.New(rand.NewSource(time.Now().UnixNano()))
if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
t.Errorf("Applying mutation function failed %v", err)
}
}
|