File: regex_tokenizer.cpp

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (66 lines) | stat: -rw-r--r-- 1,806 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <torchtext/csrc/regex_tokenizer.h> // @manual
#include <sstream>

namespace torchtext {

RegexTokenizer::RegexTokenizer(
    const std::vector<std::string>& patterns,
    const std::vector<std::string>& replacements,
    const bool to_lower = false)
    : patterns_(std::move(patterns)),
      replacements_(std::move(replacements)),
      to_lower_(to_lower) {
  TORCH_CHECK(
      patterns.size() == replacements.size(),
      "Expected `patterns` and `replacements` to have same size!");

  for (const auto& pattern : patterns_) {
    compiled_patterns_.push_back(new RE2(pattern));
  }
}

std::vector<std::string> RegexTokenizer::forward(std::string str) const {
  // str tolower
  if (to_lower_) {
    std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
      return std::tolower(c);
    });
  }

  for (size_t i = 0; i < compiled_patterns_.size(); i++) {
    RE2::GlobalReplace(&str, *compiled_patterns_[i], replacements_[i]);
  }

  std::vector<std::string> tokens;
  split_(str, tokens);
  return tokens;
}

void RegexTokenizer::split_(
    std::string& str,
    std::vector<std::string>& tokens,
    const char& delimiter) const {
  std::stringstream ss(str);
  std::string token;

  while (std::getline(ss, token, delimiter)) {
    if (!token.empty()) {
      tokens.push_back(token);
    }
  }
}

RegexTokenizerStates _serialize_regex_tokenizer(
    const c10::intrusive_ptr<RegexTokenizer>& self) {
  return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_);
}

c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(
    RegexTokenizerStates&& states) {
  return c10::make_intrusive<RegexTokenizer>(
      std::move(std::get<0>(states)),
      std::move(std::get<1>(states)),
      std::get<2>(states));
}

} // namespace torchtext