File: server.go

package info (click to toggle)
incus 6.0.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 24,392 kB
  • sloc: sh: 16,313; ansic: 3,121; python: 457; makefile: 337; ruby: 51; sql: 50; lisp: 6
file content (300 lines) | stat: -rw-r--r-- 6,059 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
package dns

import (
	"context"
	"fmt"
	"strings"
	"sync"
	"time"

	"github.com/miekg/dns"

	"github.com/lxc/incus/v6/internal/ports"
	"github.com/lxc/incus/v6/internal/server/db"
	dbCluster "github.com/lxc/incus/v6/internal/server/db/cluster"
	internalUtil "github.com/lxc/incus/v6/internal/util"
	"github.com/lxc/incus/v6/shared/logger"
)

// ZoneRetriever is a function which fetches a DNS zone.
type ZoneRetriever func(name string, full bool) (*Zone, error)

// Server represents a DNS server instance.
type Server struct {
	tcpDNS *dns.Server
	udpDNS *dns.Server

	// External dependencies.
	db            *db.Cluster
	zoneRetriever ZoneRetriever

	// Internal state (to handle reconfiguration).
	address string

	cmd chan serverCmdInfo

	mu sync.Mutex
}

type serverCmd int

const (
	serverCmdStart serverCmd = iota
	serverCmdRestart
	serverCmdStop
	serverCmdReconfigure
	serverCmdHandleError
)

type serverCmdInfo struct {
	cmd     serverCmd
	address string
	err     error
}

// NewServer returns a new server instance.
func NewServer(db *db.Cluster, retriever ZoneRetriever) *Server {
	// Setup new struct.
	s := &Server{db: db, zoneRetriever: retriever}
	return s
}

func (s *Server) handleErr(err error) {
	s.cmd <- serverCmdInfo{
		cmd: serverCmdHandleError,
		err: err,
	}
}

func (s *Server) runDNSServer() {
	shouldRun := false
	address := ""

	for cmd := range s.cmd {
		switch cmd.cmd {
		case serverCmdStart:
			if shouldRun {
				continue
			}

			shouldRun = true
			address = cmd.address
			s.mu.Lock()
			err := s.start(cmd.address)
			if err != nil {
				// Run in new goroutine to avoid deadlock.
				go s.handleErr(err)
			}

			s.mu.Unlock()
		case serverCmdRestart:
			s.mu.Lock()
			// don't start if the server shouldn't run or is already running (s.address is set when the server starts)
			if !shouldRun || s.address != "" {
				s.mu.Unlock()
				continue
			}

			err := s.start(address)
			if err != nil {
				// Run in new goroutine to avoid deadlock.
				go s.handleErr(err)
			}

			s.mu.Unlock()
		case serverCmdStop:
			shouldRun = false
			s.mu.Lock()
			s.stop()
			s.mu.Unlock()
		case serverCmdReconfigure:
			s.mu.Lock()
			s.stop()

			if cmd.address == "" {
				shouldRun = false
			} else {
				shouldRun = true
				address = cmd.address
				err := s.start(cmd.address)
				if err != nil {
					// Run in new goroutine to avoid deadlock.
					go s.handleErr(err)
				}
			}

			s.mu.Unlock()
		case serverCmdHandleError:
			if cmd.err == nil {
				continue
			}

			logger.Errorf("DNS server encountered an error, restarting in 10s: %v", cmd.err)
			s.mu.Lock()
			s.stop()
			s.mu.Unlock()
			go func() {
				<-time.NewTimer(time.Second * 10).C
				s.cmd <- serverCmdInfo{cmd: serverCmdRestart}
			}()
		}
	}
}

// Start sets up the DNS listener.
func (s *Server) Start(address string) error {
	s.mu.Lock()

	start := s.cmd == nil

	if start {
		s.cmd = make(chan serverCmdInfo)
		go s.runDNSServer()
	}

	s.mu.Unlock()

	if start {
		s.cmd <- serverCmdInfo{
			cmd:     serverCmdStart,
			address: address,
		}
	} else {
		s.cmd <- serverCmdInfo{
			cmd:     serverCmdReconfigure,
			address: address,
		}
	}

	return nil
}

func (s *Server) start(address string) error {
	// Set default port if needed.
	address = internalUtil.CanonicalNetworkAddress(address, ports.DNSDefaultPort)

	// Setup the handler.
	handler := dnsHandler{}
	handler.server = s

	// Spawn the DNS server.
	s.tcpDNS = &dns.Server{Addr: address, Net: "tcp", Handler: handler}
	go func() {
		err := s.tcpDNS.ListenAndServe()
		if err != nil {
			s.handleErr(fmt.Errorf("Failed to listen on TCP DNS address %q: %v", address, err))
		}
	}()

	s.udpDNS = &dns.Server{Addr: address, Net: "udp", Handler: handler}
	go func() {
		err := s.udpDNS.ListenAndServe()
		if err != nil {
			s.handleErr(fmt.Errorf("Failed to listen on UDP DNS address %q: %v", address, err))
		}
	}()

	// TSIG handling.
	err := s.updateTSIG()
	if err != nil {
		return err
	}

	// Record the address.
	s.address = address

	return nil
}

// Stop tears down the DNS listener.
func (s *Server) Stop() error {
	s.cmd <- serverCmdInfo{
		cmd: serverCmdStop,
	}

	return nil
}

func (s *Server) stop() {
	// Skip if no instance.
	if s.tcpDNS == nil || s.udpDNS == nil {
		return
	}

	// Stop the listener.
	_ = s.tcpDNS.Shutdown()
	_ = s.udpDNS.Shutdown()

	// Unset the address.
	s.address = ""
}

// Reconfigure updates the listener with a new configuration.
func (s *Server) Reconfigure(address string) error {
	return s.Start(address)
}

// UpdateTSIG fetches all TSIG keys and loads them into the DNS server.
func (s *Server) UpdateTSIG() error {
	// Locking.
	s.mu.Lock()
	defer s.mu.Unlock()

	return s.updateTSIG()
}

func (s *Server) updateTSIG() error {
	// Skip if no instance.
	if s.tcpDNS == nil || s.udpDNS == nil || s.db == nil {
		return nil
	}

	secrets := make(map[string]string)

	err := s.db.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
		// Get all the network zones.
		zones, err := dbCluster.GetNetworkZones(ctx, tx.Tx())
		if err != nil {
			return err
		}

		// For each zone, get its config.
		for _, zone := range zones {
			// Get all configs for this zone.
			config, err := dbCluster.GetNetworkZoneConfig(ctx, tx.Tx(), zone.ID)
			if err != nil {
				return err
			}

			// Process each config entry.
			for key, value := range config {
				// Check if the key matches the pattern 'peers.%.key'.
				if !strings.HasPrefix(key, "peers.") || !strings.HasSuffix(key, ".key") {
					continue
				}

				// Split the key to extract the peer name.
				fields := strings.SplitN(key, ".", 3)
				if len(fields) != 3 {
					// Skip invalid values.
					continue
				}

				// Format as a valid TSIG secret (encode domain name, key name and make valid FQDN).
				secretKey := fmt.Sprintf("%s_%s.", zone.Name, fields[1])
				secrets[secretKey] = value
			}
		}

		return nil
	})
	if err != nil {
		return err
	}

	// Apply to the DNS servers.
	s.tcpDNS.TsigSecret = secrets
	s.udpDNS.TsigSecret = secrets

	return nil
}