1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
#pragma once
#include <c10/util/string_view.h>
#include <torch/script.h>
#include <torchtext/csrc/export.h>
#include <algorithm>
namespace torchtext {
typedef std::vector<std::string> StringList;
typedef ska_ordered::order_preserving_flat_hash_map<std::string, int64_t>
IndexDict;
typedef std::tuple<
std::string,
std::vector<int64_t>,
std::vector<std::string>,
std::vector<torch::Tensor>>
VocabStates;
// sorting using a custom object
struct CompareTokens {
bool operator()(
const std::pair<std::string, int64_t>& a,
const std::pair<std::string, int64_t>& b) {
if (a.second == b.second) {
return a.first < b.first;
}
return a.second > b.second;
}
};
TORCHTEXT_API int64_t _infer_lines(const std::string& file_path);
struct Vocab : torch::CustomClassHolder {
static const int32_t MAX_VOCAB_SIZE = 30000000;
int64_t unk_index_{};
std::vector<int32_t> stoi_;
const std::string version_str_ = "0.0.2";
StringList itos_;
c10::optional<int64_t> default_index_ = {};
// TODO: [can we remove this?] we need to keep this constructor, otherwise
// torch binding gets compilation error: no matching constructor for
// initialization of 'torchtext::Vocab'
TORCHTEXT_API explicit Vocab(StringList tokens);
TORCHTEXT_API explicit Vocab(
StringList tokens,
const c10::optional<int64_t>& default_index);
TORCHTEXT_API int64_t __len__() const;
TORCHTEXT_API int64_t __getitem__(const c10::string_view& token) const;
TORCHTEXT_API bool __contains__(const c10::string_view& token) const;
TORCHTEXT_API void set_default_index(c10::optional<int64_t> index);
TORCHTEXT_API c10::optional<int64_t> get_default_index() const;
TORCHTEXT_API void insert_token(std::string token, const int64_t& index);
TORCHTEXT_API void append_token(std::string token);
TORCHTEXT_API std::string lookup_token(const int64_t& index);
TORCHTEXT_API std::vector<std::string> lookup_tokens(
const std::vector<int64_t>& indices);
std::vector<int64_t> lookup_indices(
const std::vector<c10::string_view>& tokens);
TORCHTEXT_API std::unordered_map<std::string, int64_t> get_stoi() const;
TORCHTEXT_API std::vector<std::string> get_itos() const;
protected:
uint32_t _hash(const c10::string_view& str) const {
uint32_t h = 2166136261;
for (size_t i = 0; i < str.size(); i++) {
h = h ^ uint32_t(uint8_t(str[i]));
h = h * 16777619;
}
return h;
}
uint32_t _find(const c10::string_view& w) const {
uint32_t stoi_size = stoi_.size();
uint32_t id = _hash(w) % stoi_size;
while (stoi_[id] != -1 && itos_[stoi_[id]] != w) {
id = (id + 1) % stoi_size;
}
return id;
}
void _add(std::string w) {
uint32_t h = _find(c10::string_view{w.data(), w.size()});
if (stoi_[h] == -1) {
itos_.emplace_back(std::move(w));
stoi_[h] = itos_.size() - 1;
}
}
};
TORCHTEXT_API VocabStates
_serialize_vocab(const c10::intrusive_ptr<Vocab>& self);
TORCHTEXT_API c10::intrusive_ptr<Vocab> _deserialize_vocab(VocabStates states);
} // namespace torchtext
|