1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
/* SPDX-License-Identifier: GPL-3.0-or-later
* Copyright © 2022 The TokTok team.
*/
#include "shared_key_cache.h"
#include <stdint.h>
#include <string.h> // memcpy(...)
#include "attributes.h"
#include "ccompat.h"
#include "crypto_core.h"
#include "logger.h"
#include "mem.h"
#include "mono_time.h"
typedef struct Shared_Key {
uint8_t public_key[CRYPTO_PUBLIC_KEY_SIZE];
uint8_t shared_key[CRYPTO_SHARED_KEY_SIZE];
uint64_t time_last_requested;
} Shared_Key;
struct Shared_Key_Cache {
Shared_Key *keys;
const uint8_t *self_secret_key;
uint64_t timeout; /** After this time (in seconds), a key is erased on the next housekeeping cycle */
const Mono_Time *mono_time;
const Memory *mem;
const Logger *log;
uint8_t keys_per_slot;
};
non_null()
static bool shared_key_is_empty(const Logger *log, const Shared_Key *k)
{
LOGGER_ASSERT(log, k != nullptr, "shared key must not be NULL");
/*
* Since time can never be 0, we use that to determine if a key slot is empty.
* Additionally this allows us to use crypto_memzero and leave the slot in a valid state.
*/
return k->time_last_requested == 0;
}
non_null()
static void shared_key_set_empty(const Logger *log, Shared_Key *k)
{
crypto_memzero(k, sizeof(Shared_Key));
LOGGER_ASSERT(log, shared_key_is_empty(log, k), "shared key must be empty after clearing it");
}
Shared_Key_Cache *shared_key_cache_new(const Logger *log, const Mono_Time *mono_time, const Memory *mem, const uint8_t *self_secret_key, uint64_t timeout, uint8_t keys_per_slot)
{
if (mono_time == nullptr || self_secret_key == nullptr || timeout == 0 || keys_per_slot == 0) {
return nullptr;
}
// Time must not be zero, since we use that as special value for empty slots
if (mono_time_get(mono_time) == 0) {
// Fail loudly in debug environments
LOGGER_FATAL(log, "time must not be zero (mono_time not initialised?)");
return nullptr;
}
Shared_Key_Cache *res = (Shared_Key_Cache *)mem_alloc(mem, sizeof(Shared_Key_Cache));
if (res == nullptr) {
return nullptr;
}
res->self_secret_key = self_secret_key;
res->mono_time = mono_time;
res->mem = mem;
res->log = log;
res->keys_per_slot = keys_per_slot;
// We take one byte from the public key for each bucket and store keys_per_slot elements there
const size_t cache_size = 256 * keys_per_slot;
Shared_Key *keys = (Shared_Key *)mem_valloc(mem, cache_size, sizeof(Shared_Key));
if (keys == nullptr) {
mem_delete(mem, res);
return nullptr;
}
crypto_memlock(keys, cache_size * sizeof(Shared_Key));
res->keys = keys;
return res;
}
void shared_key_cache_free(Shared_Key_Cache *cache)
{
if (cache == nullptr) {
return;
}
const size_t cache_size = 256 * cache->keys_per_slot;
// Don't leave key material in memory
crypto_memzero(cache->keys, cache_size * sizeof(Shared_Key));
crypto_memunlock(cache->keys, cache_size * sizeof(Shared_Key));
mem_delete(cache->mem, cache->keys);
mem_delete(cache->mem, cache);
}
/* NOTE: On each lookup housekeeping is performed to evict keys that did timeout. */
const uint8_t *shared_key_cache_lookup(Shared_Key_Cache *cache, const uint8_t public_key[CRYPTO_PUBLIC_KEY_SIZE])
{
// caching the time is not necessary, but calls to mono_time_get(...) are not free
const uint64_t cur_time = mono_time_get(cache->mono_time);
// We can't use the first and last bytes because they are masked in curve25519. Selected 8 for good alignment.
const uint8_t bucket_idx = public_key[8];
Shared_Key *bucket_start = &cache->keys[bucket_idx * cache->keys_per_slot];
const uint8_t *found = nullptr;
// Perform lookup
for (size_t i = 0; i < cache->keys_per_slot; ++i) {
if (shared_key_is_empty(cache->log, &bucket_start[i])) {
continue;
}
if (pk_equal(public_key, bucket_start[i].public_key)) {
found = bucket_start[i].shared_key;
bucket_start[i].time_last_requested = cur_time;
break;
}
}
// Perform housekeeping for this bucket
for (size_t i = 0; i < cache->keys_per_slot; ++i) {
if (shared_key_is_empty(cache->log, &bucket_start[i])) {
continue;
}
const bool timed_out = (bucket_start[i].time_last_requested + cache->timeout) < cur_time;
if (timed_out) {
shared_key_set_empty(cache->log, &bucket_start[i]);
}
}
if (found == nullptr) {
// Insert into cache
uint64_t oldest_timestamp = UINT64_MAX;
size_t oldest_index = 0;
/*
* Find least recently used entry, unused entries are prioritised,
* because their time_last_requested field is zeroed.
*/
for (size_t i = 0; i < cache->keys_per_slot; ++i) {
if (bucket_start[i].time_last_requested < oldest_timestamp) {
oldest_timestamp = bucket_start[i].time_last_requested;
oldest_index = i;
}
}
// Compute the shared key for the cache
if (encrypt_precompute(public_key, cache->self_secret_key, bucket_start[oldest_index].shared_key) != 0) {
// Don't put anything in the cache on error
return nullptr;
}
// update cache entry
memcpy(bucket_start[oldest_index].public_key, public_key, CRYPTO_PUBLIC_KEY_SIZE);
bucket_start[oldest_index].time_last_requested = cur_time;
found = bucket_start[oldest_index].shared_key;
}
return found;
}
|