1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
|
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/components/string_matching/fuzzy_tokenized_string_match.h"
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iterator>
#include <optional>
#include <set>
#include <string>
#include <vector>
#include "base/i18n/case_conversion.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "chromeos/ash/components/string_matching/acronym_matcher.h"
#include "chromeos/ash/components/string_matching/diacritic_utils.h"
#include "chromeos/ash/components/string_matching/prefix_matcher.h"
#include "chromeos/ash/components/string_matching/sequence_matcher.h"
namespace ash::string_matching {
namespace {
using Hits = FuzzyTokenizedStringMatch::Hits;
constexpr double kPartialMatchPenaltyRate = 0.9;
constexpr double kMinScore = 0.0;
constexpr double kMaxScore = 1.0;
// The maximum supported size for a prefix matching scoring boost.
constexpr size_t kMaxBoostSize = 2;
// The scale ratio for non exact matching results.
constexpr double kNonExactMatchScaleRatio = 0.97;
// Returns sorted tokens from a TokenizedString.
std::vector<std::u16string> ProcessAndSort(const TokenizedString& text) {
std::vector<std::u16string> result;
for (const auto& token : text.tokens()) {
result.emplace_back(token);
}
std::sort(result.begin(), result.end());
return result;
}
double ScaledRelevance(const double relevance) {
return 1.0 - std::pow(0.5, relevance);
}
} // namespace
FuzzyTokenizedStringMatch::~FuzzyTokenizedStringMatch() = default;
FuzzyTokenizedStringMatch::FuzzyTokenizedStringMatch() = default;
double FuzzyTokenizedStringMatch::TokenSetRatio(const TokenizedString& query,
const TokenizedString& text,
bool partial) {
std::set<std::u16string> query_token(query.tokens().begin(),
query.tokens().end());
std::set<std::u16string> text_token(text.tokens().begin(),
text.tokens().end());
std::vector<std::u16string> intersection;
std::vector<std::u16string> query_diff_text;
std::vector<std::u16string> text_diff_query;
// Find the set intersection and the set differences between two sets of
// tokens.
std::set_intersection(query_token.begin(), query_token.end(),
text_token.begin(), text_token.end(),
std::back_inserter(intersection));
std::set_difference(query_token.begin(), query_token.end(),
text_token.begin(), text_token.end(),
std::back_inserter(query_diff_text));
std::set_difference(text_token.begin(), text_token.end(), query_token.begin(),
query_token.end(), std::back_inserter(text_diff_query));
const std::u16string intersection_string =
base::JoinString(intersection, u" ");
const std::u16string query_rewritten =
intersection.empty()
? base::JoinString(query_diff_text, u" ")
: base::StrCat({intersection_string, u" ",
base::JoinString(query_diff_text, u" ")});
const std::u16string text_rewritten =
intersection.empty()
? base::JoinString(text_diff_query, u" ")
: base::StrCat({intersection_string, u" ",
base::JoinString(text_diff_query, u" ")});
if (partial) {
return std::max({PartialRatio(intersection_string, query_rewritten),
PartialRatio(intersection_string, text_rewritten),
PartialRatio(query_rewritten, text_rewritten)});
}
return std::max(
{SequenceMatcher(intersection_string, query_rewritten).Ratio(),
SequenceMatcher(intersection_string, text_rewritten).Ratio(),
SequenceMatcher(query_rewritten, text_rewritten).Ratio()});
}
double FuzzyTokenizedStringMatch::TokenSortRatio(const TokenizedString& query,
const TokenizedString& text,
bool partial) {
const std::u16string query_sorted =
base::JoinString(ProcessAndSort(query), u" ");
const std::u16string text_sorted =
base::JoinString(ProcessAndSort(text), u" ");
if (partial) {
return PartialRatio(query_sorted, text_sorted);
}
return SequenceMatcher(query_sorted, text_sorted).Ratio();
}
double FuzzyTokenizedStringMatch::PartialRatio(const std::u16string& query,
const std::u16string& text) {
if (query.empty() || text.empty()) {
return kMinScore;
}
std::u16string shorter = query;
std::u16string longer = text;
if (shorter.size() > longer.size()) {
shorter = text;
longer = query;
}
const auto matching_blocks =
SequenceMatcher(shorter, longer).GetMatchingBlocks();
double partial_ratio = 0;
for (const auto& block : matching_blocks) {
const int long_start =
block.pos_second_string > block.pos_first_string
? block.pos_second_string - block.pos_first_string
: 0;
// Penalizes the match if it is not close to the beginning of a token.
int current = long_start - 1;
while (current >= 0 &&
!base::EqualsCaseInsensitiveASCII(longer.substr(current, 1), u" ")) {
current--;
}
const double penalty =
std::pow(kPartialMatchPenaltyRate, long_start - current - 1);
// TODO(crbug.com/40638914): currently this part re-calculate the ratio for
// every pair. Improve this to reduce latency.
partial_ratio = std::max(
partial_ratio,
SequenceMatcher(shorter, longer.substr(long_start, shorter.size()))
.Ratio() *
penalty);
if (partial_ratio > 0.995) {
return kMaxScore;
}
}
return partial_ratio;
}
double FuzzyTokenizedStringMatch::WeightedRatio(const TokenizedString& query,
const TokenizedString& text) {
// All token based comparisons are scaled by 0.95 (on top of any partial
// scalars), as per original implementation:
// https://github.com/seatgeek/fuzzywuzzy/blob/af443f918eebbccff840b86fa606ac150563f466/fuzzywuzzy/fuzz.py#L245
const double unbase_scale = 0.95;
// Since query.text() and text.text() is not normalized, we use query.tokens()
// and text.tokens() instead.
const std::u16string query_normalized(base::JoinString(query.tokens(), u" "));
const std::u16string text_normalized(base::JoinString(text.tokens(), u" "));
std::vector<double> weighted_ratios;
weighted_ratios.emplace_back(
SequenceMatcher(query_normalized, text_normalized)
.Ratio(/*text_length_agnostic=*/true));
const double length_ratio =
static_cast<double>(
std::max(query_normalized.size(), text_normalized.size())) /
std::min(query_normalized.size(), text_normalized.size());
// Use partial if two strings are quite different in sizes.
const bool use_partial = length_ratio >= 1.5;
double length_ratio_scale = 1;
if (use_partial) {
// TODO(crbug.com/1336160): Consider scaling |partial_scale| smoothly with
// |length_ratio|, instead of using a step function.
//
// If one string is much much shorter than the other, set |partial_scale| to
// be 0.6, otherwise set it to be 0.9.
length_ratio_scale = length_ratio > 8 ? 0.6 : 0.9;
weighted_ratios.emplace_back(
PartialRatio(query_normalized, text_normalized) * length_ratio_scale);
}
weighted_ratios.emplace_back(TokenSortRatio(query, text, use_partial) *
unbase_scale * length_ratio_scale);
// Do not use partial match for token set because the match between the
// intersection string and query/text rewrites will always return an extremely
// high value.
weighted_ratios.emplace_back(TokenSetRatio(query, text, false /*partial*/) *
unbase_scale * length_ratio_scale);
// Return the maximum of all included weighted ratios
return *std::max_element(weighted_ratios.begin(), weighted_ratios.end());
}
double FuzzyTokenizedStringMatch::PrefixMatcher(const TokenizedString& query,
const TokenizedString& text) {
string_matching::PrefixMatcher match(query, text);
match.Match();
return ScaledRelevance(match.relevance());
}
double FuzzyTokenizedStringMatch::AcronymMatcher(const TokenizedString& query,
const TokenizedString& text) {
string_matching::AcronymMatcher match(query, text);
const double relevance = match.CalculateRelevance();
return ScaledRelevance(relevance);
}
double FuzzyTokenizedStringMatch::PrefixMatcher(
const TokenizedString& query,
const TokenizedString& text,
std::vector<Hits>& hits_vector) {
string_matching::PrefixMatcher match(query, text);
match.Match();
hits_vector.emplace_back(match.hits());
return ScaledRelevance(match.relevance());
}
double FuzzyTokenizedStringMatch::AcronymMatcher(
const TokenizedString& query,
const TokenizedString& text,
std::vector<Hits>& hits_vector) {
string_matching::AcronymMatcher match(query, text);
const double relevance = match.CalculateRelevance();
hits_vector.emplace_back(match.hits());
return ScaledRelevance(relevance);
}
double FuzzyTokenizedStringMatch::Relevance(const TokenizedString& query_input,
const TokenizedString& text_input,
bool use_weighted_ratio,
bool strip_diacritics,
bool use_acronym_matcher) {
// If the query is much longer than the text then it's often not a match.
if (query_input.text().size() >= text_input.text().size() * 2) {
return 0.0;
}
std::optional<TokenizedString> stripped_query;
std::optional<TokenizedString> stripped_text;
if (strip_diacritics) {
stripped_query.emplace(RemoveDiacritics(query_input.text()));
stripped_text.emplace(RemoveDiacritics(text_input.text()));
}
const TokenizedString& query =
strip_diacritics ? stripped_query.value() : query_input;
const TokenizedString& text =
strip_diacritics ? stripped_text.value() : text_input;
// If there is an exact match, relevance will be 1.0 and there is only 1
// hit that is the entire text/query.
const auto& query_text = query.text();
const auto& text_text = text.text();
const auto query_size = query_text.size();
const auto text_size = text_text.size();
if (query_size > 0 && query_size == text_size &&
base::EqualsCaseInsensitiveASCII(query_text, text_text)) {
hits_.emplace_back(0, query_size);
return 1.0;
}
// The |relevances| stores the |relevance_scores| calculated from different
// string matching methods. The highest result among them will be returned.
std::vector<double> relevances;
// The |hits_vector| stores the |hits| calculated from different string
// matching methods. The final selected instance corresponds to the hits
// generated by the matching algorithm which yielded the highest relevance
// score. The final selected instance will be assigned to |hits_| then.
std::vector<Hits> hits_vector;
double prefix_score = PrefixMatcher(query, text, hits_vector);
// A scoring boost for short prefix matching queries.
if (query_size <= kMaxBoostSize && prefix_score > kMinScore) {
prefix_score = std::min(
1.0, prefix_score + 2.0 / (query_size * (query_size + text_size)));
}
relevances.emplace_back(prefix_score);
// Find hits using SequenceMatcher on original query and text.
Hits sequence_hits;
size_t match_size = 0;
for (const auto& match :
SequenceMatcher(query_text, text_text).GetMatchingBlocks()) {
if (match.length > 0) {
match_size += match.length;
sequence_hits.emplace_back(match.pos_second_string,
match.pos_second_string + match.length);
}
}
hits_vector.emplace_back(sequence_hits);
relevances.emplace_back(use_weighted_ratio
? WeightedRatio(query, text)
: SequenceMatcher(base::i18n::ToLower(query_text),
base::i18n::ToLower(text_text))
.Ratio(/*text_length_agnostic=*/true));
if (use_acronym_matcher) {
relevances.emplace_back(AcronymMatcher(query, text, hits_vector));
}
size_t best_match_pos =
std::max_element(relevances.begin(), relevances.end()) -
relevances.begin();
hits_ = hits_vector[best_match_pos];
return match_size == text_size
? relevances[best_match_pos]
: relevances[best_match_pos] * kNonExactMatchScaleRatio;
}
} // namespace ash::string_matching
|