1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
/**
* Unique in this file is adapted from PyTorch/XLA
* https://github.com/pytorch/xla/blob/master/third_party/xla_client/unique.h
*/
#pragma once
#include <c10/util/Optional.h>
#include <functional>
#include <set>
namespace torch {
namespace lazy {
// Helper class to allow tracking zero or more things, which should be forcibly
// be one only thing.
template <typename T, typename C = std::equal_to<T>>
class Unique {
public:
std::pair<bool, const T&> set(const T& value) {
if (value_) {
TORCH_CHECK(C()(*value_, value), "'", *value_, "' vs '", value);
return std::pair<bool, const T&>(false, *value_);
}
value_ = value;
return std::pair<bool, const T&>(true, *value_);
}
operator bool() const {
return value_.has_value();
}
operator const T&() const {
return *value_;
}
const T& operator*() const {
return *value_;
}
const T* operator->() const {
return value_.operator->();
}
std::set<T> AsSet() const {
std::set<T> vset;
if (value_.has_value()) {
vset.insert(*value_);
}
return vset;
}
private:
c10::optional<T> value_;
};
} // namespace lazy
} // namespace torch
|