File: ai_language_model.h

package info (click to toggle)
chromium 138.0.7204.183-1~deb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-proposed-updates
  • size: 6,080,960 kB
  • sloc: cpp: 34,937,079; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,954; asm: 946,768; xml: 739,971; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,811; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (236 lines) | stat: -rw-r--r-- 10,121 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_
#define CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_

#include <deque>
#include <optional>

#include "base/containers/queue.h"
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/types/expected.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/optimization_guide/core/model_execution/model_broker_client.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/safety_checker.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/public/mojom/model_broker.mojom.h"
#include "content/public/browser/browser_context.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"

// The implementation of `blink::mojom::AILanguageModel`, which exposes the APIs
// for model execution.
class AILanguageModel : public AIContextBoundObject,
                        public blink::mojom::AILanguageModel,
                        public optimization_guide::TextSafetyClient {
 public:
  using PromptApiMetadata = optimization_guide::proto::PromptApiMetadata;

  // The minimum version of the model execution config for prompt API that
  // starts using proto instead of string value for the request.
  static constexpr uint32_t kMinVersionUsingProto = 2;

  // The Context class manages the history of prompt input and output. Context
  // is stored in a FIFO and kept below a limited number of tokens when overflow
  // occurs.
  class Context {
   public:
    // A piece of the prompt history and it's size.
    struct ContextItem {
      ContextItem();
      ContextItem(const ContextItem&);
      ContextItem(ContextItem&&);
      ~ContextItem();

      on_device_model::mojom::InputPtr input;
      uint32_t tokens = 0;
    };

    // `max_tokens` is the number of tokens remaining after the initial prompts.
    explicit Context(uint32_t max_tokens);
    Context(const Context&);
    ~Context();

    // The status of the result returned from `ReserveSpace()`.
    enum class SpaceReservationResult {
      // There remaining space is enough for the required tokens.
      kSufficientSpace = 0,
      // There remaining space is not enough for the required tokens, but after
      // evicting some of the oldest `ContextItem`s, it has enough space now.
      kSpaceMadeAvailable,
      // Even after evicting all the `ContextItem`s, it's not possible to make
      // enough space. In this case, no eviction will happen.
      kInsufficientSpace
    };

    // Make sure the context has at least `number_of_tokens` available, if there
    // is no enough space, the oldest `ContextItem`s will be evicted.
    SpaceReservationResult ReserveSpace(uint32_t num_tokens);

    // Insert a new context item, this may evict some oldest items to ensure the
    // total number of tokens in the context is below the limit. It returns the
    // result from the space reservation.
    SpaceReservationResult AddContextItem(ContextItem context_item);

    // Returns an input containing all of the current prompt history excluding
    // the initial prompts. This does not include prompts removed due to
    // overflow handling.
    on_device_model::mojom::InputPtr GetNonInitialPrompts();

    // The number of tokens remaining after the initial prompts.
    uint32_t max_tokens() const { return max_tokens_; }
    uint32_t current_tokens() const { return current_tokens_; }

   private:
    uint32_t max_tokens_;
    uint32_t current_tokens_ = 0;
    std::deque<ContextItem> context_items_;
  };

  AILanguageModel(AIContextBoundObjectSet& context_bound_object_set,
                  on_device_model::mojom::SessionParamsPtr session_params,
                  base::WeakPtr<optimization_guide::ModelClient> model_client,
                  mojo::PendingRemote<on_device_model::mojom::Session> session,
                  base::WeakPtr<OptimizationGuideLogger> logger);
  AILanguageModel(const AILanguageModel&) = delete;
  AILanguageModel& operator=(const AILanguageModel&) = delete;

  ~AILanguageModel() override;

  // Returns the the metadata parsed to the `PromptApiMetadata` from `any`.
  static PromptApiMetadata ParseMetadata(
      const optimization_guide::proto::Any& any);

  // Format the initial prompts, gets the token count, updates the session,
  // and reports to `create_client`.
  void Initialize(
      std::vector<blink::mojom::AILanguageModelPromptPtr> initial_prompts,
      mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
          create_client);

  // `blink::mojom::AILanguageModel` implementation.
  void Prompt(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
              on_device_model::mojom::ResponseConstraintPtr constraint,
              mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
                  pending_responder) override;
  void Append(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
              mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
                  pending_responder) override;
  void Fork(
      mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
          client) override;
  void Destroy() override;
  void MeasureInputUsage(
      std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
      MeasureInputUsageCallback callback) override;

  // AIContextBoundObject:
  void SetPriority(on_device_model::mojom::Priority priority) override;

  // optimization_guide::TextSafetyClient:
  void StartSession(
      mojo::PendingReceiver<on_device_model::mojom::TextSafetySession> session)
      override;

  blink::mojom::AILanguageModelInstanceInfoPtr GetLanguageModelInstanceInfo();

 private:
  mojo::PendingRemote<blink::mojom::AILanguageModel> BindRemote();

  class PromptState;
  void InitializeGetInputSizeComplete(
      on_device_model::mojom::InputPtr input,
      mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
          create_client,
      std::optional<uint32_t> token_count);
  void InitializeSafetyChecksComplete(
      on_device_model::mojom::InputPtr input,
      mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
          create_client,
      optimization_guide::SafetyChecker::Result safety_result);

  void ForkInternal(
      mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
          client,
      base::OnceClosure on_complete);

  void PromptInternal(
      std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
      on_device_model::mojom::ResponseConstraintPtr constraint,
      mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
          pending_responder,
      base::OnceClosure on_complete);
  void PromptGetInputSizeComplete(base::OnceClosure on_complete,
                                  std::optional<uint32_t> result);
  void OnPromptOutputComplete();

  void AppendInternal(
      std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
      mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
          pending_responder,
      base::OnceClosure on_complete);

  void HandleOverflow();
  void GetSizeInTokens(
      on_device_model::mojom::InputPtr input,
      base::OnceCallback<void(std::optional<uint32_t>)> callback);

  // These methods are used for implementing queueing.
  using QueueCallback = base::OnceCallback<void(base::OnceClosure)>;
  void AddToQueue(QueueCallback task);
  void TaskComplete();
  void RunNext();

  // Contains just the initial prompts. This should not change throughout the
  // lifetime of this object. If this object is valid, `current_session_` can
  // also be assumed to be valid, as any disconnects should apply to both
  // remotes (e.g. a service crash).
  mojo::Remote<on_device_model::mojom::Session> initial_session_;

  // Contains the current committed session state. This will be replaced after a
  // successful prompt with the latest session state.
  mojo::Remote<on_device_model::mojom::Session> current_session_;

  // The session params the initial session was created with.
  on_device_model::mojom::SessionParamsPtr session_params_;

  // Holds all the input and output from the previous prompt.
  std::unique_ptr<Context> context_;
  // It's safe to store `raw_ref` here since both `this` and `ai_manager_` are
  // owned by `context_bound_object_set_`, and they will be destroyed together.
  base::raw_ref<AIContextBoundObjectSet> context_bound_object_set_;

  // Holds the queue of operations to be run.
  base::queue<QueueCallback> queue_;
  // Whether a task is currently running.
  bool task_running_ = false;

  std::unique_ptr<optimization_guide::SafetyChecker> safety_checker_;
  base::WeakPtr<optimization_guide::ModelClient> model_client_;

  // Holds state for any currently active prompt. This holds a reference to
  // `safety_checker_` so must be ordered after that member.
  std::unique_ptr<PromptState> prompt_state_;

  base::WeakPtr<OptimizationGuideLogger> logger_;

  mojo::Receiver<blink::mojom::AILanguageModel> receiver_{this};

  base::WeakPtrFactory<AILanguageModel> weak_ptr_factory_{this};
};

#endif  // CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_