File: language_detection_model.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (183 lines) | stat: -rw-r--r-- 7,321 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_LANGUAGE_DETECTION_CORE_LANGUAGE_DETECTION_MODEL_H_
#define COMPONENTS_LANGUAGE_DETECTION_CORE_LANGUAGE_DETECTION_MODEL_H_

#include <string>
#include <vector>

#include "base/component_export.h"
#include "base/files/file.h"
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/category.h"

namespace tflite::task::text::nlclassifier {
class NLClassifier;
}  // namespace tflite::task::text::nlclassifier

namespace language_detection {

// Even though the model only looks at the first 128 characters of the string,
// calls to ClassifyText have a run-time proportional to the size of the
// input. So we expect better performance if we truncate the string.
// We use 256 to keep in line with the existing code.
// TODO(https://crbug.com/354070625): Figure out if we can drop this to 128.
inline constexpr size_t kModelTruncationLength = 256;

struct Prediction {
  Prediction(const std::string& language, float score)
      : language(language), score(score) {}
  Prediction() = delete;
  std::string language;
  float score;

  bool operator<(const Prediction& other) const { return score < other.score; }
};

// Returns the prediction with the highest score.
COMPONENT_EXPORT(LANGUAGE_DETECTION)
Prediction TopPrediction(const std::vector<Prediction>& predictions);

// The state of the language detection model file needed for determining
// the language of the page.
//
// Keep in sync with LanguageDetectionModelState in enums.xml.
enum class LanguageDetectionModelState {
  // The language model state is not known.
  kUnknown,
  // The provided model file was not valid.
  kModelFileInvalid,
  // The language model's `base::File` is valid.
  kModelFileValid,
  // The language model is available for use with TFLite.
  kModelAvailable,

  // New values above this line.
  kMaxValue = kModelAvailable,
};

// A language detection model that will use a TFLite model to determine the
// language of a string.
// Each instance of this should only be used from a single thread.
class COMPONENT_EXPORT(LANGUAGE_DETECTION) LanguageDetectionModel {
 public:
  using ModelLoadedCallback = base::OnceCallback<void(LanguageDetectionModel&)>;

  LanguageDetectionModel();
  ~LanguageDetectionModel();

  // Runs the TFLIte language detection model on the string. This will only look
  // at the first 128 unicode characters of the string. Return a vector of
  // scored language predictions. If `truncate` is `true`, this will truncate
  // the string before passing to the TFLite model. Even though the model only
  // considers a prefix of the input, the runtime is proportional to the total
  // length of the input.
  std::vector<Prediction> Predict(std::u16string_view contents) const;

  // Runs the TFLIte language detection model on the whole string. This will
  // scan over the content with the 128 character window.
  // Return a vector of scored language predictions. The predictions are the
  // mean value of the predictions on each window.
  std::vector<Prediction> PredictWithScan(std::u16string_view contents) const;

  // Runs the TFLIte language detection model on no more than three samples of
  // the string. If the contents is less than 768 characters, the function will
  // decide the language by running the model over the first 128 characters.
  // Otherwise, the first, last and the middle 256-character text piece will be
  // sampled and the return value will be the prediction with the highest
  // confidence for the three samples.
  Prediction PredictTopLanguageWithSamples(std::u16string_view contents) const;

  // Updates the language detection model for use by memory-mapping
  // |model_file| used to detect the language of the page.
  //
  // This method is blocking and should only be called in context
  // where it is fine to block the current thread. If you cannot
  // block, use UpdateWithFileAsync(...) instead.
  void UpdateWithFile(base::File model_file);

  // Updates the language detection model for use by memory-mapping
  // |model_file| used to detect the language of the page. Performs
  // the operation on a background sequence and call |callback| on
  // completion
  void UpdateWithFileAsync(base::File model_file, base::OnceClosure callback);

  // Returns whether |this| is initialized and is available to handle requests
  // to determine the language of the page.
  bool IsAvailable() const;

  // Returns the size of the loaded model in bytes. If the model is not yet
  // available, the method will return 0.
  int64_t GetModelSize() const;

  void AddOnModelLoadedCallback(ModelLoadedCallback callback);

  std::string GetModelVersion() const;

  // Detach the instance from the bound sequence. Must only be used if the
  // object is created on a sequence and then moved on another sequence to
  // live.
  void DetachFromSequence() { DETACH_FROM_SEQUENCE(sequence_checker_); }

  // The number of characters to sample and provide as a buffer to the model
  // in PredictTopLanguageWithSamples.
  static constexpr size_t kTextSampleLength = 256;

  // The number of samples of |kTextSampleLength| to evaluate the model
  // in PredictTopLanguageWithSamples.
  static constexpr int kNumTextSamples = 3;

  // The maximum window size the model runs over when predicting the language.
  static constexpr size_t kScanWindowSize = 128;

 private:
  // An owned NLClassifier.
  using OwnedNLClassifier =
      std::unique_ptr<tflite::task::text::nlclassifier::NLClassifier>;
  using ModelAndSize = std::pair<OwnedNLClassifier, int64_t>;

  // Loads model from |model_file| using |num_threads|. This can be called on
  // any thread.
  static std::optional<LanguageDetectionModel::ModelAndSize> LoadModelFromFile(
      base::File model_file,
      int num_threads);

  void NotifyModelLoaded();

  // Execute the model on the provided |sampled_str| and return the top
  // language and the models score/confidence in that prediction.
  Prediction DetectTopLanguage(std::u16string_view sampled_str) const;

  // Updates the model if the not unset.
  void SetModel(std::optional<ModelAndSize> model_and_size);

  SEQUENCE_CHECKER(sequence_checker_);

  // The tflite classifier that can determine the language of text.
  OwnedNLClassifier lang_detection_model_;

  // The number of threads to use for model inference. -1 tells TFLite to use
  // its internal default logic.
  const int num_threads_ = -1;

  static constexpr int kMaxPendingCallbacksCount = 100;
  // Pending callbacks for waiting the model to be available.
  std::vector<ModelLoadedCallback> model_loaded_callbacks_;

  // Records whether a file has been updated to the model.
  bool loaded_ = false;

  // Records the size of the model file loaded. The value is only valid when
  // loaded_ is True.
  int64_t model_file_size_ = 0;

  // Used to load the data on a background sequence (see UpdateWithFileAsync).
  base::WeakPtrFactory<LanguageDetectionModel> weak_factory_{this};
};

}  // namespace language_detection
#endif  // COMPONENTS_LANGUAGE_DETECTION_CORE_LANGUAGE_DETECTION_MODEL_H_