1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
/**
* This file is adapted from PyTorch/XLA
* https://github.com/pytorch/xla/blob/master/third_party/xla_client/multi_wait.h
*/
#pragma once
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <c10/macros/Export.h>
namespace torch {
namespace lazy {
// Support waiting for a number of tasks to complete.
class TORCH_API MultiWait {
public:
explicit MultiWait(size_t count) : count_(count) {}
// Signal the completion of a single task.
void Done();
// Waits until at least count (passed as constructor value) completions
// happened.
void Wait();
// Same as above, but waits up to wait_seconds.
void Wait(double wait_seconds);
// Resets the threshold counter for the MultiWait object. The completed count
// is also reset to zero.
void Reset(size_t count);
// Creates a completer functor which signals the mult wait object once func
// has completed. Handles exceptions by signaling the multi wait with the
// proper status value. This API returns a function which captures a MultiWait
// reference, so care must be taken such that the reference remains valid for
// the whole lifetime of the returned function.
std::function<void()> Completer(std::function<void()> func);
// Similar as the above API, but with explicit capture of the MultiWait shared
// pointer.
static std::function<void()> Completer(
std::shared_ptr<MultiWait> mwait,
std::function<void()> func);
private:
void Complete(const std::function<void()>& func);
std::mutex mutex_;
std::condition_variable cv_;
size_t count_ = 0;
size_t completed_count_ = 0;
std::exception_ptr exptr_;
};
} // namespace lazy
} // namespace torch
|