1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
package main
import (
"io"
"reflect"
"text/template"
)
type Signature struct {
Name string
NameTemplate *template.Template
ParamNames []string
ParamTemplates []*template.Template
RetVals []string
RetValTemplates []*template.Template
Kind reflect.Kind
Err bool
}
func (s *Signature) Write(w io.Writer) {
s.NameTemplate.Execute(w, s)
w.Write([]byte("("))
for i, p := range s.ParamTemplates {
w.Write([]byte(s.ParamNames[i]))
w.Write([]byte(" "))
p.Execute(w, s.Kind)
if i < len(s.ParamNames) {
w.Write([]byte(", "))
}
}
w.Write([]byte(")"))
if len(s.RetVals) > 0 {
w.Write([]byte("("))
for i, r := range s.RetValTemplates {
w.Write([]byte(s.RetVals[i]))
w.Write([]byte(" "))
r.Execute(w, s.Kind)
if i < len(s.RetVals) {
w.Write([]byte(", "))
}
}
if s.Err {
w.Write([]byte("err error"))
}
w.Write([]byte(")"))
return
}
if s.Err {
w.Write([]byte("(err error)"))
}
}
const (
golinkPragmaRaw = "//go:linkname {{.Name}}{{short .Kind}} github.com/chewxy/{{vecPkg .Kind}}{{getalias .Name}}\n"
typeAnnotatedNameRaw = `{{.Name}}{{short .Kind}}`
plainNameRaw = `{{.Name}}`
)
const (
scalarTypeRaw = `{{asType .}}`
sliceTypeRaw = `[]{{asType .}}`
iteratorTypeRaw = `Iterator`
interfaceTypeRaw = "interface{}"
boolsTypeRaw = `[]bool`
boolTypeRaw = `bool`
intTypeRaw = `int`
intsTypeRaw = `[]int`
reflectTypeRaw = `reflect.Type`
// arrayTypeRaw = `Array`
arrayTypeRaw = `*storage.Header`
unaryFuncTypeRaw = `func({{asType .}}){{asType .}} `
unaryFuncErrTypeRaw = `func({{asType .}}) ({{asType .}}, error)`
reductionFuncTypeRaw = `func(a, b {{asType .}}) {{asType .}}`
reductionFuncTypeErrRaw = `func(a, b {{asType .}}) ({{asType .}}, error)`
tensorTypeRaw = `Tensor`
splatFuncOptTypeRaw = `...FuncOpt`
denseTypeRaw = `*Dense`
testingTypeRaw = `*testing.T`
)
var (
golinkPragma *template.Template
typeAnnotatedName *template.Template
plainName *template.Template
scalarType *template.Template
sliceType *template.Template
iteratorType *template.Template
interfaceType *template.Template
boolsType *template.Template
boolType *template.Template
intType *template.Template
intsType *template.Template
reflectType *template.Template
arrayType *template.Template
unaryFuncType *template.Template
unaryFuncErrType *template.Template
tensorType *template.Template
splatFuncOptType *template.Template
denseType *template.Template
testingType *template.Template
)
func init() {
golinkPragma = template.Must(template.New("golinkPragma").Funcs(funcs).Parse(golinkPragmaRaw))
typeAnnotatedName = template.Must(template.New("type annotated name").Funcs(funcs).Parse(typeAnnotatedNameRaw))
plainName = template.Must(template.New("plainName").Funcs(funcs).Parse(plainNameRaw))
scalarType = template.Must(template.New("scalarType").Funcs(funcs).Parse(scalarTypeRaw))
sliceType = template.Must(template.New("sliceType").Funcs(funcs).Parse(sliceTypeRaw))
iteratorType = template.Must(template.New("iteratorType").Funcs(funcs).Parse(iteratorTypeRaw))
interfaceType = template.Must(template.New("interfaceType").Funcs(funcs).Parse(interfaceTypeRaw))
boolsType = template.Must(template.New("boolsType").Funcs(funcs).Parse(boolsTypeRaw))
boolType = template.Must(template.New("boolType").Funcs(funcs).Parse(boolTypeRaw))
intType = template.Must(template.New("intTYpe").Funcs(funcs).Parse(intTypeRaw))
intsType = template.Must(template.New("intsType").Funcs(funcs).Parse(intsTypeRaw))
reflectType = template.Must(template.New("reflectType").Funcs(funcs).Parse(reflectTypeRaw))
arrayType = template.Must(template.New("arrayType").Funcs(funcs).Parse(arrayTypeRaw))
unaryFuncType = template.Must(template.New("unaryFuncType").Funcs(funcs).Parse(unaryFuncTypeRaw))
unaryFuncErrType = template.Must(template.New("unaryFuncErrType").Funcs(funcs).Parse(unaryFuncErrTypeRaw))
tensorType = template.Must(template.New("tensorType").Funcs(funcs).Parse(tensorTypeRaw))
splatFuncOptType = template.Must(template.New("splatFuncOpt").Funcs(funcs).Parse(splatFuncOptTypeRaw))
denseType = template.Must(template.New("*Dense").Funcs(funcs).Parse(denseTypeRaw))
testingType = template.Must(template.New("*testing.T").Funcs(funcs).Parse(testingTypeRaw))
}
|