File: vocab.h

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (94 lines) | stat: -rw-r--r-- 3,100 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#pragma once
#include <c10/util/string_view.h>
#include <torch/script.h>
#include <torchtext/csrc/export.h>
#include <algorithm>

namespace torchtext {

typedef std::vector<std::string> StringList;
typedef ska_ordered::order_preserving_flat_hash_map<std::string, int64_t>
    IndexDict;
typedef std::tuple<
    std::string,
    std::vector<int64_t>,
    std::vector<std::string>,
    std::vector<torch::Tensor>>
    VocabStates;

// sorting using a custom object
struct CompareTokens {
  bool operator()(
      const std::pair<std::string, int64_t>& a,
      const std::pair<std::string, int64_t>& b) {
    if (a.second == b.second) {
      return a.first < b.first;
    }
    return a.second > b.second;
  }
};

TORCHTEXT_API int64_t _infer_lines(const std::string& file_path);

struct Vocab : torch::CustomClassHolder {
  static const int32_t MAX_VOCAB_SIZE = 30000000;
  int64_t unk_index_{};
  std::vector<int32_t> stoi_;
  const std::string version_str_ = "0.0.2";
  StringList itos_;
  c10::optional<int64_t> default_index_ = {};

  // TODO: [can we remove this?] we need to keep this constructor, otherwise
  // torch binding gets compilation error: no matching constructor for
  // initialization of 'torchtext::Vocab'
  TORCHTEXT_API explicit Vocab(StringList tokens);
  TORCHTEXT_API explicit Vocab(
      StringList tokens,
      const c10::optional<int64_t>& default_index);
  TORCHTEXT_API int64_t __len__() const;
  TORCHTEXT_API int64_t __getitem__(const c10::string_view& token) const;
  TORCHTEXT_API bool __contains__(const c10::string_view& token) const;
  TORCHTEXT_API void set_default_index(c10::optional<int64_t> index);
  TORCHTEXT_API c10::optional<int64_t> get_default_index() const;
  TORCHTEXT_API void insert_token(std::string token, const int64_t& index);
  TORCHTEXT_API void append_token(std::string token);
  TORCHTEXT_API std::string lookup_token(const int64_t& index);
  TORCHTEXT_API std::vector<std::string> lookup_tokens(
      const std::vector<int64_t>& indices);
  std::vector<int64_t> lookup_indices(
      const std::vector<c10::string_view>& tokens);
  TORCHTEXT_API std::unordered_map<std::string, int64_t> get_stoi() const;
  TORCHTEXT_API std::vector<std::string> get_itos() const;

 protected:
  uint32_t _hash(const c10::string_view& str) const {
    uint32_t h = 2166136261;
    for (size_t i = 0; i < str.size(); i++) {
      h = h ^ uint32_t(uint8_t(str[i]));
      h = h * 16777619;
    }
    return h;
  }

  uint32_t _find(const c10::string_view& w) const {
    uint32_t stoi_size = stoi_.size();
    uint32_t id = _hash(w) % stoi_size;
    while (stoi_[id] != -1 && itos_[stoi_[id]] != w) {
      id = (id + 1) % stoi_size;
    }
    return id;
  }

  void _add(std::string w) {
    uint32_t h = _find(c10::string_view{w.data(), w.size()});
    if (stoi_[h] == -1) {
      itos_.emplace_back(std::move(w));
      stoi_[h] = itos_.size() - 1;
    }
  }
};

TORCHTEXT_API VocabStates
_serialize_vocab(const c10::intrusive_ptr<Vocab>& self);
TORCHTEXT_API c10::intrusive_ptr<Vocab> _deserialize_vocab(VocabStates states);
} // namespace torchtext