File: enforce_finite_op.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (75 lines) | stat: -rw-r--r-- 2,339 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#ifndef CAFFE_OPERATORS_ENFORCE_FINITE_OP_H_
#define CAFFE_OPERATORS_ENFORCE_FINITE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#include "c10/util/irange.h"

namespace caffe2 {

template <class Context>
class EnforceFiniteOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit EnforceFiniteOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws), ws_(ws) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType();

 private:
  Workspace* ws_;
  Tensor buffer_{CPU};

  template <typename T>
  void EnforceOnCPU(const Tensor& input) {
    const T* input_data = input.template data<T>();
    auto size = input.numel();

    for (const auto i : c10::irange(size)) {
      auto isfinite = std::isfinite(input_data[i]);
      if (!isfinite) {
        LogBlobFiniteness();
      }
      CAFFE_ENFORCE_FINITE(
        isfinite,
          "Index ",
          i,
          " is not finite (e.g., NaN, Inf): ",
          input_data[i]);
    }
  }

  // LogBlobFiniteness sums every tensor in the workspace and logs whether it's finite or not.
  void LogBlobFiniteness() {
    // This uses the aten interfaces to compute the sum and finiteness of the
    // tensors which are not present by default on xplat and mobile builds.
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    for (const std::string& blob_name : ws_->Blobs()) {
      try {
        const auto& blob = ws_->GetBlob(blob_name);
        if (blob != nullptr && blob->IsType<Tensor>()) {
          Tensor* c2Tensor = blob->GetMutable<Tensor>();
          const at::Tensor& tensor = static_cast<at::Tensor>(*c2Tensor);
          bool blob_finite = tensor.sum().isfinite().cpu().data_ptr<bool>()[0];
          LOG(INFO) << "blob " << blob_name << " isfinite=" << (blob_finite ? "true" : "false");
        }
      } catch (const std::exception& ex) {
        LOG(ERROR) << "failed to check finiteness for " << blob_name << ": " << ex.what();
      }
    }
#endif
  }
};

} // namespace caffe2

#endif // CAFFE_OPERATORS_ENFORCE_FINITE_OP_H_