File: session_impl.h

package info (click to toggle)
chromium 144.0.7559.109-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,915,868 kB
  • sloc: cpp: 35,866,215; ansic: 7,599,035; javascript: 3,623,761; python: 1,639,407; xml: 833,084; asm: 716,173; pascal: 185,323; sh: 88,763; perl: 88,699; objc: 79,984; sql: 58,217; cs: 42,430; fortran: 24,101; makefile: 20,747; tcl: 15,277; php: 14,022; yacc: 9,059; ruby: 7,553; awk: 3,720; lisp: 3,233; lex: 1,330; ada: 727; jsp: 228; sed: 36
file content (126 lines) | stat: -rw-r--r-- 5,450 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_

#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "base/functional/callback_forward.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/timer/timer.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/on_device_capability.h"
#include "components/optimization_guide/core/model_execution/on_device_context.h"
#include "components/optimization_guide/core/model_execution/on_device_execution.h"
#include "components/optimization_guide/core/model_execution/on_device_model_feature_adapter.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/model_execution/safety_checker.h"
#include "components/optimization_guide/core/model_execution/substitution.h"
#include "components/optimization_guide/core/model_quality/model_quality_logs_uploader_service.h"
#include "components/optimization_guide/proto/model_quality_metadata.pb.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"

namespace optimization_guide {

class OnDeviceContext;

// Session implementation that uses either the on device model or the server
// model.
class SessionImpl : public OnDeviceSession {
 public:
  // Possible outcomes of AddContext(). Maps to histogram enum
  // "OptimizationGuideOnDeviceAddContextResult".
  // These values are persisted to logs. Entries should not be renumbered
  // and numeric values should never be reused.
  enum class AddContextResult {
    kUsingServer = 0,
    kUsingOnDevice = 1,
    kFailedConstructingInput = 2,
    kMaxValue = kFailedConstructingInput,
  };

  SessionImpl(mojom::OnDeviceFeature feature, OnDeviceOptions on_device_opts);
  SessionImpl(mojom::OnDeviceFeature feature,
              const SamplingParams& sampling_params);
  ~SessionImpl() override;

  // optimization_guide::OnDeviceSession:
  const TokenLimits& GetTokenLimits() const override;
  const proto::Any& GetOnDeviceFeatureMetadata() const override;
  void SetInput(MultimodalMessage request, SetInputCallback callback) override;
  void AddContext(
      const google::protobuf::MessageLite& request_metadata) override;
  void Score(const std::string& text,
             OptimizationGuideModelScoreCallback callback) override;
  void ExecuteModel(
      const google::protobuf::MessageLite& request_metadata,
      OptimizationGuideModelExecutionResultStreamingCallback callback) override;
  void ExecuteModelWithResponseConstraint(
      const google::protobuf::MessageLite& request_metadata,
      on_device_model::mojom::ResponseConstraintPtr constraint,
      OptimizationGuideModelExecutionResultStreamingCallback callback) override;
  void GetSizeInTokens(
      const std::string& text,
      OptimizationGuideModelSizeInTokenCallback callback) override;
  void GetExecutionInputSizeInTokens(
      MultimodalMessageReadView request_metadata,
      OptimizationGuideModelSizeInTokenCallback callback) override;
  void GetContextSizeInTokens(
      MultimodalMessageReadView request_metadata,
      OptimizationGuideModelSizeInTokenCallback callback) override;
  const SamplingParams GetSamplingParams() const override;
  on_device_model::Capabilities GetCapabilities() const override;
  std::unique_ptr<OnDeviceSession> Clone() override;
  void SetPriority(on_device_model::mojom::Priority priority) override;

  // Returns true if the on-device model should be used.
  bool ShouldUseOnDeviceModel() const;

 private:
  AddContextResult AddContextImpl(MultimodalMessage request,
                                  SetInputCallback callback);

  void DestroyOnDeviceState();

  // Called when an on-device execution flow terminates, and can be cleaned up.
  void OnDeviceExecutionTerminated(bool healthy);

  // Helper function to get the size of request in tokens with boolean flag to
  // control if we are extracting the context or the execution text.
  void GetSizeInTokensInternal(
      MultimodalMessageReadView request,
      OptimizationGuideModelSizeInTokenCallback callback,
      bool want_input_context);

  const mojom::OnDeviceFeature feature_;

  MultimodalMessage context_;
  base::TimeTicks context_start_time_;

  // Manages the on-device session holding the processed context.
  // If this is null, on-device executions cannot be started.
  std::unique_ptr<OnDeviceContext> on_device_context_;

  // Manages state for an ongoing on-device execution.
  std::optional<OnDeviceExecution> on_device_execution_;

  // Params used to control output sampling for the on device model.
  const SamplingParams sampling_params_;

  // Capabilities for this session of the on device model.
  on_device_model::Capabilities capabilities_;

  base::WeakPtrFactory<SessionImpl> weak_ptr_factory_{this};
};

}  // namespace optimization_guide

#endif  // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_