File: sampling.h

package info (click to toggle)
rsem 1.3.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 37,588 kB
  • sloc: cpp: 19,202; perl: 1,259; python: 1,245; ansic: 547; makefile: 186; sh: 154
file content (67 lines) | stat: -rw-r--r-- 1,713 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#ifndef SAMPLING
#define SAMPLING

#include<ctime>
#include<cstdio>
#include<cassert>
#include<vector>
#include<set>

#include "boost/random.hpp"

typedef unsigned int seedType;
typedef boost::random::mt19937 engine_type;
typedef boost::random::uniform_01<> uniform_01_dist;
typedef boost::random::gamma_distribution<> gamma_dist;
typedef boost::random::variate_generator<engine_type&, uniform_01_dist> uniform_01_generator;
typedef boost::random::variate_generator<engine_type&, gamma_dist> gamma_generator;

class engineFactory {
public:
  static void init() { seedEngine = new engine_type(time(NULL)); }
  static void init(seedType seed) { seedEngine = new engine_type(seed); }

  static void finish() { if (seedEngine != NULL) delete seedEngine; }

	static engine_type *new_engine() {
		seedType seed;
		static std::set<seedType> seedSet;			// empty set of seeds
		std::set<seedType>::iterator iter;

		do {
			seed = (*seedEngine)();
			iter = seedSet.find(seed);
		} while (iter != seedSet.end());
		seedSet.insert(seed);

		return new engine_type(seed);
	}

 private:
	static engine_type *seedEngine;
};

engine_type* engineFactory::seedEngine = NULL;

// arr should be cumulative!
// interval : [,)
// random number should be in [0, arr[len - 1])
// If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
int sample(uniform_01_generator& rg, std::vector<double>& arr, int len) {
  int l, r, mid;
  double prb = rg() * arr[len - 1];

  l = 0; r = len - 1;
  while (l <= r) {
    mid = (l + r) / 2;
    if (arr[mid] <= prb) l = mid + 1;
    else r = mid - 1;
  }

  if (l >= len) { printf("%d %lf %lf\n", len, arr[len - 1], prb); }
  assert(l < len);

  return l;
}

#endif