File: triggerwatch.go

package info (click to toggle)
snapd 2.71-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 79,536 kB
  • sloc: ansic: 16,114; sh: 16,105; python: 9,941; makefile: 1,890; exp: 190; awk: 40; xml: 22
file content (161 lines) | stat: -rw-r--r-- 4,369 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// -*- Mode: Go; indent-tabs-mode: t -*-

/*
 * Copyright (C) 2020 Canonical Ltd
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

package triggerwatch

import (
	"errors"
	"fmt"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/snapcore/snapd/logger"
	"github.com/snapcore/snapd/osutil/udev/netlink"
	"github.com/snapcore/snapd/timeutil"
)

var timeAfter = func(d time.Duration) <-chan time.Time {
	return timeutil.After(d)
}

type triggerProvider interface {
	Open(filter triggerEventFilter, node string) (triggerDevice, error)
	FindMatchingDevices(filter triggerEventFilter) ([]triggerDevice, error)
}

type triggerDevice interface {
	WaitForTrigger(chan keyEvent)
	String() string
	Close()
}

type ueventConnection interface {
	Connect(mode netlink.Mode) error
	Close() error
	Monitor(queue chan netlink.UEvent, errors chan error, matcher netlink.Matcher) func(time.Duration) bool
}

var (
	// trigger mechanism
	trigger       triggerProvider
	getUEventConn = func() ueventConnection {
		return &netlink.UEventConn{}
	}

	// wait for '1' to be pressed
	triggerFilter = triggerEventFilter{Key: "KEY_1"}

	ErrTriggerNotDetected     = errors.New("trigger not detected")
	ErrNoMatchingInputDevices = errors.New("no matching input devices")
)

// Wait waits for a trigger on the available trigger devices for a given amount
// of time. Returns nil if one was detected, ErrTriggerNotDetected if timeout
// was hit, or other non-nil error.
func Wait(timeout time.Duration, deviceTimeout time.Duration) error {
	sigs := make(chan os.Signal, 1)
	signal.Notify(sigs, syscall.SIGUSR1)
	conn := getUEventConn()
	if err := conn.Connect(netlink.UdevEvent); err != nil {
		logger.Panicf("Unable to connect to Netlink Kobject UEvent socket")
	}
	defer conn.Close()

	add := "add"
	matcher := &netlink.RuleDefinitions{
		Rules: []netlink.RuleDefinition{
			{
				Action: &add,
				Env: map[string]string{
					"SUBSYSTEM":         "input",
					"ID_INPUT_KEYBOARD": "1",
					"DEVNAME":           ".*",
				},
			},
		},
	}

	ueventQueue := make(chan netlink.UEvent)
	ueventErrors := make(chan error)
	conn.Monitor(ueventQueue, ueventErrors, matcher)

	if trigger == nil {
		logger.Panicf("trigger is unset")
	}

	devices, err := trigger.FindMatchingDevices(triggerFilter)
	if err != nil {
		return fmt.Errorf("cannot list trigger devices: %v", err)
	}

	if devices == nil {
		devices = make([]triggerDevice, 0)
	}

	logger.Noticef("waiting for trigger key: %v", triggerFilter.Key)

	detectKeyCh := make(chan keyEvent, len(devices))
	for _, dev := range devices {
		go dev.WaitForTrigger(detectKeyCh)
		defer dev.Close()
	}
	foundDevice := len(devices) != 0

	timeoutEvent := timeAfter(timeout)
	deviceTimeoutEvent := timeAfter(deviceTimeout)
	for {
		select {
		case kev := <-detectKeyCh:
			if kev.Err != nil {
				return kev.Err
			}
			// channel got closed without an error
			logger.Noticef("%s: + got trigger key %v", kev.Dev, triggerFilter.Key)
			return nil
		case <-timeoutEvent:
			return ErrTriggerNotDetected
		case <-deviceTimeoutEvent:
			if !foundDevice {
				return ErrNoMatchingInputDevices
			}
		case uevent := <-ueventQueue:
			dev, err := trigger.Open(triggerFilter, uevent.Env["DEVNAME"])
			if err != nil {
				logger.Noticef("ignoring device %s that cannot be opened: %v", uevent.Env["DEVNAME"], err)
			} else if dev != nil {
				foundDevice = true
				defer dev.Close()
				go dev.WaitForTrigger(detectKeyCh)
			}
		case <-sigs:
			logger.Noticef("Switching root")
			if err := syscall.Chdir("/sysroot"); err != nil {
				return fmt.Errorf("Cannot change directory: %w", err)
			}
			if err := syscall.Chroot("/sysroot"); err != nil {
				return fmt.Errorf("Cannot change root: %w", err)
			}
			if err := syscall.Chdir("/"); err != nil {
				return fmt.Errorf("Cannot change directory: %w", err)
			}
		}
	}
}