1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
#pragma once
#include <c10/util/Exception.h>
namespace torch {
namespace jit {
// Intrusive doubly linked lists with sane reverse iterators.
// The header file is named generic_graph_node_list.h because it is ONLY
// used for Graph's Node lists, and if you want to use it for other
// things, you will have to do some refactoring.
//
// At the moment, the templated type T must support a few operations:
//
// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
// which are used for the intrusive linked list pointers.
//
// - It must have a method 'destroy()', which removes T from the
// list and frees a T.
//
// In practice, we are only using it with Node and const Node. 'destroy()'
// needs to be renegotiated if you want to use this somewhere else.
//
// Regardless of the iteration direction, iterators always physically point
// to the element they logically point to, rather than
// the off-by-one behavior for all standard library reverse iterators like
// std::list.
// The list is includes two sentinel nodes, one at the beginning and one at the
// end with a circular link between them. It is an error to insert nodes after
// the end sentinel node but before the beginning node:
// Visualization showing only the next() links:
// HEAD -> first -> second -> ... -> last -> TAIL
// ^------------------------------------------
// Visualization showing only the prev() links:
// HEAD <- first <- second <- ... <- last <- TAIL
// ------------------------------------------^
static constexpr int kNextDirection = 0;
static constexpr int kPrevDirection = 1;
template <typename T>
struct generic_graph_node_list;
template <typename T>
struct generic_graph_node_list_iterator;
struct Node;
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
using const_graph_node_list_iterator =
generic_graph_node_list_iterator<const Node>;
template <typename T>
struct generic_graph_node_list_iterator {
generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
generic_graph_node_list_iterator(
const generic_graph_node_list_iterator& rhs) = default;
generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) =
default;
generic_graph_node_list_iterator& operator=(
const generic_graph_node_list_iterator& rhs) = default;
generic_graph_node_list_iterator& operator=(
generic_graph_node_list_iterator&& rhs) = default;
T* operator*() const {
return cur;
}
T* operator->() const {
return cur;
}
generic_graph_node_list_iterator& operator++() {
AT_ASSERT(cur);
cur = cur->next_in_graph[d];
return *this;
}
generic_graph_node_list_iterator operator++(int) {
generic_graph_node_list_iterator old = *this;
++(*this);
return old;
}
generic_graph_node_list_iterator& operator--() {
AT_ASSERT(cur);
cur = cur->next_in_graph[reverseDir()];
return *this;
}
generic_graph_node_list_iterator operator--(int) {
generic_graph_node_list_iterator old = *this;
--(*this);
return old;
}
// erase cur without invalidating this iterator
// named differently from destroy so that ->/. bugs do not
// silently cause the wrong one to be called.
// iterator will point to the previous entry after call
void destroyCurrent() {
T* n = cur;
cur = cur->next_in_graph[reverseDir()];
n->destroy();
}
generic_graph_node_list_iterator reverse() {
return generic_graph_node_list_iterator(cur, reverseDir());
}
private:
int reverseDir() {
return d == kNextDirection ? kPrevDirection : kNextDirection;
}
T* cur;
int d; // direction 0 is forward 1 is reverse, see next_in_graph
};
template <typename T>
struct generic_graph_node_list {
using iterator = generic_graph_node_list_iterator<T>;
using const_iterator = generic_graph_node_list_iterator<const T>;
generic_graph_node_list_iterator<T> begin() {
return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<const T> begin() const {
return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<T> end() {
return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d);
}
generic_graph_node_list_iterator<const T> end() const {
return generic_graph_node_list_iterator<const T>(
head->next_in_graph[!d], d);
}
generic_graph_node_list_iterator<T> rbegin() {
return reverse().begin();
}
generic_graph_node_list_iterator<const T> rbegin() const {
return reverse().begin();
}
generic_graph_node_list_iterator<T> rend() {
return reverse().end();
}
generic_graph_node_list_iterator<const T> rend() const {
return reverse().end();
}
generic_graph_node_list reverse() {
return generic_graph_node_list(head->next_in_graph[!d], !d);
}
const generic_graph_node_list reverse() const {
return generic_graph_node_list(head->next_in_graph[!d], !d);
}
T* front() {
return head->next_in_graph[d];
}
const T* front() const {
return head->next_in_graph[d];
}
T* back() {
return head->next_in_graph[!d];
}
const T* back() const {
return head->next_in_graph[!d];
}
generic_graph_node_list(T* head, int d) : head(head), d(d) {}
private:
T* head; // both head and tail are sentinel nodes
// the first real node is head->next_in_graph[d]
// the tail sentinel is head->next_in_graph[!d]
int d;
};
template <typename T>
static inline bool operator==(
generic_graph_node_list_iterator<T> a,
generic_graph_node_list_iterator<T> b) {
return *a == *b;
}
template <typename T>
static inline bool operator!=(
generic_graph_node_list_iterator<T> a,
generic_graph_node_list_iterator<T> b) {
return *a != *b;
}
} // namespace jit
} // namespace torch
namespace std {
template <typename T>
struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
using difference_type = int64_t;
using value_type = T*;
using pointer = T**;
using reference = T*&;
using iterator_category = bidirectional_iterator_tag;
};
} // namespace std
|