File: sampling_ops.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (37 lines) | stat: -rw-r--r-- 1,077 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch


# https://pytorch.org/docs/stable/torch.html#random-sampling

class SamplingOpsModule(torch.nn.Module):
    def __init__(self):
        super(SamplingOpsModule, self).__init__()

    def forward(self):
        a = torch.empty(3, 3).uniform_(0.0, 1.0)
        size = (1, 4)
        weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
        return len(
            # torch.seed(),
            # torch.manual_seed(0),
            torch.bernoulli(a),
            # torch.initial_seed(),
            torch.multinomial(weights, 2),
            torch.normal(2.0, 3.0, size),
            torch.poisson(a),
            torch.rand(2, 3),
            torch.rand_like(a),
            torch.randint(10, size),
            torch.randint_like(a, 4),
            torch.rand(4),
            torch.randn_like(a),
            torch.randperm(4),
            a.bernoulli_(),
            a.cauchy_(),
            a.exponential_(),
            a.geometric_(0.5),
            a.log_normal_(),
            a.normal_(),
            a.random_(),
            a.uniform_(),
        )