1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
#include <torchtext/csrc/regex_tokenizer.h> // @manual
#include <sstream>
namespace torchtext {
RegexTokenizer::RegexTokenizer(
const std::vector<std::string>& patterns,
const std::vector<std::string>& replacements,
const bool to_lower = false)
: patterns_(std::move(patterns)),
replacements_(std::move(replacements)),
to_lower_(to_lower) {
TORCH_CHECK(
patterns.size() == replacements.size(),
"Expected `patterns` and `replacements` to have same size!");
for (const auto& pattern : patterns_) {
compiled_patterns_.push_back(new RE2(pattern));
}
}
std::vector<std::string> RegexTokenizer::forward(std::string str) const {
// str tolower
if (to_lower_) {
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
return std::tolower(c);
});
}
for (size_t i = 0; i < compiled_patterns_.size(); i++) {
RE2::GlobalReplace(&str, *compiled_patterns_[i], replacements_[i]);
}
std::vector<std::string> tokens;
split_(str, tokens);
return tokens;
}
void RegexTokenizer::split_(
std::string& str,
std::vector<std::string>& tokens,
const char& delimiter) const {
std::stringstream ss(str);
std::string token;
while (std::getline(ss, token, delimiter)) {
if (!token.empty()) {
tokens.push_back(token);
}
}
}
RegexTokenizerStates _serialize_regex_tokenizer(
const c10::intrusive_ptr<RegexTokenizer>& self) {
return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_);
}
c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(
RegexTokenizerStates&& states) {
return c10::make_intrusive<RegexTokenizer>(
std::move(std::get<0>(states)),
std::move(std::get<1>(states)),
std::get<2>(states));
}
} // namespace torchtext
|