1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
package main
import (
"fmt"
"io"
"text/template"
)
const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error {
if !t.IsNativelyAccessible() {
return errors.New("Cannot select on non-natively accessible data")
}
if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) {
return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape())
}
if t.F() || t.RequiresIterator() {
return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices")
}
if t.Dtype() != dt {
return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype())
}
return nil
}
`
const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64.
func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) {
if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil {
return nil, err
}
switch t.Shape().Dims() {
case 0, 1:
retVal = make([][]{{asType .}}, 1)
retVal[0] = t.{{sliceOf .}}
case 2:
if axis == 0 {
return Matrix{{short .}}(t)
}
fallthrough
default:
// size := t.Shape()[axis]
data := t.{{sliceOf .}}
stride := t.Strides()[axis]
upper := ProdInts(t.Shape()[:axis+1])
retVal = make([][]{{asType .}}, 0, upper)
for i, r := 0, 0; r < upper; i += stride {
s := make([]{{asType .}}, 0)
hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s))
hdr.Data = uintptr(unsafe.Pointer(&data[i]))
hdr.Len = stride
hdr.Cap = stride
retVal = append(retVal, s)
r++
}
return retVal, nil
}
return
}
`
const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) {
assert := assert.New(t)
var T *Dense
var err error
var x [][]{{asType .}}
T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
if x, err = Select{{short .}}(T, 1); err != nil {
t.Fatal(err)
}
assert.Equal(6, len(x))
assert.Equal(20, len(x[0]))
T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
if x, err = Select{{short .}}(T, 0); err != nil {
t.Fatal(err)
}
assert.Equal(2, len(x))
assert.Equal(60, len(x[0]))
T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
if x, err = Select{{short .}}(T, 3); err != nil {
t.Fatal(err)
}
assert.Equal(120, len(x))
assert.Equal(1, len(x[0]))
T = New(Of({{reflectKind .}}), WithShape(2, 3), )
if x, err = Select{{short .}}(T, 0); err != nil {
t.Fatal(err)
}
assert.Equal(2, len(x))
assert.Equal(3, len(x[0]))
T = New(Of({{reflectKind .}}), WithShape(2, 3), )
if x, err = Select{{short .}}(T, 1); err != nil {
t.Fatal(err)
}
assert.Equal(6, len(x))
assert.Equal(1, len(x[0]))
T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} ))
if x, err = Select{{short .}}(T, 0); err != nil {
t.Fatal(err)
}
assert.Equal(1, len(x))
assert.Equal(1, len(x[0]))
if _, err = Select{{short .}}(T, 10); err == nil{
t.Fatal("Expected errors")
}
}
`
var (
NativeSelect *template.Template
NativeSelectTest *template.Template
)
func init() {
NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw))
NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw))
}
func generateNativeSelect(f io.Writer, ak Kinds) {
fmt.Fprintf(f, importUnqualifiedTensor)
fmt.Fprintf(f, "%v\n", checkNativeSelectable)
ks := filter(ak.Kinds, isSpecialized)
for _, k := range ks {
fmt.Fprintf(f, "/* Native Select for %v */\n\n", k)
NativeSelect.Execute(f, k)
fmt.Fprint(f, "\n\n")
}
}
func generateNativeSelectTests(f io.Writer, ak Kinds) {
fmt.Fprintf(f, importUnqualifiedTensor)
ks := filter(ak.Kinds, isSpecialized)
for _, k := range ks {
NativeSelectTest.Execute(f, k)
fmt.Fprint(f, "\n\n")
}
}
|