1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
package main
import (
"io"
"text/template"
)
type APIUnary struct {
UnaryOp
}
func (fn *APIUnary) Signature() *Signature {
var paramNames []string
var paramTemplates []*template.Template
switch {
case fn.UnaryOp.Name() == "Clamp":
paramNames = []string{"a", "min", "max", "opts"}
paramTemplates = []*template.Template{tensorType, interfaceType, interfaceType, splatFuncOptType}
default:
paramNames = []string{"a", "opts"}
paramTemplates = []*template.Template{tensorType, splatFuncOptType}
}
return &Signature{
Name: fn.Name(),
NameTemplate: plainName,
ParamNames: paramNames,
ParamTemplates: paramTemplates,
RetVals: []string{"retVal"},
RetValTemplates: []*template.Template{tensorType},
Err: true,
}
}
func (fn *APIUnary) WriteBody(w io.Writer) {
body := `e := a.Engine()
if {{interfaceName .Name | lower}}, ok := e.({{interfaceName .Name}}); ok {
{{if eq .Name "Clamp" -}}
return clamper.Clamp(a, min, max, opts...)
{{else -}}
return {{interfaceName .Name|lower}}.{{.Name}}(a, opts...)
{{end -}}
}
err = errors.Errorf("Engine does not perform {{.Name}}")
return
`
T := template.Must(template.New("body").Funcs(funcs).Parse(body))
T.Execute(w, fn)
}
func (fn *APIUnary) Write(w io.Writer) {
w.Write([]byte("func "))
sig := fn.Signature()
sig.Write(w)
w.Write([]byte("{ \n"))
fn.WriteBody(w)
w.Write([]byte("}\n\n"))
}
func generateUncondUnaryAPI(f io.Writer, kinds Kinds) {
var unaries []*APIUnary
for _, u := range unconditionalUnaries {
fn := &APIUnary{
UnaryOp: u,
}
unaries = append(unaries, fn)
}
for _, u := range unaries {
u.Write(f)
}
}
func generateCondUnaryAPI(f io.Writer, kinds Kinds) {
var unaries []*APIUnary
for _, u := range conditionalUnaries {
fn := &APIUnary{
UnaryOp: u,
}
unaries = append(unaries, fn)
}
for _, u := range unaries {
u.Write(f)
}
}
func generateSpecialUnaryAPI(f io.Writer, kinds Kinds) {
var unaries []*APIUnary
for _, u := range specialUnaries {
fn := &APIUnary{
UnaryOp: u,
}
unaries = append(unaries, fn)
}
for _, u := range unaries {
u.Write(f)
}
}
|