File: preloading_model_executor_unittest.cc

package info (click to toggle)
chromium 138.0.7204.183-1~deb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-proposed-updates
  • size: 6,080,960 kB
  • sloc: cpp: 34,937,079; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,954; asm: 946,768; xml: 739,971; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,811; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (89 lines) | stat: -rw-r--r-- 3,693 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_executor.h"

#include "base/base_paths.h"
#include "base/path_service.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features.h"

using ModelInput = PreloadingModelExecutor::ModelInput;
using ModelOutput = PreloadingModelExecutor::ModelOutput;

class PreloadingModelExecutorTest : public testing::Test {
 public:
  PreloadingModelExecutorTest() {
    scoped_feature_list_.InitAndEnableFeature(
        blink::features::kPreloadingHeuristicsMLModel);
  }
  ~PreloadingModelExecutorTest() override = default;

  void SetUp() override {
    base::FilePath source_root_dir;
    base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &source_root_dir);

    model_file_path_ = source_root_dir.AppendASCII("chrome")
                           .AppendASCII("browser")
                           .AppendASCII("navigation_predictor")
                           .AppendASCII("test")
                           .AppendASCII("preloading_heuristics.tflite");
    execution_task_runner_ = base::ThreadPool::CreateSequencedTaskRunner(
        {base::MayBlock(), base::TaskPriority::BEST_EFFORT});
    model_executor_ = std::make_unique<PreloadingModelExecutor>();
    model_executor_->InitializeAndMoveToExecutionThread(
        /*model_inference_timeout=*/std::nullopt,
        optimization_guide::proto::OPTIMIZATION_TARGET_OMNIBOX_URL_SCORING,
        execution_task_runner_, base::SequencedTaskRunner::GetCurrentDefault());
  }

  void TearDown() override {
    // Destroy model executor.
    execution_task_runner_->DeleteSoon(FROM_HERE, std::move(model_executor_));
    RunUntilIdle();
  }

  void RunUntilIdle() { task_environment_.RunUntilIdle(); }

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  base::test::TaskEnvironment task_environment_;
  base::FilePath model_file_path_;
  scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;
  std::unique_ptr<PreloadingModelExecutor> model_executor_;
};

TEST_F(PreloadingModelExecutorTest, ExecuteModel) {
  // Update model file.
  execution_task_runner_->PostTask(
      FROM_HERE,
      base::BindOnce(
          &optimization_guide::ModelExecutor<ModelOutput,
                                             ModelInput>::UpdateModelFile,
          model_executor_->GetWeakPtrForExecutionThread(), model_file_path_));

  // Execute model.
  std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
  base::OnceCallback<void(const std::optional<ModelOutput>&)>
      execution_callback = base::BindOnce(
          [](base::RunLoop* run_loop,
             const std::optional<ModelOutput>& output) {
            ASSERT_TRUE(output.has_value());
            // TODO(isaboori): After the trained model is approved, use
            // realistic inputs and check the output value.
            run_loop->Quit();
          },
          run_loop.get());
  base::TimeTicks now = base::TimeTicks::Now();
  ModelInput input = std::vector<float>(/*count=*/17, /*value=*/0.0);
  execution_task_runner_->PostTask(
      FROM_HERE, base::BindOnce(&optimization_guide::ModelExecutor<
                                    ModelOutput, ModelInput>::SendForExecution,
                                model_executor_->GetWeakPtrForExecutionThread(),
                                std::move(execution_callback), now, input));
  run_loop->Run();
}