File: client_side_phishing_model.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (210 lines) | stat: -rw-r--r-- 8,376 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SAFE_BROWSING_CONTENT_BROWSER_CLIENT_SIDE_PHISHING_MODEL_H_
#define COMPONENTS_SAFE_BROWSING_CONTENT_BROWSER_CLIENT_SIDE_PHISHING_MODEL_H_

#include <map>
#include <memory>

#include "base/callback_list.h"
#include "base/containers/flat_map.h"
#include "base/files/file.h"
#include "base/gtest_prod_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/synchronization/lock.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/thread_annotations.h"
#include "components/optimization_guide/core/delivery/optimization_target_model_observer.h"
#include "components/safe_browsing/core/common/fbs/client_model_generated.h"
#include "components/safe_browsing/core/common/proto/client_model.pb.h"
#include "components/safe_browsing/core/common/proto/csd.pb.h"

namespace optimization_guide {
class OptimizationGuideModelProvider;
}  // namespace optimization_guide

namespace safe_browsing {

enum class CSDModelType { kNone = 0, kFlatbuffer = 1 };

// This holds the currently active client side phishing detection model.
//
// The data to populate it is fetched periodically from Google to get the most
// up-to-date model. We assume it is updated at most every few hours.
//
// This class lives on UI thread and can only be called there. In particular
// GetModelStr() returns a string reference, which assumes the string won't be
// used and updated at the same time.

class ClientSidePhishingModel
    : public optimization_guide::OptimizationTargetModelObserver {
 public:
  ClientSidePhishingModel(
      optimization_guide::OptimizationGuideModelProvider* opt_guide);

  ~ClientSidePhishingModel() override;

  // optimization_guide::OptimizationTargetModelObserver implementation
  void OnModelUpdated(
      optimization_guide::proto::OptimizationTarget optimization_target,
      base::optional_ref<const optimization_guide::ModelInfo> model_info)
      override;

  // Enhanced Safe Browsing users receive an additional image embedding model to
  // be attached to CSD-Phishing ping to better train the models.
  void SubscribeToImageEmbedderOptimizationGuide();

  void UnsubscribeToImageEmbedderOptimizationGuide();

  // Register a callback to be notified whenever the model changes. All
  // notifications will occur on the UI thread.
  base::CallbackListSubscription RegisterCallback(
      base::RepeatingCallback<void()> callback);

  // Returns whether we currently have a model.
  bool IsEnabled() const;

  static bool VerifyCSDFlatBufferIndicesAndFields(
      const flat::ClientSideModel* model);

  // Returns model type (flatbuffer or none).
  CSDModelType GetModelType() const;

  // Returns the shared memory region for the flatbuffer.
  base::ReadOnlySharedMemoryRegion GetModelSharedMemoryRegion() const;

  const base::File& GetVisualTfLiteModel() const;

  const base::File& GetImageEmbeddingModel() const;

  bool HasImageEmbeddingModel();

  bool IsModelMetadataImageEmbeddingVersionMatching();

  int GetTriggerModelVersion();

  void SetVisualTfLiteModelForTesting(base::File file);
  // Overrides model type.
  void SetModelTypeForTesting(CSDModelType model_type);
  // Removes mapping.
  void ClearMappedRegionForTesting();
  // Get flatbuffer memory address.
  void* GetFlatBufferMemoryAddressForTesting();
  // Notifies all the callbacks of a change in model.
  void NotifyCallbacksOfUpdateForTesting();

  const base::flat_map<std::string, TfLiteModelMetadata::Threshold>&
  GetVisualTfLiteModelThresholds() const;

  // This function is used to override internal model for testing in
  // client_side_phishing_model_unittest
  void MaybeOverrideModel();

  void OnModelAndVisualTfLiteFileLoaded(
      std::optional<optimization_guide::proto::Any> model_metadata,
      std::pair<std::string, base::File> model_and_tflite);

  void OnImageEmbeddingModelLoaded(
      std::optional<optimization_guide::proto::Any> model_metadata,
      base::File image_embedding_model_data);

  void SetModelAndVisualTfLiteForTesting(
      const base::FilePath& model_file_path,
      const base::FilePath& visual_tf_lite_model_path);

  // Updates the internal model string, when one is received from testing in
  // client_side_phishing_model_unittest
  void SetModelStringForTesting(const std::string& model_str,
                                base::File visual_tflite_model);

  bool IsSubscribedToImageEmbeddingModelUpdates();

 private:
  static const int kInitialClientModelFetchDelayMs;

  void NotifyCallbacksOnUI();

  // Callback when the file overriding the model has been read in
  // client_side_phishing_model_unittest
  void OnGetOverridenModelData(
      CSDModelType model_type,
      std::pair<std::string, base::File> model_and_tflite);

  // The list of callbacks to notify when a new model is ready. Guarded by
  // sequence_checker_. Will always be notified on the UI thread.
  base::RepeatingCallbackList<void()> callbacks_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Model protobuf string. Guarded by sequence_checker_.
  std::string model_str_ GUARDED_BY_CONTEXT(sequence_checker_);

  // Visual TFLite model file. Guarded by sequence_checker_.
  std::optional<base::File> visual_tflite_model_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Image Embedding TfLite model file. Guarded by sequence_checker_.
  std::optional<base::File> image_embedding_model_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Thresholds in visual TFLite model file to be used for comparison after
  // visual classification
  base::flat_map<std::string, TfLiteModelMetadata::Threshold> thresholds_;

  // Model type as inferred by feature flag. Guarded by sequence_checker_.
  CSDModelType model_type_ GUARDED_BY_CONTEXT(sequence_checker_) =
      CSDModelType::kNone;

  // MappedReadOnlyRegion where the flatbuffer has been copied to. Guarded by
  // sequence_checker_.
  base::MappedReadOnlyRegion mapped_region_
      GUARDED_BY_CONTEXT(sequence_checker_) = base::MappedReadOnlyRegion();

  FRIEND_TEST_ALL_PREFIXES(ClientSidePhishingModelTest, CanOverrideWithFlag);

  // Optimization Guide service that provides the client side detection
  // model files for this service. Optimization Guide Service is a
  // BrowserContextKeyedServiceFactory and should not be used after Shutdown
  raw_ptr<optimization_guide::OptimizationGuideModelProvider> opt_guide_;

  // These two integer values will be set from reading the metadata specified
  // under each optimization target. These two are used to match the model
  // pairings properly. If the two values match, then the image embedding model
  // will be sent to the renderer process along with the trigger models. They do
  // not reflect any versions used in the model file itself.
  std::optional<int> trigger_model_opt_guide_metadata_image_embedding_version_;
  std::optional<int>
      embedding_model_opt_guide_metadata_image_embedding_version_;

  // This value is set from a version set in the model file's metadata. This
  // value will be used to send to the CSD service class so that it can be added
  // to the debugging metadata so that we can understand what version has been
  // sent to the renderer.
  std::optional<int> trigger_model_version_;

  scoped_refptr<base::SequencedTaskRunner> background_task_runner_;

  // If the users subscribe to ESB, the code will add an observer to the
  // OptimizationGuide service for the image embedder model. We can choose to
  // remove the observer, but it will be on the list to be removed, and not
  // removed instantly. Therefore, if the user subscribes, unsubscribes, and
  // re-subscribes again in very quick succession, the code will crash because
  // the DCHECK fails, indicating that the observer is added already. Therefore,
  // this will be a one time use flag.
  bool subscribed_to_image_embedder_ = false;

  SEQUENCE_CHECKER(sequence_checker_);

  base::TimeTicks beginning_time_;

  base::WeakPtrFactory<ClientSidePhishingModel> weak_ptr_factory_{this};
};

}  // namespace safe_browsing

#endif  // COMPONENTS_SAFE_BROWSING_CONTENT_BROWSER_CLIENT_SIDE_PHISHING_MODEL_H_