File: clip_tokenizer.cpp

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (159 lines) | stat: -rw-r--r-- 5,029 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <torchtext/csrc/clip_tokenizer.h>
#include <torchtext/csrc/regex.h> // @manual

#include <unordered_set>

namespace torchtext {
const Regex kCLIPRegex(
    "(?i)(<\\|startoftext\\|>|<\\|endoftext\\|>|\\'s|\\'t|\\'re|\\'ve|"
    "\\'m|\\'ll|\\'d|[\\pL]+|[\\pN]|[^\\s\\pL\\pN]+)");
const std::string kWhitespaceString("</w>");
const std::unordered_set<std::string> kSpecialTokens{
    "<|startoftext|>",
    "<|endoftext|>"};

std::vector<std::string> clip_pre_tokenizer(std::string input) {
  std::string token;
  std::vector<std::string> tokens;
  re2::StringPiece inp(input);
  while (kCLIPRegex.FindAndConsume(&inp, &token)) {
    tokens.push_back(token);
  }
  return tokens;
}

std::vector<std::string> CLIPEncoder::BPE_(
    const std::vector<std::string>& token_list) {
  // Given a list of input tokens, keep finding the best bpe merge and
  // generate a new list of tokens until
  //  1) token list size reduced to 1
  //      OR
  //  2) can't find bpe merge
  auto concatenated = concatenate_strings(token_list);
  if (caching_enabled_ && cache_.contains(concatenated)) {
    return cache_.at(concatenated);
  } else if (kSpecialTokens.find(concatenated) != kSpecialTokens.end()) {
    return {concatenated};
  }

  std::vector<std::string> tok_list(token_list.begin(), token_list.end() - 1);
  tok_list.push_back(token_list[token_list.size() - 1] + kWhitespaceString);
  auto pairs = get_pairs(tok_list, seperator_);
  if (pairs.empty()) {
    return {concatenated + kWhitespaceString};
  }
  while (true) {
    auto bigram = FindBestPair_(pairs);
    if (!bpe_merge_ranks_.contains(bigram))
      break;

    // Finding all indexes that token_list[i] == first and token_list[i+1] ==
    // second. After the loop, new token list will be
    //  1) first + second pair
    //  2) all the other tokens in the original token list
    //
    // For example: first="a" second="w" and token_list =
    // ["a", "w", "some", "a", "w", "e"]
    // Result: new_token_list = ["aw", "some", "aw", "e"]

    auto parts = split_tokens(bigram, seperator_);
    std::vector<std::string> new_token_list;
    std::size_t i = 0;
    while (i < tok_list.size()) {
      auto j = list_str_index(tok_list, parts.first, i);
      if (j != -1) {
        for (int k = i; k < j; k++)
          new_token_list.push_back(tok_list[k]);
        i = j;
      } else {
        for (std::size_t k = i; k < tok_list.size(); k++)
          new_token_list.push_back(tok_list[k]);
        break;
      }

      if (tok_list[i] == parts.first && i < (tok_list.size() - 1) &&
          tok_list[i + 1] == parts.second) {
        new_token_list.push_back(parts.first + parts.second);
        i += 2;
      } else {
        new_token_list.push_back(tok_list[i]);
        i += 1;
      }
    }

    tok_list = new_token_list;
    if (tok_list.size() == 1) {
      break;
    } else {
      pairs = get_pairs(tok_list, seperator_);
    }
  }

  if (caching_enabled_)
    cache_.insert(concatenated, tok_list);
  return tok_list;
}

std::vector<std::string> CLIPEncoder::PreTokenize_(std::string input) {
  return clip_pre_tokenizer(input);
}

std::vector<int64_t> CLIPEncoder::Encode(const std::string& text) {
  return GPT2BPEEncoder::Encode(text);
}

std::vector<std::string> CLIPEncoder::Tokenize(const std::string& text) {
  return GPT2BPEEncoder::Tokenize(text);
}

CLIPEncoderStatesPybind _serialize_clip_encoder_pybind(
    const c10::intrusive_ptr<CLIPEncoder>& self) {
  return std::make_tuple(
      self->GetBPEEncoder(),
      self->GetBPEMergeRanks(),
      self->seperator_,
      self->GetByteEncoder(),
      self->caching_enabled_);
}

CLIPEncoderStatesTorchbind _serialize_clip_encoder_torchbind(
    const c10::intrusive_ptr<CLIPEncoder>& self) {
  return std::make_tuple(
      self->bpe_encoder_,
      self->bpe_merge_ranks_,
      self->seperator_,
      self->byte_encoder_,
      self->caching_enabled_);
}

c10::intrusive_ptr<CLIPEncoder> _deserialize_clip_encoder_pybind(
    CLIPEncoderStatesPybind states) {
  auto state_size = std::tuple_size<decltype(states)>::value;
  TORCH_CHECK(
      state_size == 5,
      "Expected deserialized CLIPEncoder to have 5 states but found " +
          std::to_string(state_size) + " states");
  return c10::make_intrusive<CLIPEncoder>(
      std::move(std::get<0>(states)),
      std::move(std::get<1>(states)),
      std::get<2>(states),
      std::move(std::get<3>(states)),
      std::get<4>(states));
}

c10::intrusive_ptr<CLIPEncoder> _deserialize_clip_encoder_torchbind(
    CLIPEncoderStatesTorchbind states) {
  auto state_size = std::tuple_size<decltype(states)>::value;
  TORCH_CHECK(
      state_size == 5,
      "Expected deserialized CLIPEncoder to have 5 states but found " +
          std::to_string(state_size) + " states");
  return c10::make_intrusive<CLIPEncoder>(
      std::move(std::get<0>(states)),
      std::move(std::get<1>(states)),
      std::get<2>(states),
      std::move(std::get<3>(states)),
      std::get<4>(states));
}

}; // namespace torchtext