File: passage_embeddings_service_controller.h

package info (click to toggle)
chromium 138.0.7204.92-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,071,576 kB
  • sloc: cpp: 34,933,512; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,953; asm: 946,768; xml: 739,956; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,806; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (135 lines) | stat: -rw-r--r-- 5,506 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_

#include <memory>
#include <vector>

#include "base/callback_list.h"
#include "base/observer_list.h"
#include "base/timer/elapsed_timer.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/model_info.h"
#include "components/optimization_guide/proto/passage_embeddings_model_metadata.pb.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"

namespace passage_embeddings {

class PassageEmbeddingsServiceController : public EmbedderMetadataProvider {
 public:
  PassageEmbeddingsServiceController();
  ~PassageEmbeddingsServiceController() override;

  // Updates the paths and the metadata needed for executing the passage
  // embeddings model. The original paths and metadata will be erased regardless
  // of the validity of the new model paths.
  // Returns true and notifies the observers if the given paths are valid.
  // Virtual for testing.
  virtual bool MaybeUpdateModelInfo(
      base::optional_ref<const optimization_guide::ModelInfo> model_info);

  // Returns true if the embedder is currently running.
  bool EmbedderRunning();

  // Returns the embedder used to generate embeddings.
  Embedder* GetEmbedder();

 protected:
  // EmbedderMetadataProvider:
  void AddObserver(EmbedderMetadataObserver* observer) override;
  void RemoveObserver(EmbedderMetadataObserver* observer) override;

  // Computes embeddings for each entry in `passages`. Will invoke `callback`
  // when done. If successful, it is guaranteed that `results` will have the
  // same number of passages and embeddings and in the same order as
  // `passages`. Otherwise `results` will have empty passages and embeddings.
  using GetEmbeddingsResultCallback = base::OnceCallback<void(
      std::vector<mojom::PassageEmbeddingsResultPtr> results,
      ComputeEmbeddingsStatus status)>;
  using GetEmbeddingsCallback =
      base::RepeatingCallback<void(std::vector<std::string> passages,
                                   PassagePriority priority,
                                   GetEmbeddingsResultCallback callback)>;
  void GetEmbeddings(std::vector<std::string> passages,
                     PassagePriority priority,
                     GetEmbeddingsResultCallback callback);

  // Returns true if this service controller is ready for embeddings generation.
  bool EmbedderReady();

  // Returns the metadata about the embeddings model. This is only valid when
  // EmbedderReady() returns true.
  EmbedderMetadata GetEmbedderMetadata();

  // Launches the passage embeddings service and binds `cpu_logger_` to the
  // service process. Does nothing if the service is already launched.
  virtual void MaybeLaunchService() = 0;

  // Resets `service_remote_` and `cpu_logger_`. Called when the service remote
  // is idle or disconnects.
  virtual void ResetServiceRemote() = 0;

  // Resets `embedder_remote_`. Called when the model info is updated, when
  // models fail to load, or when the embedder remote is idle or disconnects.
  void ResetEmbedderRemote();

  mojo::Remote<mojom::PassageEmbeddingsService> service_remote_;

 private:
  // uint64_t is large enough to never overflow.
  using RequestId = uint64_t;
  RequestId next_request_id_ = 0;

  // Called when the model files on disks are opened and ready to be sent to
  // the service.
  void LoadModelsToService(
      mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
      base::ElapsedTimer service_launch_timer,
      mojom::PassageEmbeddingsLoadModelsParamsPtr params);

  // Called when an attempt to load models to service finishes.
  void OnLoadModelsResult(base::ElapsedTimer service_launch_timer,
                          bool success);

  // Called when an attempt to generate embeddings finishes.
  void OnGotEmbeddings(RequestId request_id,
                       GetEmbeddingsResultCallback callback,
                       base::ElapsedTimer generate_embeddings_timer,
                       PassagePriority priority,
                       std::vector<mojom::PassageEmbeddingsResultPtr> results);

  // Version of the embeddings model.
  int64_t model_version_;

  // Metadata of the embeddings model.
  std::optional<optimization_guide::proto::PassageEmbeddingsModelMetadata>
      model_metadata_;

  base::FilePath embeddings_model_path_;
  base::FilePath sp_model_path_;

  mojo::Remote<mojom::PassageEmbedder> embedder_remote_;

  // Pending requests to generate embeddings.
  std::vector<RequestId> pending_requests_;

  // Notifies embedders that model metadata updated.
  base::ObserverList<EmbedderMetadataObserver> observer_list_;

  // This holds the main scheduler that receives requests from multiple clients,
  // prioritizes all the jobs, and ultimately submits batches of work via
  // `GetEmbeddings` when the time is right.
  const std::unique_ptr<Embedder> embedder_;

  // Used to generate weak pointers to self.
  base::WeakPtrFactory<PassageEmbeddingsServiceController> weak_ptr_factory_{
      this};
};

}  // namespace passage_embeddings

#endif  // COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_