File: utils.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (70 lines) | stat: -rw-r--r-- 1,948 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <c10/util/irange.h>
#include <torch/csrc/autograd/functions/utils.h>

#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>

#include <sstream>
#include <vector>

namespace torch {
namespace autograd {

variable_list wrap_outputs(
    const variable_list& inputs,
    tensor_list&& outputs,
    const function_constructor& ctr) {
  variable_list result;
  result.reserve(outputs.size());
  if (!any_variable_requires_grad(inputs)) {
    for (auto& output : outputs) {
      if (output.defined()) {
        result.push_back(make_variable(output, /*requires_grad=*/false));
      } else {
        result.emplace_back();
      }
    }
  } else {
    auto grad_fn =
        ctr(GradMode::is_enabled() ? collect_next_edges(inputs) : edge_list());
    for (auto& output : outputs) {
      if (output.defined()) {
        auto variable =
            autograd::make_variable(output, /*requires_grad=*/false);
        autograd::create_gradient_edge(variable, grad_fn);
        result.push_back(std::move(variable));
      } else {
        grad_fn->add_input_metadata(Node::undefined_input());
        result.emplace_back();
      }
    }
  }
  return result;
}

void check_input_variables(
    const char* name,
    const variable_list& inputs,
    int args,
    int required_args,
    bool allow_undefined) {
  if (required_args == -1) {
    required_args = args;
  }
  if (inputs.size() != (size_t)args) {
    std::stringstream ss;
    ss << name << ": expected " << args << " arguments (got " << inputs.size();
    ss << ")";
    throw std::runtime_error(ss.str());
  }
  for (const auto i : c10::irange(required_args)) {
    if (!inputs[i].defined() && !allow_undefined) {
      std::stringstream ss;
      ss << name << ": expected Tensor at argument " << i << " (got None)";
      throw std::runtime_error(ss.str());
    }
  }
}
} // namespace autograd
} // namespace torch