1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
|
#pragma once
#include <c10/util/Exception.h>
#include <mutex>
#include <vector>
namespace torch::autograd::utils {
// Warning handler for multi-threaded contexts. Gather warnings from
// all threads into a single queue, then process together at the end
// in the main thread.
class DelayWarningHandler : public at::WarningHandler {
public:
~DelayWarningHandler() override = default;
void replay_warnings();
private:
void process(const c10::Warning& warning) override;
std::vector<c10::Warning> warnings_;
std::mutex mutex_;
};
} // namespace torch::autograd::utils
|