1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
#include "base/time/time.h"
namespace segmentation_platform {
using TrainingData = proto::TrainingData;
TrainingDataCache::TrainingDataCache(SegmentInfoDatabase* segment_info_database)
: segment_info_database_(segment_info_database) {}
TrainingDataCache::~TrainingDataCache() = default;
void TrainingDataCache::StoreInputs(SegmentId segment_id,
ModelSource model_source,
const TrainingData& data,
bool save_to_db) {
if (save_to_db) {
// TODO (ritikagup@) : Add handling for default models, if required.
segment_info_database_->SaveTrainingData(
segment_id, model_source, std::move(data), base::DoNothing());
} else {
cache[std::make_pair(segment_id, model_source)]
[TrainingRequestId::FromUnsafeValue(data.request_id())] =
std::move(data);
}
}
void TrainingDataCache::GetInputsAndDelete(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
TrainingDataCallback callback) {
absl::optional<TrainingData> result;
if (cache.contains(std::make_pair(segment_id, model_source)) &&
cache[std::make_pair(segment_id, model_source)].contains(request_id)) {
// TrainingRequestId found from cache, return and delete the cache entry.
auto& segment_data = cache[std::make_pair(segment_id, model_source)];
auto it = segment_data.find(request_id);
result = std::move(it->second);
segment_data.erase(it);
std::move(callback).Run(result);
} else {
segment_info_database_->GetTrainingData(
segment_id, model_source, request_id,
/*delete_from_db=*/true, std::move(callback));
}
}
absl::optional<TrainingRequestId> TrainingDataCache::GetRequestId(
SegmentId segment_id,
ModelSource model_source) {
// TODO(haileywang): Add a metric to record how many request at a given time
// every time this function is triggered.
absl::optional<TrainingRequestId> request_id;
auto it = cache.find(std::make_pair(segment_id, model_source));
if (it == cache.end() or it->second.size() == 0) {
return request_id;
}
return it->second.begin()->first;
}
TrainingRequestId TrainingDataCache::GenerateNextId() {
return request_id_generator.GenerateNextId();
}
} // namespace segmentation_platform
|