1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/translate/core/language_detection/language_detection_model.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/metrics/metrics_hashes.h"
#include "base/timer/elapsed_timer.h"
#include "base/trace_event/trace_event.h"
#include "build/build_config.h"
#include "components/language/core/common/language_util.h"
#include "components/language_detection/core/constants.h"
#include "components/language_detection/core/language_detection_model.h"
#include "components/translate/core/common/translate_util.h"
#include "components/translate/core/language_detection/language_detection_util.h"
namespace translate {
LanguageDetectionModel::LanguageDetectionModel(
language_detection::LanguageDetectionModel& shared_tflite_model)
: tflite_model_(shared_tflite_model) {}
LanguageDetectionModel::LanguageDetectionModel(
std::unique_ptr<language_detection::LanguageDetectionModel>
owned_tflite_model)
: owned_tflite_model_(std::move(owned_tflite_model)),
tflite_model_(*owned_tflite_model_.get()) {
owned_tflite_model_->DetachFromSequence();
}
LanguageDetectionModel::~LanguageDetectionModel() = default;
void LanguageDetectionModel::UpdateWithFile(base::File model_file) {
tflite_model_->UpdateWithFile(std::move(model_file));
}
void LanguageDetectionModel::UpdateWithFileAsync(base::File model_file,
base::OnceClosure callback) {
tflite_model_->UpdateWithFileAsync(std::move(model_file),
std::move(callback));
}
bool LanguageDetectionModel::IsAvailable() const {
return tflite_model_->IsAvailable();
}
std::string LanguageDetectionModel::DeterminePageLanguage(
const std::string& code,
const std::string& html_lang,
const std::u16string& contents,
std::string* predicted_language,
bool* is_prediction_reliable,
float& prediction_reliability_score) const {
DCHECK(IsAvailable());
if (!predicted_language || !is_prediction_reliable) {
return language_detection::kUnknownLanguageCode;
}
*is_prediction_reliable = false;
*predicted_language = language_detection::kUnknownLanguageCode;
prediction_reliability_score = 0.0;
if (!tflite_model_->IsAvailable()) {
return language_detection::kUnknownLanguageCode;
}
const language_detection::Prediction prediction = DetectLanguage(contents);
prediction_reliability_score = prediction.score;
// TODO(crbug.com/40748826): Use the model threshold provided
// by the model itself. Not needed until threshold is finalized.
bool is_reliable =
prediction_reliability_score > GetTFLiteLanguageDetectionThreshold();
std::string final_prediction = translate::FilterDetectedLanguage(
base::UTF16ToUTF8(contents), prediction.language, is_reliable);
*predicted_language = final_prediction;
*is_prediction_reliable = is_reliable;
language::ToTranslateLanguageSynonym(&final_prediction);
LOCAL_HISTOGRAM_BOOLEAN("LanguageDetection.TFLite.DidAttemptDetection", true);
return translate::DeterminePageLanguage(code, html_lang, final_prediction,
is_reliable);
}
language_detection::Prediction LanguageDetectionModel::DetectLanguage(
const std::u16string& contents) const {
base::ElapsedTimer timer;
auto prediction = tflite_model_->PredictTopLanguageWithSamples(contents);
base::UmaHistogramTimes(
"LanguageDetection.TFLiteModel.DetectPageLanguage.Duration",
timer.Elapsed());
base::UmaHistogramCounts1M(
"LanguageDetection.TFLiteModel.DetectPageLanguage.Size", contents.size());
return prediction;
}
std::string LanguageDetectionModel::GetModelVersion() const {
return tflite_model_->GetModelVersion();
}
} // namespace translate
|