File: sentencepiece.cpp

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (85 lines) | stat: -rw-r--r-- 2,627 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <torchtext/csrc/sentencepiece.h> // @manual

namespace torchtext {

SentencePiece::SentencePiece(const std::string& content) : content_(content) {
  const auto status = processor_.LoadFromSerializedProto(content_);
  if (!status.ok()) {
    throw std::runtime_error(
        "Failed to load SentencePiece model. Error: " + status.ToString());
  }
}

std::vector<std::string> SentencePiece::Encode(const std::string& input) const {
  std::vector<std::string> pieces;
  processor_.Encode(input, &pieces);
  return pieces;
}

std::vector<int64_t> SentencePiece::EncodeAsIds(
    const std::string& input) const {
  const auto val = processor_.EncodeAsIds(input);
  return std::vector<int64_t>(val.begin(), val.end());
}

std::string SentencePiece::DecodeIds(const std::vector<int64_t>& ids) const {
  const std::vector<int> val(ids.begin(), ids.end());
  return processor_.DecodeIds(val);
}

std::vector<std::string> SentencePiece::EncodeAsPieces(
    const std::string& input) const {
  return processor_.EncodeAsPieces(input);
}

std::string SentencePiece::DecodePieces(
    const std::vector<std::string>& pieces) const {
  return processor_.DecodePieces(pieces);
}

int64_t SentencePiece::GetPieceSize() const {
  return processor_.GetPieceSize();
}

int64_t SentencePiece::unk_id() const {
  return processor_.unk_id();
}

int64_t SentencePiece::PieceToId(const std::string& piece) const {
  return processor_.PieceToId(piece);
}

std::string SentencePiece::IdToPiece(const int64_t id) const {
  return processor_.IdToPiece(id);
}

void generate_sp_model(
    const std::string& filename,
    const int64_t& vocab_size,
    const std::string& model_type,
    const std::string& model_prefix) {
  const auto status = ::sentencepiece::SentencePieceTrainer::Train(
      "--input=" + filename + " --model_prefix=" + model_prefix +
      " --vocab_size=" + std::to_string(vocab_size) +
      " --model_type=" + model_type);
  if (!status.ok()) {
    throw std::runtime_error(
        "Failed to train SentencePiece model. Error: " + status.ToString());
  }
}

c10::intrusive_ptr<SentencePiece> load_sp_model(const std::string& path) {
  std::ifstream file(path, std::ios::binary | std::ios::in);
  if (!file) {
    throw std::runtime_error("Failed to open file :" + path);
  }
  std::string content(
      (std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  return c10::make_intrusive<SentencePiece>(std::move(content));
}

c10::intrusive_ptr<SentencePiece> load_sp_model_string(std::string content) {
  return c10::make_intrusive<SentencePiece>(std::move(content));
}

} // namespace torchtext