File: assert_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (65 lines) | stat: -rw-r--r-- 1,488 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "caffe2/operators/assert_op.h"

namespace caffe2 {

REGISTER_CPU_OPERATOR(Assert, AssertOp<CPUContext>);

OPERATOR_SCHEMA(Assert)
    .NumInputs(1)
    .NumOutputs(0)
    .SetDoc(R"DOC(
Takes in a tensor of type *bool*, *int*, *long*, or *long long* and checks if all values are True when coerced into a boolean. In other words, for non-bool types this asserts that all values in the tensor are non-zero. If a value is False after coerced into a boolean, the operator throws an error. Else, if all values are True, nothing is returned. For tracability, a custom error message can be set using the `error_msg` argument.

Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/assert_op.cc

<details>

<summary> <b>Example</b> </summary>

**Code**

```

workspace.ResetWorkspace()

op = core.CreateOperator(
    "Assert",
    ["A"],
    [],
    error_msg="Failed assertion from Assert operator"
)

workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.int32))
print("A:", workspace.FetchBlob("A"))
try:
    workspace.RunOperatorOnce(op)
except RuntimeError:
    print("Assertion Failed!")
else:
    print("Assertion Passed!")

```

**Result**

```

A:
[[7 5 6]
 [1 2 4]
 [5 3 7]]
Assertion Passed!

```

</details>

        )DOC")
    .Arg(
        "error_msg",
        "(*string*): custom error message to be thrown when the input does not pass assertion",
        false)
    .Input(0,"X","(*Tensor*): input tensor");

} // namespace caffe2