1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
// Copyright (c) 2018-2022, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the
// LICENSE.md file distributed with the sources of this project regarding your
// rights to use or distribute this software.
package test
import (
"bufio"
"fmt"
"log"
"os"
"os/user"
"runtime"
"strconv"
"testing"
"golang.org/x/sys/unix"
)
var (
origUID, origGID, unprivUID, unprivGID int
origHome, unprivHome string
)
// EnsurePrivilege ensures elevated privileges are available during a test.
func EnsurePrivilege(t *testing.T) {
uid := os.Getuid()
if uid != 0 {
t.Fatal("test must be run with privilege")
}
}
// DropPrivilege drops privilege. Use this at the start of a test that does
// not require elevated privileges. A matching call to ResetPrivilege must
// occur before the test completes (a defer statement is recommended.)
func DropPrivilege(t *testing.T) {
// In older x/sys/unix and Go <1.16, setresuid/setresgid modifies the
// current thread only. To ensure our new uid/gid sticks, we need to lock
// ourselves to the current OS thread.
runtime.LockOSThread()
if os.Getgid() == 0 {
if err := unix.Setresgid(unprivGID, unprivGID, origGID); err != nil {
t.Fatalf("failed to set group identity: %v", err)
}
}
if os.Getuid() == 0 {
if err := unix.Setresuid(unprivUID, unprivUID, origUID); err != nil {
t.Fatalf("failed to set user identity: %v", err)
}
if err := os.Setenv("HOME", unprivHome); err != nil {
t.Fatalf("failed to set HOME environment variable: %v", err)
}
}
}
// ResetPrivilege returns effective privilege to the original user.
func ResetPrivilege(t *testing.T) {
if err := unix.Setresuid(origUID, origUID, unprivUID); err != nil {
t.Fatalf("failed to reset user identity: %v", err)
}
if err := unix.Setresgid(origGID, origGID, unprivGID); err != nil {
t.Fatalf("failed to reset group identity: %v", err)
}
// We might want restoration of HOME env var to persist past this individual
// test, so use os.Setenv() rather than t.Setenv()
//nolint:tenv
os.Setenv("HOME", origHome)
runtime.UnlockOSThread()
}
// WithPrivilege wraps the supplied test function with calls to ensure
// the test is run with elevated privileges.
func WithPrivilege(f func(t *testing.T)) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
EnsurePrivilege(t)
f(t)
}
}
// WithoutPrivilege wraps the supplied test function with calls to ensure
// the test is run without elevated privileges.
func WithoutPrivilege(f func(t *testing.T)) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
DropPrivilege(t)
defer ResetPrivilege(t)
f(t)
}
}
// getProcInfo returns the parent PID, UID, and GID associated with the
// supplied PID. Calls os.Exit on error.
func getProcInfo(pid int) (ppid int, uid int, gid int) {
f, err := os.Open(fmt.Sprintf("/proc/%v/status", pid))
if err != nil {
log.Fatalf("failed to open /proc/%v/status", pid)
}
defer f.Close()
for s := bufio.NewScanner(f); s.Scan(); {
var temp int
if n, _ := fmt.Sscanf(s.Text(), "PPid:\t%d", &temp); n == 1 {
ppid = temp
}
if n, _ := fmt.Sscanf(s.Text(), "Uid:\t%d", &temp); n == 1 {
uid = temp
}
if n, _ := fmt.Sscanf(s.Text(), "Gid:\t%d", &temp); n == 1 {
gid = temp
}
}
return ppid, uid, gid
}
// getUnprivIDs searches recursively up the process parent chain to find a
// process with a non-root UID, then returns the UID and GID of that process.
// Calls os.Exit on error, or if no non-root process is found.
func getUnprivIDs(pid int) (uid int, gid int) {
if 1 == pid {
log.Fatal("no unprivileged process found")
}
ppid, uid, gid := getProcInfo(pid)
if uid != 0 {
return uid, gid
}
return getUnprivIDs(ppid)
}
func init() {
origUID = os.Getuid()
origGID = os.Getgid()
origUser, err := user.LookupId(strconv.Itoa(origUID))
if err != nil {
log.Fatalf("err: %s", err)
}
origHome = origUser.HomeDir
unprivUID, unprivGID = getUnprivIDs(os.Getpid())
unprivUser, err := user.LookupId(strconv.Itoa(unprivUID))
if err != nil {
log.Fatalf("err: %s", err)
}
unprivHome = unprivUser.HomeDir
}
|