1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file
package sampleuv
import "golang.org/x/exp/rand"
// Weighted provides sampling without replacement from a collection of items with
// non-uniform probability.
type Weighted struct {
weights []float64
// heap is a weight heap.
//
// It keeps a heap-organised sum of remaining
// index weights that are available to be taken
// from.
//
// Each element holds the sum of weights for
// the corresponding index, plus the sum of
// its children's weights; the children of
// an element i can be found at positions
// 2*(i+1)-1 and 2*(i+1). The root of the
// weight heap is at element 0.
//
// See comments in container/heap for an
// explanation of the layout of a heap.
heap []float64
rnd *rand.Rand
}
// NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is
// used as the random number generator.
//
// Note that sampling from weights with a high variance or overall low absolute
// value sum may result in problems with numerical stability.
func NewWeighted(w []float64, src rand.Source) Weighted {
s := Weighted{
weights: make([]float64, len(w)),
heap: make([]float64, len(w)),
}
if src != nil {
s.rnd = rand.New(src)
}
s.ReweightAll(w)
return s
}
// Len returns the number of items held by the Weighted, including items
// already taken.
func (s Weighted) Len() int { return len(s.weights) }
// Take returns an index from the Weighted with probability proportional
// to the weight of the item. The weight of the item is then set to zero.
// Take returns false if there are no items remaining.
func (s Weighted) Take() (idx int, ok bool) {
if s.heap[0] == 0 {
return -1, false
}
var r float64
if s.rnd == nil {
r = rand.Float64()
} else {
r = s.rnd.Float64()
}
r *= s.heap[0]
i := 0
for {
r -= s.weights[i]
if r < 0 {
break // Fall within item i.
}
li := i*2 + 1 // Move to left child.
// Left node should exist, because r is non-negative,
// but there could be floating point errors, so we
// check index explicitly.
if li >= len(s.heap) {
break
}
i = li
d := s.heap[i]
if r >= d {
// If there is enough r to pass left child try to
// move to the right child.
r -= d
ri := i + 1
if ri >= len(s.heap) {
break
}
i = ri
}
}
s.Reweight(i, 0)
return i, true
}
// Reweight sets the weight of item idx to w.
func (s Weighted) Reweight(idx int, w float64) {
s.weights[idx] = w
// We want to keep the heap state here consistent
// with the result of a reset call. So we sum
// weights in the same order, since floating point
// addition is not associative.
for {
w = s.weights[idx]
ri := idx*2 + 2
if ri < len(s.heap) {
w += s.heap[ri]
}
li := ri - 1
if li < len(s.heap) {
w += s.heap[li]
}
s.heap[idx] = w
if idx == 0 {
break
}
idx = (idx - 1) / 2
}
}
// ReweightAll sets the weight of all items in the Weighted. ReweightAll
// panics if len(w) != s.Len.
func (s Weighted) ReweightAll(w []float64) {
if len(w) != s.Len() {
panic("floats: length of the slices do not match")
}
copy(s.weights, w)
s.reset()
}
func (s Weighted) reset() {
copy(s.heap, s.weights)
for i := len(s.heap) - 1; i > 0; i-- {
// Sometimes 1-based counting makes sense.
s.heap[((i+1)>>1)-1] += s.heap[i]
}
}
|