1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
package main
import (
"io"
"text/template"
)
const testDenseReduceRaw = `var denseReductionTests = []struct {
of Dtype
fn interface{}
def interface{}
axis int
correct interface{}
correctShape Shape
}{
{{range .Kinds -}}
{{if isNumber . -}}
// {{.}}
{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 0, []{{asType .}}{6, 8, 10, 12, 14, 16}, Shape{3,2} },
{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 1, []{{asType .}}{6, 9, 24, 27}, Shape{2, 2}},
{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 2, []{{asType .}}{1, 5, 9, 13, 17, 21}, Shape{2, 3}},
{{end -}}
{{end -}}
}
func TestDense_Reduce(t *testing.T){
assert := assert.New(t)
for _, drt := range denseReductionTests {
T := New(WithShape(2,3,2), WithBacking(Range(drt.of, 0, 2*3*2)))
T2, err := T.Reduce(drt.fn, drt.axis, drt.def, )
if err != nil {
t.Error(err)
continue
}
assert.True(drt.correctShape.Eq(T2.Shape()))
assert.Equal(drt.correct, T2.Data())
// stupids:
_, err = T.Reduce(drt.fn, 1000, drt.def,)
assert.NotNil(err)
// wrong function type
var f interface{}
f = func(a, b float64)float64{return 0}
if drt.of == Float64 {
f = func(a, b int)int{return 0}
}
_, err = T.Reduce(f, 0, drt.correct)
assert.NotNil(err)
// wrong default value type
var def2 interface{}
def2 = 3.14
if drt.of == Float64 {
def2 = int(1)
}
_, err = T.Reduce(drt.fn, 3, def2) // only last axis requires a default value
assert.NotNil(err)
}
}
`
var (
testDenseReduce *template.Template
)
func init() {
testDenseReduce = template.Must(template.New("testDenseReduce").Funcs(funcs).Parse(testDenseReduceRaw))
}
func generateDenseReductionTests(f io.Writer, generic Kinds) {
testDenseReduce.Execute(f, generic)
}
|