1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#include <memory>
#include <vector>
#include "base/callback_list.h"
#include "base/observer_list.h"
#include "base/timer/elapsed_timer.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/model_info.h"
#include "components/optimization_guide/proto/passage_embeddings_model_metadata.pb.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
class PassageEmbeddingsServiceController : public EmbedderMetadataProvider {
public:
PassageEmbeddingsServiceController();
~PassageEmbeddingsServiceController() override;
// Updates the paths and the metadata needed for executing the passage
// embeddings model. The original paths and metadata will be erased regardless
// of the validity of the new model paths.
// Returns true and notifies the observers if the given paths are valid.
// Virtual for testing.
virtual bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info);
// Returns true if the embedder is currently running.
bool EmbedderRunning();
// Returns the embedder used to generate embeddings.
Embedder* GetEmbedder();
protected:
// EmbedderMetadataProvider:
void AddObserver(EmbedderMetadataObserver* observer) override;
void RemoveObserver(EmbedderMetadataObserver* observer) override;
// Computes embeddings for each entry in `passages`. Will invoke `callback`
// when done. If successful, it is guaranteed that `results` will have the
// same number of passages and embeddings and in the same order as
// `passages`. Otherwise `results` will have empty passages and embeddings.
using GetEmbeddingsResultCallback = base::OnceCallback<void(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status)>;
using GetEmbeddingsCallback =
base::RepeatingCallback<void(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback)>;
void GetEmbeddings(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback);
// Returns true if this service controller is ready for embeddings generation.
bool EmbedderReady();
// Returns the metadata about the embeddings model. This is only valid when
// EmbedderReady() returns true.
EmbedderMetadata GetEmbedderMetadata();
// Launches the passage embeddings service and binds `cpu_logger_` to the
// service process. Does nothing if the service is already launched.
virtual void MaybeLaunchService() = 0;
// Resets `service_remote_` and `cpu_logger_`. Called when the service remote
// is idle or disconnects.
virtual void ResetServiceRemote() = 0;
// Resets `embedder_remote_`. Called when the model info is updated, when
// models fail to load, or when the embedder remote is idle or disconnects.
void ResetEmbedderRemote();
mojo::Remote<mojom::PassageEmbeddingsService> service_remote_;
private:
// uint64_t is large enough to never overflow.
using RequestId = uint64_t;
RequestId next_request_id_ = 0;
// Called when the model files on disks are opened and ready to be sent to
// the service.
void LoadModelsToService(
mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
base::ElapsedTimer service_launch_timer,
mojom::PassageEmbeddingsLoadModelsParamsPtr params);
// Called when an attempt to load models to service finishes.
void OnLoadModelsResult(base::ElapsedTimer service_launch_timer,
bool success);
// Called when an attempt to generate embeddings finishes.
void OnGotEmbeddings(RequestId request_id,
GetEmbeddingsResultCallback callback,
base::ElapsedTimer generate_embeddings_timer,
PassagePriority priority,
std::vector<mojom::PassageEmbeddingsResultPtr> results);
// Version of the embeddings model.
int64_t model_version_;
// Metadata of the embeddings model.
std::optional<optimization_guide::proto::PassageEmbeddingsModelMetadata>
model_metadata_;
base::FilePath embeddings_model_path_;
base::FilePath sp_model_path_;
mojo::Remote<mojom::PassageEmbedder> embedder_remote_;
// Pending requests to generate embeddings.
std::vector<RequestId> pending_requests_;
// Notifies embedders that model metadata updated.
base::ObserverList<EmbedderMetadataObserver> observer_list_;
// This holds the main scheduler that receives requests from multiple clients,
// prioritizes all the jobs, and ultimately submits batches of work via
// `GetEmbeddings` when the time is right.
const std::unique_ptr<Embedder> embedder_;
// Used to generate weak pointers to self.
base::WeakPtrFactory<PassageEmbeddingsServiceController> weak_ptr_factory_{
this};
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
|