1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_execution/safety_client.h"
#include "base/task/thread_pool.h"
#include "base/types/expected.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
namespace optimization_guide {
SafetyClient::SafetyClient(
base::WeakPtr<on_device_model::ServiceClient> service_client)
: service_client_(std::move(service_client)) {}
SafetyClient::~SafetyClient() = default;
void SafetyClient::SetLanguageDetectionModel(
base::optional_ref<const ModelInfo> model_info) {
if (!model_info.has_value()) {
language_detection_model_path_.reset();
return;
}
remote_.reset(); // The remote's assets are outdated.
language_detection_model_path_ = model_info->GetModelFilePath();
}
void SafetyClient::MaybeUpdateSafetyModel(
base::optional_ref<const ModelInfo> model_info) {
if (safety_model_info_ && model_info &&
safety_model_info_->GetVersion() == model_info->GetVersion()) {
// We could get duplicate update notifications because this object could
// receive model updates from multiple profiles.
return;
}
// New safety model means new configs, fail existing sessions.
weak_ptr_factory_.InvalidateWeakPtrs();
auto new_info = SafetyModelInfo::Load(model_info);
if (!new_info) {
safety_model_info_.reset();
return;
}
remote_.reset(); // The remote's assets are outdated.
safety_model_info_ = std::move(new_info);
}
base::expected<std::unique_ptr<SafetyChecker>, OnDeviceModelEligibilityReason>
SafetyClient::MakeSafetyChecker(ModelBasedCapabilityKey feature,
bool can_skip) {
if (!features::ShouldUseTextSafetyClassifierModel() || can_skip) {
// Construct a dummy checker that always passes all checks.
return std::make_unique<SafetyChecker>(nullptr, SafetyConfig());
}
if (!safety_model_info_) {
return base::unexpected(
OnDeviceModelEligibilityReason::kSafetyModelNotAvailable);
}
auto config =
safety_model_info_->GetConfig(ToModelExecutionFeatureProto(feature));
if (!config) {
return base::unexpected(
OnDeviceModelEligibilityReason::kSafetyConfigNotAvailableForFeature);
}
if (!config->allowed_languages().empty() && !language_detection_model_path_) {
return base::unexpected(
OnDeviceModelEligibilityReason::kLanguageDetectionModelNotAvailable);
}
return std::make_unique<SafetyChecker>(weak_ptr_factory_.GetWeakPtr(),
SafetyConfig(*config));
}
void SafetyClient::StartSession(
mojo::PendingReceiver<on_device_model::mojom::TextSafetySession> session) {
GetTextSafetyModelRemote()->StartSession(std::move(session));
}
on_device_model::TextSafetyLoaderParams SafetyClient::LoaderParams() const {
on_device_model::TextSafetyLoaderParams params;
// Populate the model paths even if they are not needed for the current
// feature, since the base model remote could be used for subsequent features.
if (safety_model_info_) {
params.ts_paths.emplace();
params.ts_paths->data = safety_model_info_->GetDataPath();
params.ts_paths->sp_model = safety_model_info_->GetSpModelPath();
}
if (language_detection_model_path_) {
params.language_paths.emplace();
params.language_paths->model = *language_detection_model_path_;
}
return params;
}
SafetyClient::Remote& SafetyClient::GetTextSafetyModelRemote() {
if (remote_) {
return remote_;
}
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&on_device_model::LoadTextSafetyParams, LoaderParams()),
base::BindOnce(
[](base::WeakPtr<SafetyClient> self,
mojo::PendingReceiver<on_device_model::mojom::TextSafetyModel>
model,
on_device_model::mojom::TextSafetyModelParamsPtr params) {
if (!self || !self->service_client_) {
// Close the files on a background thread.
base::ThreadPool::PostTask(
FROM_HERE, {base::MayBlock()},
base::DoNothingWithBoundArgs(std::move(params)));
return;
}
self->service_client_->Get()->LoadTextSafetyModel(std::move(params),
std::move(model));
},
weak_ptr_factory_.GetWeakPtr(),
remote_.BindNewPipeAndPassReceiver()));
remote_.reset_on_disconnect(); // Maybe track disconnects?
remote_.reset_on_idle_timeout(features::GetOnDeviceModelIdleTimeout());
return remote_;
}
} // namespace optimization_guide
|