File: socket.go

package info (click to toggle)
golang-github-digitalocean-go-libvirt 0.0~git20250317.13bf9b4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,188 kB
  • sloc: yacc: 188; sh: 76; xml: 50; makefile: 3
file content (376 lines) | stat: -rw-r--r-- 8,980 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
package socket

import (
	"bufio"
	"encoding/binary"
	"errors"
	"io"
	"net"
	"sync"
	"syscall"
	"time"
	"unsafe"

	"github.com/digitalocean/go-libvirt/internal/constants"
)

const disconnectTimeout = 5 * time.Second

// request and response statuses
const (
	// StatusOK is always set for method calls or events.
	// For replies it indicates successful completion of the method.
	// For streams it indicates confirmation of the end of file on the stream.
	StatusOK = iota

	// StatusError for replies indicates that the method call failed
	// and error information is being returned. For streams this indicates
	// that not all data was sent and the stream has aborted.
	StatusError

	// StatusContinue is only used for streams.
	// This indicates that further data packets will be following.
	StatusContinue
)

// request and response types
const (
	// Call is used when making calls to the remote server.
	Call = iota

	// Reply indicates a server reply.
	Reply

	// Message is an asynchronous notification.
	Message

	// Stream represents a stream data packet.
	Stream

	// CallWithFDs is used by a client to indicate the request has
	// arguments with file descriptors.
	CallWithFDs

	// ReplyWithFDs is used by a server to indicate the request has
	// arguments with file descriptors.
	ReplyWithFDs
)

// Dialer is an interface for connecting to libvirt's underlying socket.
type Dialer interface {
	Dial() (net.Conn, error)
}

// Router is an interface used to route packets to the appropriate clients.
type Router interface {
	Route(*Header, []byte)
}

// Socket represents a libvirt Socket and its connection state
type Socket struct {
	dialer Dialer
	router Router

	conn   net.Conn
	reader *bufio.Reader
	writer *bufio.Writer
	// used to serialize any Socket writes and any updates to conn, r, or w
	mu *sync.Mutex

	// disconnected is closed when the listen goroutine associated with a
	// Socket connection has returned.
	disconnected chan struct{}
}

// packet represents a RPC request or response.
type packet struct {
	// Size of packet, in bytes, including length.
	// Len + Header + Payload
	Len    uint32
	Header Header
}

// Global packet instance, for use with unsafe.Sizeof()
var _p packet

// Header is a libvirt rpc packet header
type Header struct {
	// Program identifier
	Program uint32

	// Program version
	Version uint32

	// Remote procedure identifier
	Procedure uint32

	// Call type, e.g., Reply
	Type uint32

	// Call serial number
	Serial int32

	// Request status, e.g., StatusOK
	Status uint32
}

// New initializes a new type for managing the Socket.
func New(dialer Dialer, router Router) *Socket {
	s := &Socket{
		dialer:       dialer,
		router:       router,
		disconnected: make(chan struct{}),
		mu:           &sync.Mutex{},
	}

	// we start with a closed channel since that indicates no connection
	close(s.disconnected)

	return s
}

// Connect uses the dialer provided on creation to establish
// underlying physical connection to the desired libvirt.
func (s *Socket) Connect() error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if !s.isDisconnected() {
		return errors.New("already connected to socket")
	}
	conn, err := s.dialer.Dial()
	if err != nil {
		return err
	}

	s.conn = conn
	s.reader = bufio.NewReader(conn)
	s.writer = bufio.NewWriter(conn)
	s.disconnected = make(chan struct{})

	go s.listenAndRoute()

	return nil
}

// Disconnect closes the Socket connection to libvirt and waits for the reader
// gorouting to shut down.
func (s *Socket) Disconnect() error {
	// just return if we're already disconnected
	if s.isDisconnected() {
		return nil
	}

	err := s.conn.Close()
	if err != nil {
		return err
	}

	// now we wait for the reader to return so as not to avoid it nil
	// referencing
	// Put this in a select,
	// and have it only nil out the conn value if it doesn't fail
	select {
	case <-s.disconnected:
	case <-time.After(disconnectTimeout):
		return errors.New("timed out waiting for Disconnect cleanup")
	}

	return nil
}

// Disconnected returns a channel that will be closed once the current
// connection is closed.  This can happen due to an explicit call to Disconnect
// from the client, or due to non-temporary Read or Write errors encountered.
func (s *Socket) Disconnected() <-chan struct{} {
	return s.disconnected
}

// isDisconnected is a non-blocking function to query whether a connection
// is disconnected or not.
func (s *Socket) isDisconnected() bool {
	select {
	case <-s.disconnected:
		return true
	default:
		return false
	}
}

// listenAndRoute reads packets from the Socket and calls the provided
// Router function to route them
func (s *Socket) listenAndRoute() {
	// only returns once it detects a non-temporary error related to the
	// underlying connection
	listen(s.reader, s.router)

	// signal any clients listening that the connection has been disconnected
	close(s.disconnected)
}

// listen processes incoming data and routes
// responses to their respective callback handler.
func listen(s io.Reader, router Router) {
	for {
		// response packet length
		length, err := pktlen(s)
		if err != nil {
			if isTemporary(err) {
				continue
			}
			// connection is no longer valid, so shutdown
			return
		}

		// response header
		h, err := extractHeader(s)
		if err != nil {
			// invalid packet
			continue
		}

		// payload: packet length minus what was previously read
		size := int(length) - int(unsafe.Sizeof(_p))
		buf := make([]byte, size)
		_, err = io.ReadFull(s, buf)
		if err != nil {
			// invalid packet
			continue
		}

		// route response to caller
		router.Route(h, buf)
	}
}

// isTemporary returns true if the error returned from a read is transient.
// If the error type is an OpError, check whether the net connection
// error condition is temporary (which means we can keep using the
// connection).
// Errors not of the net.OpError type tend to be things like io.EOF,
// syscall.EINVAL, or io.ErrClosedPipe (i.e. all things that
// indicate the connection in use is no longer valid.)
func isTemporary(err error) bool {
	opErr, ok := err.(*net.OpError)
	if ok {
		return opErr.Temporary()
	}
	return false
}

// pktlen returns the length of an incoming RPC packet.  Read errors will
// result in a returned response length of 0 and a non-nil error.
func pktlen(r io.Reader) (uint32, error) {
	buf := make([]byte, unsafe.Sizeof(_p.Len))

	// extract the packet's length from the header
	_, err := io.ReadFull(r, buf)
	if err != nil {
		return 0, err
	}

	return binary.BigEndian.Uint32(buf), nil
}

// extractHeader returns the decoded header from an incoming response.
func extractHeader(r io.Reader) (*Header, error) {
	buf := make([]byte, unsafe.Sizeof(_p.Header))

	// extract the packet's header from r
	_, err := io.ReadFull(r, buf)
	if err != nil {
		return nil, err
	}

	return &Header{
		Program:   binary.BigEndian.Uint32(buf[0:4]),
		Version:   binary.BigEndian.Uint32(buf[4:8]),
		Procedure: binary.BigEndian.Uint32(buf[8:12]),
		Type:      binary.BigEndian.Uint32(buf[12:16]),
		Serial:    int32(binary.BigEndian.Uint32(buf[16:20])),
		Status:    binary.BigEndian.Uint32(buf[20:24]),
	}, nil
}

// SendPacket sends a packet to libvirt on the socket connection.
func (s *Socket) SendPacket(
	serial int32,
	proc uint32,
	program uint32,
	payload []byte,
	typ uint32,
	status uint32,
) error {
	p := packet{
		Header: Header{
			Program:   program,
			Version:   constants.ProtocolVersion,
			Procedure: proc,
			Type:      typ,
			Serial:    serial,
			Status:    status,
		},
	}

	size := int(unsafe.Sizeof(p.Len)) + int(unsafe.Sizeof(p.Header))
	if payload != nil {
		size += len(payload)
	}
	p.Len = uint32(size)

	if s.isDisconnected() {
		// this mirrors what a lot of net code return on use of a no
		// longer valid connection
		return syscall.EINVAL
	}

	s.mu.Lock()
	defer s.mu.Unlock()

	err := binary.Write(s.writer, binary.BigEndian, p)
	if err != nil {
		return err
	}

	// write payload
	if payload != nil {
		err = binary.Write(s.writer, binary.BigEndian, payload)
		if err != nil {
			return err
		}
	}

	return s.writer.Flush()
}

// SendStream sends a stream of packets to libvirt on the socket connection.
func (s *Socket) SendStream(serial int32, proc uint32, program uint32,
	stream io.Reader, abort chan bool) error {
	// Keep total packet length under 4 MiB to follow possible limitation in libvirt server code
	buf := make([]byte, 4*MiB-unsafe.Sizeof(_p))
	for {
		select {
		case <-abort:
			return s.SendPacket(serial, proc, program, nil, Stream, StatusError)
		default:
		}
		n, err := stream.Read(buf)
		if n > 0 {
			err2 := s.SendPacket(serial, proc, program, buf[:n], Stream, StatusContinue)
			if err2 != nil {
				return err2
			}
		}
		if err != nil {
			if err == io.EOF {
				return s.SendPacket(serial, proc, program, nil, Stream, StatusOK)
			}
			// keep original error
			err2 := s.SendPacket(serial, proc, program, nil, Stream, StatusError)
			if err2 != nil {
				return err2
			}
			return err
		}
	}
}