File: on_device_tail_tokenizer.cc

package info (click to toggle)
chromium 139.0.7258.127-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,122,156 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (286 lines) | stat: -rw-r--r-- 8,571 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/omnibox/browser/on_device_tail_tokenizer.h"

#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/files/file_util.h"
#include "base/logging.h"
#include "base/strings/string_util.h"
#include "components/omnibox/browser/omnibox_field_trial.h"

namespace {
// Maximum vocabulary file size that will be loaded in bytes.
static constexpr size_t kVocabFileSizeLimit = 64 * 1024;

// The max num of single char tokens where token IDs are directly mapped to
// ASCII characters.
// Token IDs greater than kNumSingleChar are special control tokens or
// multi-char tokens specified by the given vocabulary file.
static constexpr size_t kNumSingleChar = 256;

// Special control tokens.
static constexpr char kBeginQueryToken[] = "<Q>";
static constexpr char kEndQueryToken[] = "</Q>";
static constexpr char kEmptyPreviousQueryToken[] = "<NPQ>";
static constexpr char kUnknownToken[] = "<UNK>";

std::ostream& operator<<(std::ostream& os,
                         const base::flat_set<std::string>& tokens) {
  if (tokens.empty()) {
    return os;
  }

  auto iter = tokens.begin();
  os << *iter;
  ++iter;

  for (; iter != tokens.end(); iter++) {
    os << ", " << *iter;
  }
  return os;
}

}  // namespace

OnDeviceTailTokenizer::Tokenization::Tokenization() = default;

OnDeviceTailTokenizer::Tokenization::~Tokenization() = default;

OnDeviceTailTokenizer::OnDeviceTailTokenizer() = default;

OnDeviceTailTokenizer::~OnDeviceTailTokenizer() = default;

bool OnDeviceTailTokenizer::Init(const base::FilePath& vocabulary_filepath) {
  std::string vocabulary_content;
  Reset();
  if (!base::ReadFileToStringWithMaxSize(
          vocabulary_filepath, &vocabulary_content, kVocabFileSizeLimit)) {
    DVLOG(1) << "Failed to read the vocabulary file "
             << vocabulary_filepath.LossyDisplayName();
    return false;
  }

  base::flat_set<std::string> control_tokens = {
      kBeginQueryToken, kEndQueryToken, kEmptyPreviousQueryToken,
      kUnknownToken};

  std::string token;
  max_token_length_ = 0;

  // The first 256 tokens are ASCII characters.
  for (size_t i = 0; i < kNumSingleChar; i++) {
    token = static_cast<char>(i);
    InsertTokenToMaps(token);
  }

  std::stringstream vocabulary(vocabulary_content);
  while (std::getline(vocabulary, token)) {
    if (token.empty()) {
      break;
    }

    // Duplicate tokens are not allowed.
    if (token_to_id_.find(token) != token_to_id_.end()) {
      Reset();
      DVLOG(1) << "Duplicate token found: " << token;
      return false;
    }

    InsertTokenToMaps(token);
    if (control_tokens.find(token) != control_tokens.end()) {
      control_tokens.erase(token);
    } else {
      max_token_length_ = std::max<size_t>(max_token_length_, token.size());
    }
  }

  // A valid vocabulary should include all control tokens.
  if (!control_tokens.empty()) {
    Reset();
    DVLOG(1) << "Missing following control tokens: " << control_tokens;
    return false;
  }

  InitAmbiguousMap();

  return IsReady();
}

bool OnDeviceTailTokenizer::IsReady() const {
  return !token_to_id_.empty();
}

void OnDeviceTailTokenizer::Reset() {
  token_to_id_.clear();
  id_to_token_.clear();
  ambiguous_tokens_.clear();
}

std::string OnDeviceTailTokenizer::IdToToken(const TokenId token_id) const {
  if (token_id < 0 || static_cast<size_t>(token_id) >= id_to_token_.size()) {
    return kUnknownToken;
  }
  return id_to_token_[token_id];
}

OnDeviceTailTokenizer::TokenId OnDeviceTailTokenizer::TokenToId(
    const std::string& token) const {
  auto match = token_to_id_.find(token);
  if (match == token_to_id_.end()) {
    // The ID for unknown token.
    return token_to_id_.find(kUnknownToken)->second;
  }
  return match->second;
}

void OnDeviceTailTokenizer::InitAmbiguousMap() {
  base::flat_map<std::string, size_t> prefix_count;
  for (const std::string& token : id_to_token_) {
    // Skip special tokens.
    if (token[0] == '<') {
      continue;
    }

    for (size_t len = 1; len <= token.size(); len++) {
      prefix_count[token.substr(0, len)] += 1;
    }
  }

  // Marks tokens as ambiguous if corresponding prefixes occur multiple times.
  for (const auto& iter : prefix_count) {
    if (iter.second > 1) {
      ambiguous_tokens_.insert(iter.first);
    }
  }
}

bool OnDeviceTailTokenizer::IsBeginQueryTokenId(TokenId token_id) const {
  return token_id == TokenToId(kBeginQueryToken);
}

bool OnDeviceTailTokenizer::IsEndQueryTokenId(TokenId token_id) const {
  return token_id == TokenToId(kEndQueryToken);
}

OnDeviceTailTokenizer::TokenId OnDeviceTailTokenizer::GetEndQueryTokenId()
    const {
  return TokenToId(kEndQueryToken);
}

bool OnDeviceTailTokenizer::IsAmbiguousToken(const std::string& token) const {
  return ambiguous_tokens_.find(token) != ambiguous_tokens_.end();
}

bool OnDeviceTailTokenizer::IsTokenPrintable(TokenId token_id) const {
  if (static_cast<size_t>(token_id) >= vocab_size()) {
    return false;
  }
  // If the token is not a single character, check whether it is a special
  // control token. Note other multi-char tokens which are extracted from
  // queries are always printable.
  if (static_cast<size_t>(token_id) >= kNumSingleChar) {
    return token_id != TokenToId(kBeginQueryToken) &&
           token_id != TokenToId(kEndQueryToken) &&
           token_id != TokenToId(kEmptyPreviousQueryToken) &&
           token_id != TokenToId(kUnknownToken);
  }
  return base::IsAsciiPrintable(static_cast<char>(token_id));
}

void OnDeviceTailTokenizer::EncodeRawString(
    const std::string& raw_string,
    std::vector<std::pair<std::string, TokenId>>* token_and_ids) const {
  size_t i = 0;
  while (i < raw_string.size()) {
    // Tries longest possible matches first and reduces the length gradually
    // until a match is found.
    size_t len = std::min<size_t>(max_token_length_, raw_string.size() - i);
    while (len >= 1) {
      auto iter = token_to_id_.find(raw_string.substr(i, len));
      if (iter != token_to_id_.end()) {
        token_and_ids->push_back({iter->first, iter->second});
        i += len;
        break;
      }
      len--;
    }

    // Unknown token is found.
    if (len == 0) {
      DVLOG(1) << "Invalid token found for raw string: " << raw_string;
      token_and_ids->clear();
      return;
    }
  }
}

void OnDeviceTailTokenizer::TokenizePrevQuery(
    const std::string& prev_query,
    TokenIds* prev_query_token_ids) const {
  prev_query_token_ids->clear();

  if (prev_query.empty()) {
    // Uses the special control token <NPQ> to mark empty previous query.
    prev_query_token_ids->push_back(TokenToId(kEmptyPreviousQueryToken));
    return;
  }

  std::vector<std::pair<std::string, TokenId>> token_and_ids;
  EncodeRawString(prev_query, &token_and_ids);

  if (OmniboxFieldTrial::ShouldEncodeLeadingSpaceForOnDeviceTailSuggest()) {
    prev_query_token_ids->push_back(TokenToId(" "));
  }

  for (const auto& pair : token_and_ids) {
    prev_query_token_ids->push_back(pair.second);
  }
}

void OnDeviceTailTokenizer::CreatePrefixTokenization(
    const std::string& prefix,
    Tokenization* tokenization) const {
  std::vector<std::pair<std::string, TokenId>> token_and_ids;

  EncodeRawString(prefix, &token_and_ids);
  if (token_and_ids.empty()) {
    return;
  }

  // Checks if the last token is ambiguous.
  size_t num_unambiguous = token_and_ids.size();
  if (IsAmbiguousToken(token_and_ids[token_and_ids.size() - 1].first)) {
    num_unambiguous--;
    tokenization->constraint_prefix =
        token_and_ids[token_and_ids.size() - 1].first;
  }

  // Always add begin query token at the front of the prefix.
  tokenization->unambiguous_ids.push_back(TokenToId(kBeginQueryToken));

  if (OmniboxFieldTrial::ShouldEncodeLeadingSpaceForOnDeviceTailSuggest()) {
    tokenization->unambiguous_ids.push_back(TokenToId(" "));
  }

  for (size_t i = 0; i < num_unambiguous; ++i) {
    tokenization->unambiguous_prefix += token_and_ids[i].first;
    tokenization->unambiguous_ids.push_back(token_and_ids[i].second);
  }
}

void OnDeviceTailTokenizer::InsertTokenToMaps(const std::string& token) {
  DCHECK(!token.empty());
  token_to_id_.insert({token, id_to_token_.size()});
  id_to_token_.push_back(std::move(token));
  DCHECK_EQ(token_to_id_.size(), id_to_token_.size());
}