File: vocab.cpp

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (183 lines) | stat: -rw-r--r-- 5,357 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#include <ATen/Parallel.h> // @manual
#include <torch/torch.h> // @manual
#include <torchtext/csrc/common.h>
#include <torchtext/csrc/vocab.h> // @manual

#include <iostream>
#include <stdexcept>
#include <string>
namespace torchtext {

Vocab::Vocab(StringList tokens, const c10::optional<int64_t>& default_index)
    : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} {
  for (auto& token : tokens) {
    // throw error if duplicate token is found
    auto id = _find(c10::string_view{token});
    TORCH_CHECK(
        stoi_[id] == -1, "Duplicate token found in tokens list: " + token);

    _add(std::move(token));
  }
}

Vocab::Vocab(StringList tokens) : Vocab(std::move(tokens), {}) {}

int64_t Vocab::__len__() const {
  return itos_.size();
}

bool Vocab::__contains__(const c10::string_view& token) const {
  int64_t id = _find(token);
  if (stoi_[id] != -1) {
    return true;
  }
  return false;
}

int64_t Vocab::__getitem__(const c10::string_view& token) const {
  int64_t id = _find(token);
  if (stoi_[id] != -1)
    return stoi_[id];

  // throw error if default_index_ is not set
  TORCH_CHECK(
      default_index_.has_value(),
      "Token " + std::string(token) +
          " not found and default index is not set");

  // return default index if token is OOV
  return default_index_.value();
}

void Vocab::set_default_index(c10::optional<int64_t> index) {
  default_index_ = index;
}

c10::optional<int64_t> Vocab::get_default_index() const {
  return default_index_;
}

void Vocab::append_token(std::string token) {
  // throw error if token already exist in vocab
  auto id = _find(c10::string_view{token});
  TORCH_CHECK(
      stoi_[id] == -1,
      "Token " + token + " already exists in the Vocab with index: " +
          std::to_string(stoi_[id]));

  _add(std::move(token));
}

void Vocab::insert_token(std::string token, const int64_t& index) {
  // throw error if index is not valid
  TORCH_CHECK(
      index >= 0 && index <= __len__(),
      "Specified index " + std::to_string(index) +
          " is out of bounds for vocab of size " + std::to_string(__len__()));

  // throw error if token already present
  TORCH_CHECK(!__contains__(token), "Token " + token + " not found in Vocab");

  // need to offset all tokens greater than or equal index by 1
  for (size_t i = index; i < __len__(); i++) {
    stoi_[_find(c10::string_view{itos_[i]})] = i + 1;
  }

  stoi_[_find(c10::string_view{token})] = index;
  itos_.insert(itos_.begin() + index, std::move(token));
}

std::string Vocab::lookup_token(const int64_t& index) {
  // throw error if index is not valid
  TORCH_CHECK(
      index >= 0 && index < __len__(),
      "Specified index " + std::to_string(index) +
          " is out of bounds for vocab of size " + std::to_string(__len__()));

  return itos_[index];
}

StringList Vocab::lookup_tokens(const std::vector<int64_t>& indices) {
  // throw error if indices are not valid
  for (size_t i = 0; i < indices.size(); i++) {
    TORCH_CHECK(
        indices[i] >= 0 && indices[i] < __len__(),
        "Specified index " + std::to_string(indices[i]) + " at position " +
            std::to_string(i) + " is out of bounds for vocab of size " +
            std::to_string(__len__()));
  }

  std::vector<std::string> tokens(indices.size());
  for (size_t i = 0; i < indices.size(); i++) {
    tokens[i] = itos_[indices[i]];
  }
  return tokens;
}

std::vector<int64_t> Vocab::lookup_indices(
    const std::vector<c10::string_view>& tokens) {
  std::vector<int64_t> indices(tokens.size());
  for (size_t i = 0; i < tokens.size(); i++) {
    indices[i] = __getitem__(tokens[i]);
  }
  return indices;
}

std::unordered_map<std::string, int64_t> Vocab::get_stoi() const {
  std::unordered_map<std::string, int64_t> stoi;
  // construct tokens and index list
  for (const auto& item : itos_) {
    stoi[item] = __getitem__(c10::string_view{item});
  }
  return stoi;
}

StringList Vocab::get_itos() const {
  return itos_;
}

VocabStates _serialize_vocab(const c10::intrusive_ptr<Vocab>& self) {
  std::vector<int64_t> integers;
  StringList strings = self->itos_;
  std::vector<torch::Tensor> tensors;

  if (self->default_index_.has_value()) {
    integers.push_back(self->default_index_.value());
  }

  VocabStates states = std::make_tuple(
      self->version_str_,
      std::move(integers),
      std::move(strings),
      std::move(tensors));
  return states;
}

c10::intrusive_ptr<Vocab> _deserialize_vocab(VocabStates states) {
  auto state_size = std::tuple_size<decltype(states)>::value;
  TORCH_CHECK(
      state_size == 4,
      "Expected deserialized Vocab to have 4 states but found " +
          std::to_string(state_size) + " states");

  auto& version_str = std::get<0>(states);
  auto& integers = std::get<1>(states);
  auto& strings = std::get<2>(states);
  auto& tensors = std::get<3>(states);

  // check tensors are empty
  TORCH_CHECK(tensors.size() == 0, "Expected `tensors` states to be empty");

  // throw error if version is not compatible
  TORCH_CHECK(
      version_str.compare("0.0.2") >= 0,
      "Found unexpected version for serialized Vocab: " + version_str);

  c10::optional<int64_t> default_index = {};
  if (integers.size() > 0) {
    default_index = integers[0];
  }
  return c10::make_intrusive<Vocab>(std::move(strings), default_index);
}

} // namespace torchtext