1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
//===--- Trie.h - Trie with terms as keys ---------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2021 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_RQM_TRIE_H
#define SWIFT_RQM_TRIE_H
#include "llvm/ADT/MapVector.h"
#include "Histogram.h"
namespace swift {
namespace rewriting {
enum class MatchKind {
Shortest,
Longest
};
template<typename ValueType, MatchKind Kind>
class Trie {
public:
struct Node;
struct Entry {
std::optional<ValueType> Value;
Node *Children = nullptr;
};
struct Node {
llvm::SmallMapVector<Symbol, Entry, 1> Entries;
};
private:
/// We never delete nodes, except for when the entire trie is torn down.
std::vector<Node *> Nodes;
/// The root is stored directly.
Node Root;
public:
void updateHistograms(Histogram &stats, Histogram &rootStats) const {
for (const auto &node : Nodes)
stats.add(node->Entries.size());
rootStats.add(Root.Entries.size());
}
/// Delete all entries from the trie.
void clear() {
Root.Entries.clear();
for (auto iter = Nodes.rbegin(); iter != Nodes.rend(); ++iter) {
auto *node = *iter;
delete node;
}
Nodes.clear();
}
~Trie() {
clear();
}
/// Insert an entry with the key given by the range [begin, end).
/// Returns the old value if the trie already had an entry for this key;
/// this is actually an invariant violation, but we can produce a better
/// assertion further up the stack.
template <typename Iter>
std::optional<ValueType> insert(Iter begin, Iter end, ValueType value) {
assert(begin != end);
auto *node = &Root;
while (true) {
auto &entry = node->Entries[*begin];
++begin;
if (begin == end) {
auto oldValue = entry.Value;
entry.Value = value;
return oldValue;
}
if (entry.Children == nullptr) {
entry.Children = new Node();
Nodes.push_back(entry.Children);
}
node = entry.Children;
}
}
/// Find the shortest or longest prefix of the range given by [begin,end),
/// depending on whether the Kind template parameter was bound to
/// MatchKind::Shortest or MatchKind::Longest.
template <typename Iter>
std::optional<ValueType> find(Iter begin, Iter end) const {
assert(begin != end);
auto *node = &Root;
std::optional<ValueType> bestMatch = std::nullopt;
while (true) {
auto found = node->Entries.find(*begin);
++begin;
if (found == node->Entries.end())
return bestMatch;
const auto &entry = found->second;
if (entry.Value) {
if (Kind == MatchKind::Shortest)
return entry.Value;
bestMatch = entry.Value;
}
if (begin == end)
return bestMatch;
if (entry.Children == nullptr)
return bestMatch;
node = entry.Children;
}
}
/// Find all keys that begin with the given symbol. Fn must take a single
/// argument of type ValueType.
template<typename Fn>
void findAll(Symbol symbol, Fn fn) const {
auto found = Root.Entries.find(symbol);
if (found == Root.Entries.end())
return;
const auto &entry = found->second;
if (entry.Value)
fn(*entry.Value);
if (entry.Children == nullptr)
return;
visitChildren(entry.Children, fn);
}
/// Find all keys that either match a prefix of [begin,end), or where
/// [begin,end) matches a prefix of the key. Fn must take a single
/// argument of type ValueType.
template<typename Iter, typename Fn>
void findAll(Iter begin, Iter end, Fn fn) const {
assert(begin != end);
auto *node = &Root;
while (true) {
auto found = node->Entries.find(*begin);
++begin;
if (found == node->Entries.end())
return;
const auto &entry = found->second;
if (entry.Value)
fn(*entry.Value);
if (entry.Children == nullptr)
return;
node = entry.Children;
if (begin == end) {
visitChildren(node, fn);
return;
}
}
}
private:
/// Depth-first traversal of all children of the given node, including
/// the node itself. Fn must take a single argument of type ValueType.
template<typename Fn>
void visitChildren(const Node *node, Fn fn) const {
for (const auto &pair : node->Entries) {
const auto &entry = pair.second;
if (entry.Value)
fn(*entry.Value);
if (entry.Children)
visitChildren(entry.Children, fn);
}
}
};
} // end namespace rewriting
} // end namespace swift
#endif
|