File: distributions_stubs.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (75 lines) | stat: -rw-r--r-- 2,161 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#ifndef CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_
#define CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_

#include <c10/macros/Macros.h>

/**
 * This file provides distributions compatible with
 * ATen/core/DistributionsHelper.h but backed with the std RNG implementation
 * instead of the ATen one.
 *
 * Caffe2 mobile builds currently do not depend on all of ATen so this is
 * required to allow using the faster ATen RNG for normal builds but keep the
 * build size small on mobile. RNG performance typically doesn't matter on
 * mobile builds since the models are small and rarely using random
 * initialization.
 */

namespace at {
namespace {

template <typename R, typename T>
struct distribution_adapter {
  template <typename... Args>
  C10_HOST_DEVICE inline distribution_adapter(Args... args)
      : distribution_(std::forward<Args>(args)...) {}

  template <typename RNG>
  C10_HOST_DEVICE inline R operator()(RNG generator) {
    return distribution_(*generator);
  }

 private:
  T distribution_;
};

template <typename T>
struct uniform_int_from_to_distribution
    : distribution_adapter<T, std::uniform_int_distribution<T>> {
  C10_HOST_DEVICE inline uniform_int_from_to_distribution(
      uint64_t range,
      int64_t base)
      : distribution_adapter<T, std::uniform_int_distribution<T>>(
            base,
            // std is inclusive, at is exclusive
            base + range - 1) {}
};

template <typename T>
using uniform_real_distribution =
    distribution_adapter<T, std::uniform_real_distribution<T>>;

template <typename T>
using normal_distribution =
    distribution_adapter<T, std::normal_distribution<T>>;

template <typename T>
using bernoulli_distribution =
    distribution_adapter<T, std::bernoulli_distribution>;

template <typename T>
using exponential_distribution =
    distribution_adapter<T, std::exponential_distribution<T>>;

template <typename T>
using cauchy_distribution =
    distribution_adapter<T, std::cauchy_distribution<T>>;

template <typename T>
using lognormal_distribution =
    distribution_adapter<T, std::lognormal_distribution<T>>;

} // namespace
} // namespace at

#endif // CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_