File: api_unary.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (98 lines) | stat: -rw-r--r-- 2,147 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package main

import (
	"io"
	"text/template"
)

type APIUnary struct {
	UnaryOp
}

func (fn *APIUnary) Signature() *Signature {
	var paramNames []string
	var paramTemplates []*template.Template
	switch {
	case fn.UnaryOp.Name() == "Clamp":
		paramNames = []string{"a", "min", "max", "opts"}
		paramTemplates = []*template.Template{tensorType, interfaceType, interfaceType, splatFuncOptType}
	default:
		paramNames = []string{"a", "opts"}
		paramTemplates = []*template.Template{tensorType, splatFuncOptType}
	}
	return &Signature{
		Name:            fn.Name(),
		NameTemplate:    plainName,
		ParamNames:      paramNames,
		ParamTemplates:  paramTemplates,
		RetVals:         []string{"retVal"},
		RetValTemplates: []*template.Template{tensorType},
		Err:             true,
	}
}

func (fn *APIUnary) WriteBody(w io.Writer) {
	body := `e := a.Engine()
	if {{interfaceName .Name | lower}}, ok := e.({{interfaceName .Name}}); ok {
		{{if eq .Name "Clamp" -}}
		return clamper.Clamp(a, min, max, opts...)
		{{else -}}
		return {{interfaceName .Name|lower}}.{{.Name}}(a, opts...)
		{{end -}}
	}
	err = errors.Errorf("Engine does not perform {{.Name}}")
	return
	`

	T := template.Must(template.New("body").Funcs(funcs).Parse(body))
	T.Execute(w, fn)
}

func (fn *APIUnary) Write(w io.Writer) {
	w.Write([]byte("func "))
	sig := fn.Signature()
	sig.Write(w)
	w.Write([]byte("{ \n"))
	fn.WriteBody(w)
	w.Write([]byte("}\n\n"))
}

func generateUncondUnaryAPI(f io.Writer, kinds Kinds) {
	var unaries []*APIUnary
	for _, u := range unconditionalUnaries {
		fn := &APIUnary{
			UnaryOp: u,
		}
		unaries = append(unaries, fn)
	}
	for _, u := range unaries {
		u.Write(f)
	}
}

func generateCondUnaryAPI(f io.Writer, kinds Kinds) {
	var unaries []*APIUnary
	for _, u := range conditionalUnaries {
		fn := &APIUnary{
			UnaryOp: u,
		}
		unaries = append(unaries, fn)
	}
	for _, u := range unaries {
		u.Write(f)
	}
}

func generateSpecialUnaryAPI(f io.Writer, kinds Kinds) {
	var unaries []*APIUnary

	for _, u := range specialUnaries {
		fn := &APIUnary{
			UnaryOp: u,
		}
		unaries = append(unaries, fn)
	}
	for _, u := range unaries {
		u.Write(f)
	}
}