File: permissions_aiv3_encoder.cc

package info (click to toggle)
chromium 139.0.7258.127-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,122,156 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (97 lines) | stat: -rw-r--r-- 3,515 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/permissions/prediction_service/permissions_aiv3_encoder.h"

#include <array>
#include <vector>

#include "skia/ext/image_operations.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"

namespace permissions {

namespace {
using ModelInput = PermissionsAiv3Encoder::ModelInput;
using ModelOutput = PermissionsAiv3Encoder::ModelOutput;
using ::tflite::task::core::PopulateTensor;

PermissionRequestRelevance ConvertToRelevance(
    float val,
    const std::array<float, 4>& thresholds) {
  // Empirically determined thresholds, that map to relevance enum vals as
  // follows:
  // val < thr[0] -> VeryLow
  // ...
  // val < thr[4] -> High
  // val >= thr[4] -> VeryHigh
  for (size_t i = 0; i < thresholds.size(); ++i) {
    if (val < thresholds[i]) {
      return static_cast<PermissionRequestRelevance>(i + 1);
    }
  }
  return PermissionRequestRelevance::kVeryHigh;
}

}  // namespace

const int PermissionsAiv3Encoder::kModelInputWidth = 64;
const int PermissionsAiv3Encoder::kModelInputHeight = 64;

bool PermissionsAiv3Encoder::Preprocess(
    const std::vector<TfLiteTensor*>& input_tensors,
    const ModelInput& input) {
  // TODO(crbug.com/405095664): Figure out if resize_best is fast enough to deal
  // with too big/small inputs.
  SkBitmap resized =
      skia::ImageOperations::Resize(input, skia::ImageOperations::RESIZE_BEST,
                                    kModelInputWidth, kModelInputHeight);
  if (resized.drawsNothing()) {
    return false;
  }

  std::array<float, kModelInputHeight * kModelInputWidth * 3> data;
  int index = 0;
  for (int h = 0; h < resized.height(); ++h) {
    for (int w = 0; w < resized.width(); ++w) {
      SkColor color = resized.getColor(h, w);
      // We normalize the pixel values to be in between 0 and 1.
      // TODO(crbug.com/405095664): We need to investigate if this is the
      // correct way to fill the tensors data;
      data[index++] = static_cast<float>(SkColorGetR(color)) / 255.0f;
      data[index++] = static_cast<float>(SkColorGetG(color)) / 255.0f;
      data[index++] = static_cast<float>(SkColorGetB(color)) / 255.0f;
    }
  }
  if (!PopulateTensor<float>(data.data(), data.size(), input_tensors[0]).ok()) {
    return false;
  }
  return true;
}

std::optional<ModelOutput> PermissionsAiv3Encoder::Postprocess(
    const std::vector<const TfLiteTensor*>& output_tensors) {
  DCHECK(request_type_ == RequestType::kNotifications ||
         request_type_ == RequestType::kGeolocation);

  // TODO(crbug.com/405095664): should be fetched via the model metadata proto
  // as soon as we have this.
  static constexpr std::array<float, 4> geolocation_thresholds = {0.2f, 0.4f,
                                                                  0.5f, 0.65f};
  static constexpr std::array<float, 4> notification_thresholds = {0.2f, 0.4f,
                                                                   0.7f, 0.84f};

  std::vector<float> data;
  if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &data)
           .ok()) {
    return std::nullopt;
  }

  return ConvertToRelevance(data[0],
                            request_type_ == RequestType::kNotifications
                                ? notification_thresholds
                                : geolocation_thresholds);
}

}  // namespace permissions