1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/policy/test_support/policy_storage.h"
#include "base/numerics/byte_conversions.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/strings/string_view_util.h"
#include "crypto/hash.h"
namespace policy {
namespace {
const char kPolicyKeySeparator[] = "/";
std::string GetPolicyKey(const std::string& policy_type,
const std::string& entity_id) {
if (entity_id.empty())
return policy_type;
return base::StrCat({policy_type, kPolicyKeySeparator, entity_id});
}
} // namespace
PolicyStorage::PolicyStorage()
: signature_provider_(std::make_unique<SignatureProvider>()) {}
PolicyStorage::PolicyStorage(PolicyStorage&& policy_storage) = default;
PolicyStorage& PolicyStorage::operator=(PolicyStorage&& policy_storage) =
default;
PolicyStorage::~PolicyStorage() = default;
std::string PolicyStorage::GetPolicyPayload(
const std::string& policy_type,
const std::string& entity_id) const {
auto it = policy_payloads_.find(GetPolicyKey(policy_type, entity_id));
return it == policy_payloads_.end() ? std::string() : it->second;
}
std::vector<std::string> PolicyStorage::GetEntityIdsForType(
const std::string& policy_type) {
std::string prefix = base::StrCat({policy_type, kPolicyKeySeparator});
std::vector<std::string> ids;
const size_t prefix_length = prefix.length();
for (const auto& [policy_key, payload] : policy_payloads_) {
if (base::StartsWith(policy_key, prefix))
ids.push_back(policy_key.substr(prefix_length));
}
return ids;
}
void PolicyStorage::SetPolicyPayload(const std::string& policy_type,
const std::string& policy_payload) {
SetPolicyPayload(policy_type, std::string(), policy_payload);
}
void PolicyStorage::SetPolicyPayload(const std::string& policy_type,
const std::string& entity_id,
const std::string& policy_payload) {
policy_payloads_[GetPolicyKey(policy_type, entity_id)] = policy_payload;
}
std::string PolicyStorage::GetExternalPolicyPayload(
const std::string& policy_type,
const std::string& entity_id) {
std::string policy_key = GetPolicyKey(policy_type, entity_id);
return external_policy_payloads_.contains(policy_key)
? external_policy_payloads_.at(policy_key)
: std::string();
}
void PolicyStorage::SetExternalPolicyPayload(
const std::string& policy_type,
const std::string& entity_id,
const std::string& policy_payload) {
external_policy_payloads_[GetPolicyKey(policy_type, entity_id)] =
policy_payload;
}
void PolicyStorage::SetPsmEntry(const std::string& brand_serial_id,
const PolicyStorage::PsmEntry& psm_entry) {
psm_entries_[brand_serial_id] = psm_entry;
}
const PolicyStorage::PsmEntry* PolicyStorage::GetPsmEntry(
const std::string& brand_serial_id) const {
auto it = psm_entries_.find(brand_serial_id);
if (it == psm_entries_.end())
return nullptr;
return &it->second;
}
void PolicyStorage::SetInitialEnrollmentState(
const std::string& brand_serial_id,
const PolicyStorage::InitialEnrollmentState& initial_enrollment_state) {
initial_enrollment_states_[brand_serial_id] = initial_enrollment_state;
}
const PolicyStorage::InitialEnrollmentState*
PolicyStorage::GetInitialEnrollmentState(
const std::string& brand_serial_id) const {
auto it = initial_enrollment_states_.find(brand_serial_id);
if (it == initial_enrollment_states_.end())
return nullptr;
return &it->second;
}
std::vector<std::string> PolicyStorage::GetMatchingSerialHashes(
uint64_t modulus,
uint64_t remainder) const {
std::vector<std::string> hashes;
for (const auto& [serial, enrollment_state] : initial_enrollment_states_) {
auto hash = crypto::hash::Sha256(serial);
auto hash_first8 = base::span<uint8_t>(hash).first<8>();
if (base::U64FromBigEndian(hash_first8) % modulus == remainder) {
hashes.emplace_back(base::as_string_view(hash_first8));
}
}
return hashes;
}
} // namespace policy
|