File: privilege_linux.go

package info (click to toggle)
singularity-container 4.1.5%2Bds4-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 43,876 kB
  • sloc: asm: 14,840; sh: 3,190; ansic: 1,751; awk: 414; makefile: 413; python: 99
file content (157 lines) | stat: -rw-r--r-- 4,166 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Copyright (c) 2018-2022, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the
// LICENSE.md file distributed with the sources of this project regarding your
// rights to use or distribute this software.

package test

import (
	"bufio"
	"fmt"
	"log"
	"os"
	"os/user"
	"runtime"
	"strconv"
	"testing"

	"golang.org/x/sys/unix"
)

var (
	origUID, origGID, unprivUID, unprivGID int
	origHome, unprivHome                   string
)

// EnsurePrivilege ensures elevated privileges are available during a test.
func EnsurePrivilege(t *testing.T) {
	uid := os.Getuid()
	if uid != 0 {
		t.Fatal("test must be run with privilege")
	}
}

// DropPrivilege drops privilege. Use this at the start of a test that does
// not require elevated privileges. A matching call to ResetPrivilege must
// occur before the test completes (a defer statement is recommended.)
func DropPrivilege(t *testing.T) {
	// In older x/sys/unix and Go <1.16, setresuid/setresgid modifies the
	// current thread only. To ensure our new uid/gid sticks, we need to lock
	// ourselves to the current OS thread.
	runtime.LockOSThread()

	if os.Getgid() == 0 {
		if err := unix.Setresgid(unprivGID, unprivGID, origGID); err != nil {
			t.Fatalf("failed to set group identity: %v", err)
		}
	}
	if os.Getuid() == 0 {
		if err := unix.Setresuid(unprivUID, unprivUID, origUID); err != nil {
			t.Fatalf("failed to set user identity: %v", err)
		}

		if err := os.Setenv("HOME", unprivHome); err != nil {
			t.Fatalf("failed to set HOME environment variable: %v", err)
		}
	}
}

// ResetPrivilege returns effective privilege to the original user.
func ResetPrivilege(t *testing.T) {
	if err := unix.Setresuid(origUID, origUID, unprivUID); err != nil {
		t.Fatalf("failed to reset user identity: %v", err)
	}
	if err := unix.Setresgid(origGID, origGID, unprivGID); err != nil {
		t.Fatalf("failed to reset group identity: %v", err)
	}

	// We might want restoration of HOME env var to persist past this individual
	// test, so use os.Setenv() rather than t.Setenv()
	//nolint:tenv
	os.Setenv("HOME", origHome)

	runtime.UnlockOSThread()
}

// WithPrivilege wraps the supplied test function with calls to ensure
// the test is run with elevated privileges.
func WithPrivilege(f func(t *testing.T)) func(t *testing.T) {
	return func(t *testing.T) {
		t.Helper()

		EnsurePrivilege(t)

		f(t)
	}
}

// WithoutPrivilege wraps the supplied test function with calls to ensure
// the test is run without elevated privileges.
func WithoutPrivilege(f func(t *testing.T)) func(t *testing.T) {
	return func(t *testing.T) {
		t.Helper()

		DropPrivilege(t)
		defer ResetPrivilege(t)

		f(t)
	}
}

// getProcInfo returns the parent PID, UID, and GID associated with the
// supplied PID. Calls os.Exit on error.
func getProcInfo(pid int) (ppid int, uid int, gid int) {
	f, err := os.Open(fmt.Sprintf("/proc/%v/status", pid))
	if err != nil {
		log.Fatalf("failed to open /proc/%v/status", pid)
	}
	defer f.Close()

	for s := bufio.NewScanner(f); s.Scan(); {
		var temp int
		if n, _ := fmt.Sscanf(s.Text(), "PPid:\t%d", &temp); n == 1 {
			ppid = temp
		}
		if n, _ := fmt.Sscanf(s.Text(), "Uid:\t%d", &temp); n == 1 {
			uid = temp
		}
		if n, _ := fmt.Sscanf(s.Text(), "Gid:\t%d", &temp); n == 1 {
			gid = temp
		}
	}
	return ppid, uid, gid
}

// getUnprivIDs searches recursively up the process parent chain to find a
// process with a non-root UID, then returns the UID and GID of that process.
// Calls os.Exit on error, or if no non-root process is found.
func getUnprivIDs(pid int) (uid int, gid int) {
	if 1 == pid {
		log.Fatal("no unprivileged process found")
	}

	ppid, uid, gid := getProcInfo(pid)
	if uid != 0 {
		return uid, gid
	}
	return getUnprivIDs(ppid)
}

func init() {
	origUID = os.Getuid()
	origGID = os.Getgid()
	origUser, err := user.LookupId(strconv.Itoa(origUID))
	if err != nil {
		log.Fatalf("err: %s", err)
	}

	origHome = origUser.HomeDir

	unprivUID, unprivGID = getUnprivIDs(os.Getpid())
	unprivUser, err := user.LookupId(strconv.Itoa(unprivUID))
	if err != nil {
		log.Fatalf("err: %s", err)
	}

	unprivHome = unprivUser.HomeDir
}