1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <iterator>
namespace torch {
namespace jit {
namespace fuser {
namespace ir_utils {
template <typename FilterType, typename Iterator>
class FilterIterator {
public:
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = FilterType*;
using pointer = value_type*;
using reference = value_type&;
FilterIterator(Iterator begin, Iterator end) : current_(begin), end_(end) {
advance();
}
FilterType* operator*() const {
return (*current_)->template as<FilterType>();
}
FilterType* operator->() const {
return (*this);
}
FilterIterator& operator++() {
++current_;
advance();
return *this;
}
FilterIterator operator++(int) {
const auto before_increment = *this;
++current_;
advance();
return before_increment;
}
bool operator==(const FilterIterator& other) const {
TORCH_INTERNAL_ASSERT(
end_ == other.end_,
"Comparing two FilteredViews that originate from different containers");
return current_ == other.current_;
}
bool operator!=(const FilterIterator& other) const {
return !(*this == other);
}
private:
void advance() {
current_ = std::find_if(current_, end_, [](const auto& val) {
return dynamic_cast<const FilterType*>(val) != nullptr;
});
}
private:
Iterator current_;
const Iterator end_;
};
// An iterable view to a given container of Val pointers. Only returns
// Vals of a given Val type.
// NOTE: Add a non-const iterator if needed.
template <typename FilterType, typename InputIt>
class FilteredView {
public:
using value_type = FilterType*;
using const_iterator = FilterIterator<FilterType, InputIt>;
FilteredView(InputIt first, InputIt last) : input_it_(first), last_(last) {}
const_iterator cbegin() const {
return const_iterator(input_it_, last_);
}
const_iterator begin() const {
return cbegin();
}
const_iterator cend() const {
return const_iterator(last_, last_);
}
const_iterator end() const {
return cend();
}
private:
const InputIt input_it_;
const InputIt last_;
};
template <typename FilterType, typename InputIt>
auto filterByType(InputIt first, InputIt last) {
return FilteredView<FilterType, InputIt>(first, last);
}
template <typename FilterType, typename ContainerType>
auto filterByType(const ContainerType& inputs) {
return filterByType<FilterType>(inputs.cbegin(), inputs.cend());
}
} // namespace ir_utils
} // namespace fuser
} // namespace jit
} // namespace torch
|