1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
/**
* Most of the utils in this file is adapted from PyTorch/XLA
* https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h
*/
#pragma once
#include <exception>
#include <functional>
#include <vector>
#include <c10/util/Optional.h>
#include <c10/util/OptionalArrayRef.h>
namespace torch {
namespace lazy {
// Similar to c10::scope_exit but with a status.
// TODO(alanwaketan): Consolidate it with c10::scope_exit.
template <typename T>
class Cleanup {
public:
using StatusType = T;
explicit Cleanup(std::function<void(StatusType&&)>&& func)
: func_(std::move(func)) {}
Cleanup(Cleanup&& ref) noexcept
: func_(std::move(ref.func_)), status_(std::move(ref.status_)) {}
Cleanup(const Cleanup&) = delete;
~Cleanup() {
if (func_ != nullptr) {
func_(std::move(status_));
}
}
Cleanup& operator=(const Cleanup&) = delete;
Cleanup& operator=(Cleanup&& ref) {
if (this != &ref) {
func_ = std::move(ref.func_);
status_ = std::move(ref.status_);
}
return *this;
}
void Release() {
func_ = nullptr;
}
void SetStatus(StatusType&& status) {
status_ = std::move(status);
}
const StatusType& GetStatus() const {
return status_;
}
private:
std::function<void(StatusType&&)> func_;
StatusType status_;
};
using ExceptionCleanup = Cleanup<std::exception_ptr>;
// Allows APIs which might return const references and values, to not be forced
// to return values in the signature.
// TODO(alanwaketan): This is clever, but is there really no std or c10
// supports? Needs more investigations.
template <typename T>
class MaybeRef {
public:
/* implicit */ MaybeRef(const T& ref) : ref_(ref) {}
/* implicit */ MaybeRef(T&& value)
: storage_(std::move(value)), ref_(*storage_) {}
const T& Get() const {
return ref_;
}
const T& operator*() const {
return Get();
}
operator const T&() const {
return Get();
}
bool IsStored() const {
return storage_.has_value();
}
private:
c10::optional<T> storage_;
const T& ref_;
};
template <typename T>
std::vector<T> Iota(size_t size, T init = 0, T incr = 1) {
std::vector<T> result(size);
T value = init;
for (size_t i = 0; i < size; ++i, value += incr) {
result[i] = value;
}
return result;
}
template <typename T, typename S>
std::vector<T> ToVector(const S& input) {
return std::vector<T>(input.begin(), input.end());
}
template <typename T>
c10::optional<std::vector<T>> ToOptionalVector(
c10::OptionalArrayRef<T> arrayRef) {
if (arrayRef) {
return arrayRef->vec();
}
return c10::nullopt;
}
template <typename T>
typename std::underlying_type<T>::type GetEnumValue(T value) {
return static_cast<typename std::underlying_type<T>::type>(value);
}
} // namespace lazy
} // namespace torch
|