File: api_utils.go

package info (click to toggle)
golang-github-gorgonia-tensor 0.9.24-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,696 kB
  • sloc: sh: 18; asm: 18; makefile: 8
file content (125 lines) | stat: -rw-r--r-- 2,034 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package tensor

import (
	"log"
	"math"
	"math/rand"
	"reflect"
	"sort"

	"github.com/chewxy/math32"
)

// SortIndex is similar to numpy's argsort
// TODO: tidy this up
func SortIndex(in interface{}) (out []int) {
	switch list := in.(type) {
	case []int:
		orig := make([]int, len(list))
		out = make([]int, len(list))
		copy(orig, list)
		sort.Ints(list)
		for i, s := range list {
			for j, o := range orig {
				if o == s {
					out[i] = j
					break
				}
			}
		}
	case []float64:
		orig := make([]float64, len(list))
		out = make([]int, len(list))
		copy(orig, list)
		sort.Float64s(list)

		for i, s := range list {
			for j, o := range orig {
				if o == s {
					out[i] = j
					break
				}
			}
		}
	case sort.Interface:
		sort.Sort(list)

		log.Printf("TODO: SortIndex for sort.Interface not yet done.")
	}

	return
}

// SampleIndex samples a slice or a Tensor.
// TODO: tidy this up.
func SampleIndex(in interface{}) int {
	// var l int
	switch list := in.(type) {
	case []int:
		var sum, i int
		// l = len(list)
		r := rand.Int()
		for {
			sum += list[i]
			if sum > r && i > 0 {
				return i
			}
			i++
		}
	case []float64:
		var sum float64
		var i int
		// l = len(list)
		r := rand.Float64()
		for {
			sum += list[i]
			if sum > r && i > 0 {
				return i
			}
			i++
		}
	case *Dense:
		var i int
		switch list.t.Kind() {
		case reflect.Float64:
			var sum float64
			r := rand.Float64()
			data := list.Float64s()
			// l = len(data)
			for {
				datum := data[i]
				if math.IsNaN(datum) || math.IsInf(datum, 0) {
					return i
				}

				sum += datum
				if sum > r && i > 0 {
					return i
				}
				i++
			}
		case reflect.Float32:
			var sum float32
			r := rand.Float32()
			data := list.Float32s()
			// l = len(data)
			for {
				datum := data[i]
				if math32.IsNaN(datum) || math32.IsInf(datum, 0) {
					return i
				}

				sum += datum
				if sum > r && i > 0 {
					return i
				}
				i++
			}
		default:
			panic("not yet implemented")
		}
	default:
		panic("Not yet implemented")
	}
	return -1
}