File: scheduling_embedder.h

package info (click to toggle)
chromium 138.0.7204.157-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,071,864 kB
  • sloc: cpp: 34,936,859; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,953; asm: 946,768; xml: 739,967; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,806; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (174 lines) | stat: -rw-r--r-- 7,357 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_

#include <deque>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "base/functional/callback.h"
#include "base/memory/weak_ptr.h"
#include "base/scoped_observation.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "components/performance_manager/scenario_api/performance_scenario_observer.h"
#include "components/performance_manager/scenario_api/performance_scenarios.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"

namespace passage_embeddings {

// The SchedulingEmbedder adds scheduling control with batching and priorities
// so that high priority queries can be computed as soon as possible. Scheduling
// is also needed to avoid clogging the pipes for a slow remote embedder. Even
// single pages can take a while, and when the model changes, all existing
// passages need their embeddings recomputed, which can take a very long time
// and should be done at lower priority.
class SchedulingEmbedder
    : public Embedder,
      public EmbedderMetadataObserver,
      public performance_scenarios::PerformanceScenarioObserver {
 public:
  using GetEmbeddingsResultCallback = base::OnceCallback<void(
      std::vector<mojom::PassageEmbeddingsResultPtr> results,
      ComputeEmbeddingsStatus status)>;
  using GetEmbeddingsCallback =
      base::RepeatingCallback<void(std::vector<std::string> passages,
                                   PassagePriority priority,
                                   GetEmbeddingsResultCallback callback)>;
  SchedulingEmbedder(EmbedderMetadataProvider* embedder_metadata_provider,
                     GetEmbeddingsCallback get_embeddings_callback,
                     size_t max_jobs,
                     size_t scheduled_max_batch_size,
                     bool use_performance_scenario);
  ~SchedulingEmbedder() override;

  // Embedder:
  TaskId ComputePassagesEmbeddings(
      PassagePriority priority,
      std::vector<std::string> passages,
      ComputePassagesEmbeddingsCallback callback) override;
  void ReprioritizeTasks(PassagePriority priority,
                         const std::set<TaskId>& tasks) override;
  bool TryCancel(TaskId task_id) override;

 private:
  // A job consists of multiple passages, and each passage must have its
  // embedding computed. When all are finished, the job is done and its
  // callback will be invoked. Multiple jobs may be batched together when
  // submitting work to the `embedder_remote_proxy`, and jobs can also be broken
  // down so that partial progress is made across multiple work submissions.
  struct Job {
    Job(PassagePriority priority,
        TaskId task_id,
        std::vector<std::string> passages,
        ComputePassagesEmbeddingsCallback callback);
    ~Job();
    Job(const Job&) = delete;
    Job& operator=(const Job&) = delete;
    Job(Job&&);
    Job& operator=(Job&&);

    // Data for the job is saved from calls to `ComputePassagesEmbeddings`.
    PassagePriority priority;
    TaskId task_id;
    std::vector<std::string> passages;
    ComputePassagesEmbeddingsCallback callback;

    bool in_progress = false;

    // Completed embeddings; may be partial.
    std::vector<Embedding> embeddings;

    // Measures total job duration, from creation to completion.
    base::ElapsedTimer timer;
  };

  // EmbedderMetadataObserver:
  void EmbedderMetadataUpdated(EmbedderMetadata metadata) override;

  // PerformanceScenarioObserver:
  void OnLoadingScenarioChanged(
      performance_scenarios::ScenarioScope scope,
      performance_scenarios::LoadingScenario old_scenario,
      performance_scenarios::LoadingScenario new_scenario) override;
  void OnInputScenarioChanged(
      performance_scenarios::ScenarioScope scope,
      performance_scenarios::InputScenario old_scenario,
      performance_scenarios::InputScenario new_scenario) override;

  // Invoked after the embedding for the current job has been computed.
  // Continues processing next job if one is pending.
  void OnEmbeddingsComputed(
      std::vector<mojom::PassageEmbeddingsResultPtr> results,
      ComputeEmbeddingsStatus status);

  // Stable-sort jobs by priority and submit a batch of work to embedder.
  // This will only submit new work if the embedder is not already working.
  void SubmitWorkToEmbedder();

  // Returns true if currently in a work ready performance scenario state.
  bool IsPerformanceScenarioReady();

  // Call the callback with status, etc. and record relevant histograms.
  static void FinishJob(Job job, ComputeEmbeddingsStatus status);

  // When this is non-empty, the embedder is working and its results will be
  // applied from front to back when `OnEmbeddingsComputed` is called. Not all
  // of these jobs are necessarily being worked on by the embedder. It may
  // contain a mix of in-progress, partially completed, and not-yet-started
  // jobs. In-progress jobs are ordered first, and in the same order as
  // submitted to the embedder. Partially completed jobs may follow,
  // still in the order they were last submitted to the embedder.
  // Not-yet-started jobs are ordered last. All jobs will be re-ordered by
  // priority before submitting the next batch to the embedder.
  std::deque<Job> jobs_;

  // ID to assign to the next Job.
  TaskId next_task_id_ = 1;

  // Whether the embedder is currently working on some passages. Note, this
  // is not the same concept as having a job in progress since multiple
  // embedder work submissions may be required to complete a job.
  bool work_submitted_ = false;

  // The callback that does the actual embeddings computations.
  // May be slow; await results before sending the next request.
  GetEmbeddingsCallback get_embeddings_callback_;

  // Metadata about the embedder; Set when valid metadata is received from
  // `embedder_metadata_provider`.
  EmbedderMetadata embedder_metadata_{0, 0};

  // The maximum number of jobs to hold at once. Exceeding the cap
  // will cause job failures on last pending jobs to avoid very high memory use.
  // When the limit is reached, the last pending job is canceled instead of
  // failing to accept the new job so that queries can still be accepted even
  // if the queue is full of lower priority jobs awaiting performance scenario.
  size_t max_jobs_;

  // The maximum number of embeddings to submit to the primary embedder.
  size_t max_batch_size_;

  // Whether to block embedding work submission on performance scenario.
  bool use_performance_scenario_;

  base::ScopedObservation<
      performance_scenarios::PerformanceScenarioObserverList,
      performance_scenarios::PerformanceScenarioObserver>
      performance_scenario_observation_{this};

  base::ScopedObservation<EmbedderMetadataProvider, EmbedderMetadataObserver>
      embedder_metadata_observation_{this};

  base::WeakPtrFactory<SchedulingEmbedder> weak_ptr_factory_{this};
};

}  // namespace passage_embeddings

#endif  // COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_