File: permissions_aiv4_executor.cc

package info (click to toggle)
chromium 141.0.7390.107-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,246,132 kB
  • sloc: cpp: 35,264,965; ansic: 7,169,920; javascript: 4,250,185; python: 1,460,635; asm: 950,788; xml: 751,751; pascal: 187,972; sh: 89,459; perl: 88,691; objc: 79,953; sql: 53,924; cs: 44,622; fortran: 24,137; makefile: 22,313; tcl: 15,277; php: 14,018; yacc: 8,995; ruby: 7,553; awk: 3,720; lisp: 3,096; lex: 1,330; ada: 727; jsp: 228; sed: 36
file content (97 lines) | stat: -rw-r--r-- 3,756 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/permissions/prediction_service/permissions_aiv4_executor.h"

#include <array>
#include <vector>

#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/permissions/prediction_service/permissions_aiv4_model_metadata.pb.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"

namespace permissions {

using ModelInput = PermissionsAiv4Executor::ModelInput;
using ModelOutput = PermissionsAiv4Executor::ModelOutput;
using ::passage_embeddings::Embedding;
using ::tflite::task::core::PopulateTensor;

// The default size of the text input tensor for the model. This is
// necessary if the model metadata does not provide this value.
constexpr int kDefaultTextInputSize = 768;

PermissionsAiv4ExecutorInput::PermissionsAiv4ExecutorInput(
    SkBitmap snapshot,
    Embedding inner_text_embedding)
    : snapshot(snapshot), inner_text_embedding(inner_text_embedding) {}

PermissionsAiv4ExecutorInput::~PermissionsAiv4ExecutorInput() = default;
PermissionsAiv4ExecutorInput::PermissionsAiv4ExecutorInput(
    const PermissionsAiv4ExecutorInput&) = default;
PermissionsAiv4ExecutorInput::PermissionsAiv4ExecutorInput(
    PermissionsAiv4ExecutorInput&&) = default;

bool PermissionsAiv4Executor::Preprocess(
    const std::vector<TfLiteTensor*>& input_tensors,
    const ModelInput& input) {
  DCHECK(input_tensors.size() == 2);

  int expected_input_size = kDefaultTextInputSize;
  if (input.metadata.has_value() &&
      input.metadata.value().has_text_embeddings_input_size()) {
    expected_input_size = input.metadata.value().text_embeddings_input_size();
  }

  const auto& embedding = input.inner_text_embedding;
  if (static_cast<int>(embedding.Dimensions()) != expected_input_size) {
    VLOG(1)
        << "[PermissionsAiv4Executor]: Input Size does not match expectations: "
        << embedding.Dimensions() << " vs (expected) " << expected_input_size;
    return false;
  }
  if (!PopulateTensor<float>(embedding.GetData().data(), expected_input_size,
                             input_tensors[0])
           .ok()) {
    VLOG(1) << "[PermissionsAiv4Executor]: Failed to copy passage "
               "embedding.";
    return false;
  }
  if (!ConvertSkBitMapToTfliteTensor(input_tensors[1], input.snapshot)) {
    VLOG(1)
        << "[PermissionsAiv4Executor]: Failed to convert skbitmap to tflite "
           "tensor data.";
    return false;
  }
  VLOG(1) << "[PermissionsAiv4Executor]: Successfully encoded input!";
  SetThresholdValues(input.metadata);
  return true;
}

void PermissionsAiv4Executor::SetThresholdValues(
    base::optional_ref<const PermissionsAiv4ModelMetadata> metadata) {
  if (!metadata.has_value() || !metadata.value().has_relevance_thresholds()) {
    DCHECK(request_type() == RequestType::kNotifications ||
           request_type() == RequestType::kGeolocation);

    // Empirically determined thresholds, that map to relevance enum vals as
    // follows:
    // val < thr[0] -> VeryLow
    // ...
    // val < thr[4] -> High
    // val >= thr[4] -> VeryHigh
    relevance_thresholds() = {0.008f, 0.024f, 0.11f, 0.32f};
    if (request_type() == RequestType::kGeolocation) {
      relevance_thresholds() = {0.033f, 0.077f, 0.2f, 0.49f};
    }
    return;
  }
  const auto& thresholds = metadata.value().relevance_thresholds();
  relevance_thresholds() = {
      thresholds.min_low_relevance(), thresholds.min_medium_relevance(),
      thresholds.min_high_relevance(), thresholds.min_very_high_relevance()};
}

}  // namespace permissions