1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_AI_LANGUAGE_MODEL_H_
#define THIRD_PARTY_BLINK_RENDERER_MODULES_AI_LANGUAGE_MODEL_H_
#include "base/types/pass_key.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-blink.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-blink-forward.h"
#include "third_party/blink/renderer/bindings/core/v8/idl_types.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_availability.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_language_model_append_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_language_model_clone_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_language_model_create_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_language_model_message_role.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_language_model_prompt_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_typedefs.h"
#include "third_party/blink/renderer/core/dom/events/event_target.h"
#include "third_party/blink/renderer/core/event_type_names.h"
#include "third_party/blink/renderer/core/execution_context/execution_context_lifecycle_observer.h"
#include "third_party/blink/renderer/core/streams/readable_stream.h"
#include "third_party/blink/renderer/modules/ai/language_model_params.h"
#include "third_party/blink/renderer/platform/bindings/script_wrappable.h"
#include "third_party/blink/renderer/platform/mojo/heap_mojo_remote.h"
#include "third_party/blink/renderer/platform/wtf/hash_set.h"
namespace blink {
// The class that represents a `LanguageModel` object.
class LanguageModel final : public EventTarget, public ExecutionContextClient {
DEFINE_WRAPPERTYPEINFO();
public:
// Get the mojo enum value for the given V8 `role` enum value.
static mojom::blink::AILanguageModelPromptRole ConvertRoleToMojo(
V8LanguageModelMessageRole role);
LanguageModel(
ExecutionContext* execution_context,
mojo::PendingRemote<mojom::blink::AILanguageModel> pending_remote,
scoped_refptr<base::SequencedTaskRunner> task_runner,
mojom::blink::AILanguageModelInstanceInfoPtr info);
~LanguageModel() override = default;
void Trace(Visitor* visitor) const override;
// EventTarget implementation
const AtomicString& InterfaceName() const override;
ExecutionContext* GetExecutionContext() const override;
DEFINE_ATTRIBUTE_EVENT_LISTENER(quotaoverflow, kQuotaoverflow)
// language_model.idl implementation.
static ScriptPromise<LanguageModel> create(
ScriptState* script_state,
LanguageModelCreateOptions* options,
ExceptionState& exception_state);
static ScriptPromise<V8Availability> availability(
ScriptState* script_state,
const LanguageModelCreateCoreOptions* options,
ExceptionState& exception_state);
static ScriptPromise<IDLNullable<LanguageModelParams>> params(
ScriptState* script_state,
ExceptionState& exception_state);
ScriptPromise<IDLString> prompt(ScriptState* script_state,
const V8LanguageModelPrompt* input,
const LanguageModelPromptOptions* options,
ExceptionState& exception_state);
ReadableStream* promptStreaming(ScriptState* script_state,
const V8LanguageModelPrompt* input,
const LanguageModelPromptOptions* options,
ExceptionState& exception_state);
ScriptPromise<IDLUndefined> append(ScriptState* script_state,
const V8LanguageModelPrompt* input,
const LanguageModelAppendOptions* options,
ExceptionState& exception_state);
ScriptPromise<IDLDouble> measureInputUsage(
ScriptState* script_state,
const V8LanguageModelPrompt* input,
const LanguageModelPromptOptions* options,
ExceptionState& exception_state);
double inputQuota() const { return input_quota_; }
double inputUsage() const { return input_usage_; }
uint32_t topK() const { return top_k_; }
float temperature() const { return temperature_; }
ScriptPromise<LanguageModel> clone(ScriptState* script_state,
const LanguageModelCloneOptions* options,
ExceptionState& exception_state);
void destroy(ScriptState* script_state, ExceptionState& exception_state);
HeapMojoRemote<mojom::blink::AILanguageModel>& GetAILanguageModelRemote();
scoped_refptr<base::SequencedTaskRunner> GetTaskRunner();
private:
void OnResponseComplete(
mojom::blink::ModelExecutionContextInfoPtr context_info);
void OnQuotaOverflow();
using ResolverOrStream =
std::variant<ScriptPromiseResolverBase*, ReadableStream*>;
// Helper to make AILanguageModelProxy::Prompt compatible with
// ConvertPromptInputsToMojo callback.
void ExecutePrompt(
ScriptState* script_state,
ResolverOrStream resolver_or_stream,
on_device_model::mojom::blink::ResponseConstraintPtr constraint,
mojo::PendingRemote<mojom::blink::ModelStreamingResponder>
pending_responder,
WTF::Vector<mojom::blink::AILanguageModelPromptPtr> prompts);
// Helper to make AILanguageModelProxy::MeasureInputUsage compatible with
// ConvertPromptInputsToMojo callback.
void ExecuteMeasureInputUsage(
ScriptPromiseResolver<IDLDouble>* resolver,
AbortSignal* signal,
WTF::Vector<mojom::blink::AILanguageModelPromptPtr> prompts);
// Validates and processed prompt input and returns the processed constraints.
// Returns std::nullopt on failure.
std::optional<on_device_model::mojom::blink::ResponseConstraintPtr>
ValidateAndProcessPromptInput(ScriptState* script_state,
const V8LanguageModelPrompt* input,
const LanguageModelPromptOptions* options,
ExceptionState& exception_state);
uint64_t input_usage_;
uint64_t input_quota_ = 0;
uint32_t top_k_ = 0;
float temperature_ = 0.0;
// Prompt types supported by the language model in this session.
WTF::HashSet<mojom::blink::AILanguageModelPromptType> input_types_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
HeapMojoRemote<mojom::blink::AILanguageModel> language_model_remote_;
};
} // namespace blink
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_AI_LANGUAGE_MODEL_H_
|