File: net_async_task.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (109 lines) | stat: -rw-r--r-- 2,849 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "caffe2/core/net_async_task.h"

#include "caffe2/core/net_async_task_graph.h"

namespace caffe2 {

// NOLINTNEXTLINE(modernize-pass-by-value)
AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
  CAFFE_ENFORCE(!ops_.empty());
  device_option_ = ops_.front()->device_option();
  for (auto& op : ops_) {
    CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
  }
  Reset();
}

void AsyncTask::handleChainError(
    OperatorBase* op,
    const char* err_str,
    bool save_exception) {
  std::string err_msg = err_str;
  if (op) {
    err_msg += ",  op " + (op->has_debug_def() ? op->type() : " unknown");
  }
  LOG(ERROR) << err_msg;

  // save error message and exception in chain's Event
  auto last_op = ops_.back();
  if (save_exception) {
    last_op->event().SetFinishedWithException(err_msg.c_str());
  } else {
    last_op->event().SetFinished(err_msg.c_str());
  }

  // set future as completed with an error
  // TODO: exceptions in future
  future_.SetCompleted(err_msg.c_str());
}

bool AsyncTask::Run(const ExecutionOptions& options) {
  // TODO: insert CUDA's async stream waits; tracing and counters
  OperatorBase* op = nullptr;
  try {
    // NOLINTNEXTLINE(modernize-loop-convert)
    for (auto op_idx = 0U; op_idx < ops_.size(); ++op_idx) {
      op = ops_[op_idx];
      int stream_id = 0; // TODO: thread local stream id
      if (!op->RunAsync(stream_id)) {
        handleChainError(op, "Failed to execute an op");
        return false;
      }
    }

    if (options.finish_chain_) {
      op = ops_.back();
      op->Finish();
    }

    // set the future as successfully completed or, in case of async CPU,
    // use op's callback
    if (IsCPUDeviceType(device_option_.device_type()) &&
        ops_.back()->HasAsyncPart()) {
      auto& event = ops_.back()->event();
      event.SetCallback([this, &event]() {
        CAFFE_ENFORCE(event.IsFinished());
        if (event.Query() == EventStatus::EVENT_SUCCESS) {
          future_.SetCompleted();
        } else {
          // TODO: support for exceptions
          future_.SetCompleted(event.ErrorMessage().c_str());
        }
      });
    } else {
      future_.SetCompleted();
    }
  } catch (const std::exception& e) {
    handleChainError(op, e.what(), /* save_exception */ true);
    return false;
  } catch (...) {
    handleChainError(
        op,
        "Failed to execute task: unknown error",
        /* save_exception */ true);
    return false;
  }

  return true;
}

void AsyncTask::Reset() {
  for (auto& op : ops_) {
    op->ResetEvent();
  }
  future_.ResetState();
}

DeviceOption AsyncTask::GetDeviceOption() const {
  return device_option_;
}

AsyncTaskFuture& AsyncTask::GetFuture() {
  return future_;
}

const AsyncTaskFuture& AsyncTask::GetFuture() const {
  return future_;
}

}; // namespace caffe2