File: reduction_specialization.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (74 lines) | stat: -rw-r--r-- 1,967 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package main

import (
	"io"
	"reflect"
	"text/template"
)

type ReductionOp struct {
	OpName      string
	VecVec      string // sum(a, b []T)
	OpOfVec     string // sum([]T)
	GenericName string // sum(T, T) T
	Kinds       []reflect.Kind
	Typeclass   TypeClass
}

var reductionOps = []ReductionOp{
	{OpName: "Sum", VecVec: "VecAdd", OpOfVec: "Sum", GenericName: "Add", Typeclass: isNumber},
	{OpName: "Max", VecVec: "VecMax", OpOfVec: "SliceMax", GenericName: "Max", Typeclass: isNonComplexNumber},
	{OpName: "Min", VecVec: "VecMin", OpOfVec: "SliceMin", GenericName: "Min", Typeclass: isNonComplexNumber},
}

const reductionSpecializationRaw = `func Monotonic{{.OpName | title}}(t reflect.Type, a *storage.Header) (retVal interface{}, err error) {
	switch t {
		{{$opOfVec := .OpOfVec -}}
		{{range .Kinds -}}
		{{if isNumber . -}}
	case {{reflectKind .}}:
		retVal = {{$opOfVec}}{{short .}}(a.{{sliceOf .}})
		return
		{{end -}}
		{{end -}}
	default:
		err = errors.Errorf("Cannot perform {{.OpName}} on %v", t)
		return
	}
}

func {{.OpName | title}}Methods(t reflect.Type)(firstFn, lasFn, defaultFn interface{}, err error) {
	{{$vecVec := .VecVec -}}
	{{$opOfVec := .OpOfVec -}}
	{{$genericName := .GenericName -}}
	switch t {
		{{range .Kinds -}}
		{{if isNumber . -}}
	case {{reflectKind .}}:
		return {{$vecVec}}{{short .}}, {{$opOfVec}}{{short .}}, {{$genericName}}{{short .}}, nil
		{{end -}}
		{{end -}}
	default:
		return nil, nil, nil, errors.Errorf("No methods found for {{.OpName}} for %v", t)
	}
}

`

var reductionSpecialization *template.Template

func init() {
	reductionSpecialization = template.Must(template.New("reduction specialization").Funcs(funcs).Parse(reductionSpecializationRaw))
}

func generateReductionSpecialization(f io.Writer, ak Kinds) {
	for _, op := range reductionOps {
		for _, k := range ak.Kinds {
			if !op.Typeclass(k) {
				continue
			}
			op.Kinds = append(op.Kinds, k)
		}
		reductionSpecialization.Execute(f, op)
	}
}