1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
|
// SPDX-License-Identifier: GPL-2.0-only
/*
* Copyright (c) 2025, Google LLC.
* Pasha Tatashin <pasha.tatashin@soleen.com>
*/
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <getopt.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/syscall.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <errno.h>
#include <stdarg.h>
#include "luo_test_utils.h"
int luo_open_device(void)
{
return open(LUO_DEVICE, O_RDWR);
}
int luo_create_session(int luo_fd, const char *name)
{
struct liveupdate_ioctl_create_session arg = { .size = sizeof(arg) };
snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
if (ioctl(luo_fd, LIVEUPDATE_IOCTL_CREATE_SESSION, &arg) < 0)
return -errno;
return arg.fd;
}
int luo_retrieve_session(int luo_fd, const char *name)
{
struct liveupdate_ioctl_retrieve_session arg = { .size = sizeof(arg) };
snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s",
LIVEUPDATE_SESSION_NAME_LENGTH - 1, name);
if (ioctl(luo_fd, LIVEUPDATE_IOCTL_RETRIEVE_SESSION, &arg) < 0)
return -errno;
return arg.fd;
}
int create_and_preserve_memfd(int session_fd, int token, const char *data)
{
struct liveupdate_session_preserve_fd arg = { .size = sizeof(arg) };
long page_size = sysconf(_SC_PAGE_SIZE);
void *map = MAP_FAILED;
int mfd = -1, ret = -1;
mfd = memfd_create("test_mfd", 0);
if (mfd < 0)
return -errno;
if (ftruncate(mfd, page_size) != 0)
goto out;
map = mmap(NULL, page_size, PROT_WRITE, MAP_SHARED, mfd, 0);
if (map == MAP_FAILED)
goto out;
snprintf(map, page_size, "%s", data);
munmap(map, page_size);
arg.fd = mfd;
arg.token = token;
if (ioctl(session_fd, LIVEUPDATE_SESSION_PRESERVE_FD, &arg) < 0)
goto out;
ret = 0;
out:
if (ret != 0 && errno != 0)
ret = -errno;
if (mfd >= 0)
close(mfd);
return ret;
}
int restore_and_verify_memfd(int session_fd, int token,
const char *expected_data)
{
struct liveupdate_session_retrieve_fd arg = { .size = sizeof(arg) };
long page_size = sysconf(_SC_PAGE_SIZE);
void *map = MAP_FAILED;
int mfd = -1, ret = -1;
arg.token = token;
if (ioctl(session_fd, LIVEUPDATE_SESSION_RETRIEVE_FD, &arg) < 0)
return -errno;
mfd = arg.fd;
map = mmap(NULL, page_size, PROT_READ, MAP_SHARED, mfd, 0);
if (map == MAP_FAILED)
goto out;
if (expected_data && strcmp(expected_data, map) != 0) {
ksft_print_msg("Data mismatch! Expected '%s', Got '%s'\n",
expected_data, (char *)map);
ret = -EINVAL;
goto out_munmap;
}
ret = mfd;
out_munmap:
munmap(map, page_size);
out:
if (ret < 0 && errno != 0)
ret = -errno;
if (ret < 0 && mfd >= 0)
close(mfd);
return ret;
}
int luo_session_finish(int session_fd)
{
struct liveupdate_session_finish arg = { .size = sizeof(arg) };
if (ioctl(session_fd, LIVEUPDATE_SESSION_FINISH, &arg) < 0)
return -errno;
return 0;
}
void create_state_file(int luo_fd, const char *session_name, int token,
int next_stage)
{
char buf[32];
int state_session_fd;
state_session_fd = luo_create_session(luo_fd, session_name);
if (state_session_fd < 0)
fail_exit("luo_create_session for state tracking");
snprintf(buf, sizeof(buf), "%d", next_stage);
if (create_and_preserve_memfd(state_session_fd, token, buf) < 0)
fail_exit("create_and_preserve_memfd for state tracking");
/*
* DO NOT close session FD, otherwise it is going to be unpreserved
*/
}
void restore_and_read_stage(int state_session_fd, int token, int *stage)
{
char buf[32] = {0};
int mfd;
mfd = restore_and_verify_memfd(state_session_fd, token, NULL);
if (mfd < 0)
fail_exit("failed to restore state memfd");
if (read(mfd, buf, sizeof(buf) - 1) < 0)
fail_exit("failed to read state mfd");
*stage = atoi(buf);
close(mfd);
}
void daemonize_and_wait(void)
{
pid_t pid;
ksft_print_msg("[STAGE 1] Forking persistent child to hold sessions...\n");
pid = fork();
if (pid < 0)
fail_exit("fork failed");
if (pid > 0) {
ksft_print_msg("[STAGE 1] Child PID: %d. Resources are pinned.\n", pid);
ksft_print_msg("[STAGE 1] You may now perform kexec reboot.\n");
exit(EXIT_SUCCESS);
}
/* Detach from terminal so closing the window doesn't kill us */
if (setsid() < 0)
fail_exit("setsid failed");
close(STDIN_FILENO);
close(STDOUT_FILENO);
close(STDERR_FILENO);
/* Change dir to root to avoid locking filesystems */
if (chdir("/") < 0)
exit(EXIT_FAILURE);
while (1)
sleep(60);
}
static int parse_stage_args(int argc, char *argv[])
{
static struct option long_options[] = {
{"stage", required_argument, 0, 's'},
{0, 0, 0, 0}
};
int option_index = 0;
int stage = 1;
int opt;
optind = 1;
while ((opt = getopt_long(argc, argv, "s:", long_options, &option_index)) != -1) {
switch (opt) {
case 's':
stage = atoi(optarg);
if (stage != 1 && stage != 2)
fail_exit("Invalid stage argument");
break;
default:
fail_exit("Unknown argument");
}
}
return stage;
}
int luo_test(int argc, char *argv[],
const char *state_session_name,
luo_test_stage1_fn stage1,
luo_test_stage2_fn stage2)
{
int target_stage = parse_stage_args(argc, argv);
int luo_fd = luo_open_device();
int state_session_fd;
int detected_stage;
if (luo_fd < 0) {
ksft_exit_skip("Failed to open %s. Is the luo module loaded?\n",
LUO_DEVICE);
}
state_session_fd = luo_retrieve_session(luo_fd, state_session_name);
if (state_session_fd == -ENOENT)
detected_stage = 1;
else if (state_session_fd >= 0)
detected_stage = 2;
else
fail_exit("Failed to check for state session");
if (target_stage != detected_stage) {
ksft_exit_fail_msg("Stage mismatch Requested --stage %d, but system is in stage %d.\n"
"(State session %s: %s)\n",
target_stage, detected_stage, state_session_name,
(detected_stage == 2) ? "EXISTS" : "MISSING");
}
if (target_stage == 1)
stage1(luo_fd);
else
stage2(luo_fd, state_session_fd);
return 0;
}
|