File: test_utils.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (104 lines) | stat: -rw-r--r-- 3,586 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#pragma once

#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>

namespace {
static inline void trim(std::string& s) {
  s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
            return !std::isspace(ch);
          }));
  s.erase(
      std::find_if(
          s.rbegin(),
          s.rend(),
          [](unsigned char ch) { return !std::isspace(ch); })
          .base(),
      s.end());
  for (size_t i = 0; i < s.size(); ++i) {
    while (i < s.size() && s[i] == '\n') {
      s.erase(i, 1);
    }
  }
  for (size_t i = 0; i < s.size(); ++i) {
    if (s[i] == ' ') {
      while (i + 1 < s.size() && s[i + 1] == ' ') {
        s.erase(i + 1, 1);
      }
    }
  }
}
} // namespace

#define ASSERT_THROWS_WITH_MESSAGE(statement, substring)             \
  try {                                                              \
    (void)statement;                                                 \
    FAIL();                                                          \
  } catch (const std::exception& e) {                                \
    std::string substring_s(substring);                              \
    trim(substring_s);                                               \
    auto exception_string = std::string(e.what());                   \
    trim(exception_string);                                          \
    ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
        << " Error was: \n"                                          \
        << exception_string;                                         \
  }

namespace torch {
namespace jit {

using tensor_list = std::vector<at::Tensor>;
using namespace torch::autograd;

// work around the fact that variable_tensor_list doesn't duplicate all
// of std::vector's constructors.
// most constructors are never used in the implementation, just in our tests.
Stack createStack(std::vector<at::Tensor>&& list);

void assertAllClose(const tensor_list& a, const tensor_list& b);

std::vector<at::Tensor> run(
    InterpreterState& interp,
    const std::vector<at::Tensor>& inputs);

std::pair<tensor_list, tensor_list> runGradient(
    Gradient& grad_spec,
    tensor_list& tensors_in,
    tensor_list& tensor_grads_in);

std::shared_ptr<Graph> build_lstm();
std::shared_ptr<Graph> build_mobile_export_analysis_graph();
std::shared_ptr<Graph> build_mobile_export_with_out();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();

at::Tensor t_use(at::Tensor x);
at::Tensor t_def(at::Tensor x);

// given the difference of output vs expected tensor, check whether the
// difference is within a relative tolerance range. This is a standard way of
// matching tensor values up to certain precision
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
bool almostEqual(const at::Tensor& a, const at::Tensor& b);

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(
    const std::vector<at::Tensor>& a,
    const std::vector<at::Tensor>& b);

std::vector<at::Tensor> runGraph(
    std::shared_ptr<Graph> graph,
    const std::vector<at::Tensor>& inputs);

std::pair<at::Tensor, at::Tensor> lstm(
    at::Tensor input,
    at::Tensor hx,
    at::Tensor cx,
    at::Tensor w_ih,
    at::Tensor w_hh);

} // namespace jit
} // namespace torch